aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel_1_2.hpp
blob: 6c630c56f1aa21b4403da0cee54cb9c333668447 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "ArmnnDriver.hpp"
#include "ArmnnDriverImpl.hpp"
#include "RequestThread.hpp"
#include "ModelToINetworkConverter.hpp"

#include <NeuralNetworks.h>
#include <armnn/ArmNN.hpp>

#include <string>
#include <vector>

namespace armnn_driver
{

using CallbackAsync_1_2 = std::function<
                                void(V1_0::ErrorStatus errorStatus,
                                     std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
                                     const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
                                     std::string callingFunction)>;

struct ExecutionContext_1_2
{
    ::android::hardware::neuralnetworks::V1_2::MeasureTiming    measureTimings =
        ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
    TimePoint driverStart;
};

using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>;

template <typename HalVersion>
class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
{
public:
    using HalModel = typename V1_2::Model;

    ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
                           armnn::IRuntime* runtime,
                           const HalModel& model,
                           const std::string& requestInputsAndOutputsDumpDir,
                           const bool gpuProfilingEnabled,
                           const bool asyncModelExecutionEnabled = false);

    virtual ~ArmnnPreparedModel_1_2();

    virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
                                              const ::android::sp<V1_0::IExecutionCallback>& callback) override;

    virtual Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, V1_2::MeasureTiming measure,
                                                  const ::android::sp<V1_2::IExecutionCallback>& callback) override;

    virtual Return<void> executeSynchronously(const V1_0::Request &request,
                                              V1_2::MeasureTiming measure,
                                              V1_2::IPreparedModel::executeSynchronously_cb cb) override;

    virtual Return<void> configureExecutionBurst(
            const ::android::sp<V1_2::IBurstCallback>& callback,
            const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
            const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
            configureExecutionBurst_cb cb) override;

    /// execute the graph prepared from the request
    template<typename CallbackContext>
    bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
                      armnn::InputTensors& inputTensors,
                      armnn::OutputTensors& outputTensors,
                      CallbackContext callback);

    /// Executes this model with dummy inputs (e.g. all zeroes).
    /// \return false on failure, otherwise true
    bool ExecuteWithDummyInputs();

private:

    template<typename CallbackContext>
    class ArmnnThreadPoolCallback_1_2 : public armnn::IAsyncExecutionCallback
    {
    public:
        ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion>* model,
                                    std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
                                    std::vector<V1_2::OutputShape> outputShapes,
                                    std::shared_ptr<armnn::InputTensors>& inputTensors,
                                    std::shared_ptr<armnn::OutputTensors>& outputTensors,
                                    CallbackContext callbackContext) :
                m_Model(model),
                m_MemPools(pMemPools),
                m_OutputShapes(outputShapes),
                m_InputTensors(inputTensors),
                m_OutputTensors(outputTensors),
                m_CallbackContext(callbackContext)
        {}

        void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;

        // Retrieve the Arm NN Status from the AsyncExecutionCallback that has been notified
        virtual armnn::Status GetStatus() const override
        {
            return armnn::Status::Success;
        }

        // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
        virtual void Wait() const override
        {}

        // Retrieve the start time before executing the inference
        virtual armnn::HighResolutionClock GetStartTime() const override
        {
            return std::chrono::high_resolution_clock::now();
        }

        // Retrieve the time after executing the inference
        virtual armnn::HighResolutionClock GetEndTime() const override
        {
            return std::chrono::high_resolution_clock::now();
        }

        ArmnnPreparedModel_1_2<HalVersion>* m_Model;
        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
        std::vector<V1_2::OutputShape> m_OutputShapes;
        std::shared_ptr<armnn::InputTensors> m_InputTensors;
        std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
        CallbackContext m_CallbackContext;
    };

    Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
                                      V1_2::MeasureTiming measureTiming,
                                      CallbackAsync_1_2 callback);

    Return<V1_0::ErrorStatus> PrepareMemoryForInputs(
            armnn::InputTensors& inputs,
            const V1_0::Request& request,
            const std::vector<android::nn::RunTimePoolInfo>& memPools);

    Return<V1_0::ErrorStatus> PrepareMemoryForOutputs(
            armnn::OutputTensors& outputs,
            std::vector<V1_2::OutputShape> &outputShapes,
            const V1_0::Request& request,
            const std::vector<android::nn::RunTimePoolInfo>& memPools);

    Return <V1_0::ErrorStatus> PrepareMemoryForIO(
            armnn::InputTensors& inputs,
            armnn::OutputTensors& outputs,
            std::vector<android::nn::RunTimePoolInfo>& memPools,
            const V1_0::Request& request,
            CallbackAsync_1_2 callback);

    template <typename TensorBindingCollection>
    void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);

    /// schedule the graph prepared from the request for execution
    template<typename CallbackContext>
    void ScheduleGraphForExecution(
            std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
            std::shared_ptr<armnn::InputTensors>& inputTensors,
            std::shared_ptr<armnn::OutputTensors>& outputTensors,
            CallbackContext m_CallbackContext);

    armnn::NetworkId                                                            m_NetworkId;
    armnn::IRuntime*                                                            m_Runtime;
    V1_2::Model                                                                 m_Model;
    // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
    // It is specific to this class, so it is declared as static here
    static RequestThread<ArmnnPreparedModel_1_2,
                         HalVersion,
                         CallbackContext_1_2>                                   m_RequestThread;
    uint32_t                                                                    m_RequestCount;
    const std::string&                                                          m_RequestInputsAndOutputsDumpDir;
    const bool                                                                  m_GpuProfilingEnabled;

    std::unique_ptr<IWorkingMemHandle> m_WorkingMemHandle;
    const bool m_AsyncModelExecutionEnabled;
};

}