aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Runtime.hpp
blob: f5dfadf948fcd7ed678e66dd46ffd8ad2daedb41 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
//
// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include "LoadedNetwork.hpp"
#include "DeviceSpec.hpp"

#include <armnn/INetwork.hpp>
#include <armnn/IRuntime.hpp>
#include <armnn/Tensor.hpp>
#include <armnn/BackendId.hpp>

#include <armnn/backends/DynamicBackend.hpp>

#include <client/include/IInitialiseProfilingService.hpp>
#include <client/include/IProfilingService.hpp>
#include <client/include/IReportStructure.hpp>

#include <mutex>
#include <unordered_map>

namespace armnn
{
using LoadedNetworks = std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>>;
using IReportStructure = arm::pipe::IReportStructure;
    using IInitialiseProfilingService = arm::pipe::IInitialiseProfilingService;

struct RuntimeImpl final :  public IReportStructure, public IInitialiseProfilingService
{
public:
    /// Loads a complete network into the Runtime.
    /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference.
    /// @param [in] network - Complete network to load into the Runtime.
    /// The runtime takes ownership of the network once passed in.
    /// @return armnn::Status
    Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network);

    /// Load a complete network into the IRuntime.
    /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
    /// @param [in] network Complete network to load into the IRuntime.
    /// @param [out] errorMessage Error message if there were any errors.
    /// The runtime takes ownership of the network once passed in.
    /// @return armnn::Status
    Status LoadNetwork(NetworkId& networkIdOut,
                       IOptimizedNetworkPtr network,
                       std::string& errorMessage);

    Status LoadNetwork(NetworkId& networkIdOut,
                       IOptimizedNetworkPtr network,
                       std::string& errorMessage,
                       const INetworkProperties& networkProperties);

    armnn::TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const;
    armnn::TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const;

    std::vector<ImportedInputId> ImportInputs(NetworkId networkId, const InputTensors& inputTensors,
                                              MemorySource forceImportMemorySource);
    std::vector<ImportedOutputId> ImportOutputs(NetworkId networkId, const OutputTensors& outputTensors,
                                                MemorySource forceImportMemorySource);

    void ClearImportedInputs(NetworkId networkId, const std::vector<ImportedInputId> inputIds);
    void ClearImportedOutputs(NetworkId networkId, const std::vector<ImportedOutputId> outputIds);

    // Evaluates network using input in inputTensors, outputs filled into outputTensors.
    Status EnqueueWorkload(NetworkId networkId,
                           const InputTensors& inputTensors,
                           const OutputTensors& outputTensors,
                           std::vector<ImportedInputId> preImportedInputIds = {},
                           std::vector<ImportedOutputId> preImportedOutputIds = {});

    /// This is an experimental function.
    /// Evaluates a network using input in inputTensors and outputs filled into outputTensors.
    /// This function performs a thread safe execution of the network. Returns once execution is complete.
    /// Will block until this and any other thread using the same workingMem object completes.
    Status Execute(IWorkingMemHandle& workingMemHandle,
                   const InputTensors& inputTensors,
                   const OutputTensors& outputTensors,
                   std::vector<ImportedInputId> preImportedInputs,
                   std::vector<ImportedOutputId> preImportedOutputs);

    /// Unloads a network from the Runtime.
    /// At the moment this only removes the network from the m_Impl->m_Network.
    /// This might need more work in the future to be AndroidNN compliant.
    /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
    /// @return armnn::Status
    Status UnloadNetwork(NetworkId networkId);

    const IDeviceSpec& GetDeviceSpec() const { return m_DeviceSpec; }

    /// Gets the profiler corresponding to the given network id.
    /// @param networkId The id of the network for which to get the profile.
    /// @return A pointer to the requested profiler, or nullptr if not found.
    const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const;

    /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
    /// overlapped Execution by calling this function from different threads.
    std::unique_ptr<IWorkingMemHandle> CreateWorkingMemHandle(NetworkId networkId);

    /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
    /// @param networkId The id of the network to register the callback.
    /// @param func callback function to pass to the debug layer.
    void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func);

    /// Creates a runtime for workload execution.
    RuntimeImpl(const IRuntime::CreationOptions& options);

    ~RuntimeImpl();

    //NOTE: we won't need the profiling service reference but it is good to pass the service
    // in this way to facilitate other implementations down the road
    void ReportStructure(arm::pipe::IProfilingService& profilingService) override;

    void InitialiseProfilingService(arm::pipe::IProfilingService& profilingService) override;

private:
    friend void RuntimeLoadedNetworksReserve(RuntimeImpl* runtime); // See RuntimeTests.cpp

    friend arm::pipe::IProfilingService& GetProfilingService(RuntimeImpl* runtime); // See RuntimeTests.cpp

    int GenerateNetworkId();

    LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;

    template<typename Func>
    void LoadedNetworkFuncSafe(NetworkId networkId, Func f)
    {
#if !defined(ARMNN_DISABLE_THREADS)
        std::lock_guard<std::mutex> lockGuard(m_Mutex);
#endif
        auto iter = m_LoadedNetworks.find(networkId);
        if (iter != m_LoadedNetworks.end())
        {
            f(iter->second.get());
        }
    }

    /// Loads any available/compatible dynamic backend in the runtime.
    void LoadDynamicBackends(const std::string& overrideBackendPath);

#if !defined(ARMNN_DISABLE_THREADS)
    mutable std::mutex m_Mutex;
#endif

    /// Map of Loaded Networks with associated GUID as key
    LoadedNetworks m_LoadedNetworks;

    std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts;

    int m_NetworkIdCounter;

    DeviceSpec m_DeviceSpec;

    /// List of dynamic backends loaded in the runtime
    std::vector<DynamicBackendPtr> m_DynamicBackends;

    /// Profiling Service Instance
    std::unique_ptr<arm::pipe::IProfilingService> m_ProfilingService;
};

} // namespace armnn