aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel_1_2.hpp
blob: 57deb98ca12c80873a83ef5c718668d7cbb2c00f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "ArmnnDriver.hpp"
#include "ArmnnDriverImpl.hpp"
#include "RequestThread.hpp"
#include "ModelToINetworkConverter.hpp"

#include <NeuralNetworks.h>
#include <armnn/ArmNN.hpp>
#include <armnn/Threadpool.hpp>

#include <string>
#include <vector>

namespace armnn_driver
{

using CallbackAsync_1_2 = std::function<
                                void(V1_0::ErrorStatus errorStatus,
                                     std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
                                     const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
                                     std::string callingFunction)>;

struct ExecutionContext_1_2
{
    ::android::hardware::neuralnetworks::V1_2::MeasureTiming    measureTimings =
        ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
    TimePoint driverStart;
};

using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>;

template <typename HalVersion>
class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
{
public:
    using HalModel = typename V1_2::Model;

    ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
                           armnn::IRuntime* runtime,
                           const HalModel& model,
                           const std::string& requestInputsAndOutputsDumpDir,
                           const bool gpuProfilingEnabled,
                           const bool asyncModelExecutionEnabled = false,
                           const unsigned int numberOfThreads = 1,
                           const bool importEnabled = false,
                           const bool exportEnabled = false);

    ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
                           armnn::IRuntime* runtime,
                           const std::string& requestInputsAndOutputsDumpDir,
                           const bool gpuProfilingEnabled,
                           const bool asyncModelExecutionEnabled = false,
                           const unsigned int numberOfThreads = 1,
                           const bool importEnabled = false,
                           const bool exportEnabled = false,
                           const bool preparedFromCache = false);

    virtual ~ArmnnPreparedModel_1_2();

    virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
                                              const ::android::sp<V1_0::IExecutionCallback>& callback) override;

    virtual Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, V1_2::MeasureTiming measure,
                                                  const ::android::sp<V1_2::IExecutionCallback>& callback) override;

    virtual Return<void> executeSynchronously(const V1_0::Request &request,
                                              V1_2::MeasureTiming measure,
                                              V1_2::IPreparedModel::executeSynchronously_cb cb) override;

    virtual Return<void> configureExecutionBurst(
            const ::android::sp<V1_2::IBurstCallback>& callback,
            const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
            const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
            configureExecutionBurst_cb cb) override;

    /// execute the graph prepared from the request
    template<typename CallbackContext>
    bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
                      armnn::InputTensors& inputTensors,
                      armnn::OutputTensors& outputTensors,
                      CallbackContext callback);

    /// Executes this model with dummy inputs (e.g. all zeroes).
    /// \return false on failure, otherwise true
    bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs);

private:

    template<typename CallbackContext>
    class ArmnnThreadPoolCallback_1_2 : public armnn::IAsyncExecutionCallback
    {
    public:
        ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion>* model,
                                    std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
                                    std::vector<V1_2::OutputShape> outputShapes,
                                    std::shared_ptr<armnn::InputTensors>& inputTensors,
                                    std::shared_ptr<armnn::OutputTensors>& outputTensors,
                                    CallbackContext callbackContext) :
                m_Model(model),
                m_MemPools(pMemPools),
                m_OutputShapes(outputShapes),
                m_InputTensors(inputTensors),
                m_OutputTensors(outputTensors),
                m_CallbackContext(callbackContext)
        {}

        void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;

        ArmnnPreparedModel_1_2<HalVersion>* m_Model;
        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
        std::vector<V1_2::OutputShape> m_OutputShapes;
        std::shared_ptr<armnn::InputTensors> m_InputTensors;
        std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
        CallbackContext m_CallbackContext;
    };

    Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
                                      V1_2::MeasureTiming measureTiming,
                                      CallbackAsync_1_2 callback);

    Return<V1_0::ErrorStatus> PrepareMemoryForInputs(
            armnn::InputTensors& inputs,
            const V1_0::Request& request,
            const std::vector<android::nn::RunTimePoolInfo>& memPools);

    Return<V1_0::ErrorStatus> PrepareMemoryForOutputs(
            armnn::OutputTensors& outputs,
            std::vector<V1_2::OutputShape> &outputShapes,
            const V1_0::Request& request,
            const std::vector<android::nn::RunTimePoolInfo>& memPools);

    Return <V1_0::ErrorStatus> PrepareMemoryForIO(
            armnn::InputTensors& inputs,
            armnn::OutputTensors& outputs,
            std::vector<android::nn::RunTimePoolInfo>& memPools,
            const V1_0::Request& request,
            CallbackAsync_1_2 callback);

    template <typename TensorBindingCollection>
    void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);

    /// schedule the graph prepared from the request for execution
    template<typename CallbackContext>
    void ScheduleGraphForExecution(
            std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
            std::shared_ptr<armnn::InputTensors>& inputTensors,
            std::shared_ptr<armnn::OutputTensors>& outputTensors,
            CallbackContext m_CallbackContext);

    armnn::NetworkId                          m_NetworkId;
    armnn::IRuntime*                          m_Runtime;
    V1_2::Model                               m_Model;
    // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
    // It is specific to this class, so it is declared as static here
    static RequestThread<ArmnnPreparedModel_1_2,
                         HalVersion,
                         CallbackContext_1_2> m_RequestThread;
    uint32_t                                  m_RequestCount;
    const std::string&                        m_RequestInputsAndOutputsDumpDir;
    const bool                                m_GpuProfilingEnabled;
    // Static to allow sharing of threadpool between ArmnnPreparedModel instances
    static std::unique_ptr<armnn::Threadpool> m_Threadpool;
    std::shared_ptr<IWorkingMemHandle>        m_WorkingMemHandle;
    const bool                                m_AsyncModelExecutionEnabled;
    const bool                                m_EnableImport;
    const bool                                m_EnableExport;
    const bool                                m_PreparedFromCache;
};

}