aboutsummaryrefslogtreecommitdiff
path: root/shim/sl/canonical/ArmnnDriverImpl.hpp
blob: 836bf469ccaa26bcfa6d9ddac49cea60155ca31e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
//
// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "DriverOptions.hpp"

#include <armnn/ArmNN.hpp>

#include <nnapi/IPreparedModel.h>
#include <nnapi/Result.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
#include <nnapi/Validation.h>

using namespace android::nn;

namespace armnn_driver
{

class ArmnnDriverImpl
{
public:
    static GeneralResult<SharedPreparedModel> PrepareArmnnModel(
        const armnn::IRuntimePtr& runtime,
        const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
        const DriverOptions& options,
        const Model& model,
        const std::vector<SharedHandle>& modelCacheHandle,
        const std::vector<SharedHandle>& dataCacheHandle,
        const CacheToken& token,
        bool float32ToFloat16 = false,
        Priority priority = Priority::MEDIUM);

    static GeneralResult<SharedPreparedModel> PrepareArmnnModelFromCache(
        const armnn::IRuntimePtr& runtime,
        const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
        const DriverOptions& options,
        const std::vector<SharedHandle>& modelCacheHandle,
        const std::vector<SharedHandle>& dataCacheHandle,
        const CacheToken& token,
        bool float32ToFloat16 = false);

    static const Capabilities& GetCapabilities(const armnn::IRuntimePtr& runtime);

    static std::vector<armnn::NetworkId>& GetLoadedNetworks();

    static void ClearNetworks();

private:
    static bool ValidateSharedHandle(const SharedHandle& sharedHandle);
    static bool ValidateDataCacheHandle(const std::vector<SharedHandle>& dataCacheHandle, const size_t dataSize);

    static std::vector<armnn::NetworkId> m_NetworkIDs;
};

} // namespace armnn_driver