aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel.hpp
blob: d1c830d4877bb68f59cf5f87e35c20413017040d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "ArmnnDriver.hpp"
#include "ArmnnDriverImpl.hpp"
#include "RequestThread.hpp"

#include <NeuralNetworks.h>
#include <armnn/ArmNN.hpp>

#include <string>
#include <vector>

namespace armnn_driver
{
using armnnExecuteCallback_1_0 = std::function<void(V1_0::ErrorStatus status, std::string callingFunction)>;

struct ArmnnCallback_1_0
{
    armnnExecuteCallback_1_0 callback;
};

struct ExecutionContext_1_0 {};

using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>;

template <typename HalVersion>
class ArmnnPreparedModel : public V1_0::IPreparedModel
{
public:
    using HalModel = typename HalVersion::Model;

    ArmnnPreparedModel(armnn::NetworkId networkId,
                       armnn::IRuntime* runtime,
                       const HalModel& model,
                       const std::string& requestInputsAndOutputsDumpDir,
                       const bool gpuProfilingEnabled,
                       const bool asyncModelExecutionEnabled = false);

    virtual ~ArmnnPreparedModel();

    virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
                                              const ::android::sp<V1_0::IExecutionCallback>& callback) override;

    /// execute the graph prepared from the request
    void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
                      armnn::InputTensors& inputTensors,
                      armnn::OutputTensors& outputTensors,
                      CallbackContext_1_0 callback);

    /// Executes this model with dummy inputs (e.g. all zeroes).
    /// \return false on failure, otherwise true
    bool ExecuteWithDummyInputs();

private:

    template<typename CallbackContext>
    class ArmnnThreadPoolCallback : public armnn::IAsyncExecutionCallback
    {
    public:
        ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion>* model,
                                std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
                                std::shared_ptr<armnn::InputTensors>& inputTensors,
                                std::shared_ptr<armnn::OutputTensors>& outputTensors,
                                CallbackContext callbackContext) :
                m_Model(model),
                m_MemPools(pMemPools),
                m_InputTensors(inputTensors),
                m_OutputTensors(outputTensors),
                m_CallbackContext(callbackContext)
        {}

        void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;

        // Retrieve the ArmNN Status from the AsyncExecutionCallback that has been notified
        virtual armnn::Status GetStatus() const override
        {
            return armnn::Status::Success;
        }

        // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
        virtual void Wait() const override
        {}

        // Retrieve the start time before executing the inference
        virtual armnn::HighResolutionClock GetStartTime() const override
        {
            return std::chrono::high_resolution_clock::now();
        }

        // Retrieve the time after executing the inference
        virtual armnn::HighResolutionClock GetEndTime() const override
        {
            return std::chrono::high_resolution_clock::now();
        }

        ArmnnPreparedModel<HalVersion>* m_Model;
        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
        std::shared_ptr<armnn::InputTensors> m_InputTensors;
        std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
        CallbackContext m_CallbackContext;
    };

    template <typename TensorBindingCollection>
    void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);

    /// schedule the graph prepared from the request for execution
    template<typename CallbackContext>
    void ScheduleGraphForExecution(
            std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
            std::shared_ptr<armnn::InputTensors>& inputTensors,
            std::shared_ptr<armnn::OutputTensors>& outputTensors,
            CallbackContext m_CallbackContext);

    armnn::NetworkId                                                        m_NetworkId;
    armnn::IRuntime*                                                        m_Runtime;
    HalModel                                                                m_Model;
    // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
    // It is specific to this class, so it is declared as static here
    static RequestThread<ArmnnPreparedModel, HalVersion, CallbackContext_1_0> m_RequestThread;
    uint32_t                                                                m_RequestCount;
    const std::string&                                                      m_RequestInputsAndOutputsDumpDir;
    const bool                                                              m_GpuProfilingEnabled;

    std::unique_ptr<armnn::IWorkingMemHandle> m_WorkingMemHandle;
    const bool m_AsyncModelExecutionEnabled;
};

}