aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Runtime.hpp
blob: 3c90c51bb296619e70aa1b9cf8424156f195daef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
//
// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include "LoadedNetwork.hpp"
#include "DeviceSpec.hpp"

#include <armnn/INetwork.hpp>
#include <armnn/IRuntime.hpp>
#include <armnn/Tensor.hpp>
#include <armnn/BackendId.hpp>

#include <armnn/backends/DynamicBackend.hpp>

#include <ProfilingService.hpp>

#include <IProfilingService.hpp>
#include <IReportStructure.hpp>

#include <mutex>
#include <unordered_map>

namespace armnn
{
using LoadedNetworks = std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>>;
using IReportStructure = profiling::IReportStructure;

class Runtime final :  public IRuntime,
                       public IReportStructure
{
public:
    /// Loads a complete network into the Runtime.
    /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference.
    /// @param [in] network - Complete network to load into the Runtime.
    /// The runtime takes ownership of the network once passed in.
    /// @return armnn::Status
    virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override;

    /// Load a complete network into the IRuntime.
    /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
    /// @param [in] network Complete network to load into the IRuntime.
    /// @param [out] errorMessage Error message if there were any errors.
    /// The runtime takes ownership of the network once passed in.
    /// @return armnn::Status
    virtual Status LoadNetwork(NetworkId& networkIdOut,
                               IOptimizedNetworkPtr network,
                               std::string& errorMessage) override;

    virtual Status LoadNetwork(NetworkId& networkIdOut,
                               IOptimizedNetworkPtr network,
                               std::string& errorMessage,
                               const INetworkProperties& networkProperties) override;

    virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
    virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;

    // Evaluates network using input in inputTensors, outputs filled into outputTensors.
    virtual Status EnqueueWorkload(NetworkId networkId,
        const InputTensors& inputTensors,
        const OutputTensors& outputTensors) override;

    /// Unloads a network from the Runtime.
    /// At the moment this only removes the network from the m_Impl->m_Network.
    /// This might need more work in the future to be AndroidNN compliant.
    /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
    /// @return armnn::Status
    virtual Status UnloadNetwork(NetworkId networkId) override;

    virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; }

    /// Gets the profiler corresponding to the given network id.
    /// @param networkId The id of the network for which to get the profile.
    /// @return A pointer to the requested profiler, or nullptr if not found.
    virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override;

    /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
    /// @param networkId The id of the network to register the callback.
    /// @param func callback function to pass to the debug layer.
    virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) override;

    /// Creates a runtime for workload execution.
    Runtime(const CreationOptions& options);

    ~Runtime();

    //NOTE: we won't need the profiling service reference but it is good to pass the service
    // in this way to facilitate other implementations down the road
    virtual void ReportStructure() override;

private:
    friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp

    friend profiling::ProfilingService& GetProfilingService(armnn::Runtime* runtime); // See RuntimeTests.cpp

    int GenerateNetworkId();

    LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;

    template<typename Func>
    void LoadedNetworkFuncSafe(NetworkId networkId, Func f)
    {
        std::lock_guard<std::mutex> lockGuard(m_Mutex);
        auto iter = m_LoadedNetworks.find(networkId);
        if (iter != m_LoadedNetworks.end())
        {
            f(iter->second.get());
        }
    }

    /// Loads any available/compatible dynamic backend in the runtime.
    void LoadDynamicBackends(const std::string& overrideBackendPath);

    mutable std::mutex m_Mutex;

    /// Map of Loaded Networks with associated GUID as key
    LoadedNetworks m_LoadedNetworks;

    std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts;

    int m_NetworkIdCounter;

    DeviceSpec m_DeviceSpec;

    /// List of dynamic backends loaded in the runtime
    std::vector<DynamicBackendPtr> m_DynamicBackends;

    /// Profiling Service Instance
    profiling::ProfilingService m_ProfilingService;
};

} // namespace armnn