diff options
Diffstat (limited to 'ArmnnPreparedModel_1_2.hpp')
-rw-r--r-- | ArmnnPreparedModel_1_2.hpp | 54 |
1 files changed, 39 insertions, 15 deletions
diff --git a/ArmnnPreparedModel_1_2.hpp b/ArmnnPreparedModel_1_2.hpp index f609ef7e..e68614a0 100644 --- a/ArmnnPreparedModel_1_2.hpp +++ b/ArmnnPreparedModel_1_2.hpp @@ -19,18 +19,21 @@ namespace armnn_driver { -typedef std::function<void(::android::hardware::neuralnetworks::V1_0::ErrorStatus status, - std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes, - const ::android::hardware::neuralnetworks::V1_2::Timing& timing, - std::string callingFunction)> armnnExecuteCallback_1_2; +using CallbackAsync_1_2 = std::function< + void(V1_0::ErrorStatus errorStatus, + std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes, + const ::android::hardware::neuralnetworks::V1_2::Timing& timing, + std::string callingFunction)>; -struct ArmnnCallback_1_2 +struct ExecutionContext_1_2 { - armnnExecuteCallback_1_2 callback; + ::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings = + ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO; TimePoint driverStart; - MeasureTiming measureTiming; }; +using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>; + template <typename HalVersion> class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel { @@ -62,19 +65,38 @@ public: configureExecutionBurst_cb cb) override; /// execute the graph prepared from the request - void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, - std::shared_ptr<armnn::InputTensors>& pInputTensors, - std::shared_ptr<armnn::OutputTensors>& pOutputTensors, - ArmnnCallback_1_2 callbackDescriptor); + template<typename CallbackContext> + bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + armnn::InputTensors& inputTensors, + armnn::OutputTensors& outputTensors, + CallbackContext callback); /// Executes this model with dummy inputs (e.g. all zeroes). /// \return false on failure, otherwise true bool ExecuteWithDummyInputs(); private: - Return <V1_0::ErrorStatus> Execute(const V1_0::Request& request, - MeasureTiming measureTiming, - armnnExecuteCallback_1_2 callback); + Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request, + MeasureTiming measureTiming, + CallbackAsync_1_2 callback); + + Return<V1_0::ErrorStatus> PrepareMemoryForInputs( + armnn::InputTensors& inputs, + const V1_0::Request& request, + const std::vector<android::nn::RunTimePoolInfo>& memPools); + + Return<V1_0::ErrorStatus> PrepareMemoryForOutputs( + armnn::OutputTensors& outputs, + std::vector<OutputShape> &outputShapes, + const V1_0::Request& request, + const std::vector<android::nn::RunTimePoolInfo>& memPools); + + Return <V1_0::ErrorStatus> PrepareMemoryForIO( + armnn::InputTensors& inputs, + armnn::OutputTensors& outputs, + std::vector<android::nn::RunTimePoolInfo>& memPools, + const V1_0::Request& request, + CallbackAsync_1_2 callback); template <typename TensorBindingCollection> void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings); @@ -84,7 +106,9 @@ private: V1_2::Model m_Model; // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads // It is specific to this class, so it is declared as static here - static RequestThread<ArmnnPreparedModel_1_2, HalVersion, ArmnnCallback_1_2> m_RequestThread; + static RequestThread<ArmnnPreparedModel_1_2, + HalVersion, + CallbackContext_1_2> m_RequestThread; uint32_t m_RequestCount; const std::string& m_RequestInputsAndOutputsDumpDir; const bool m_GpuProfilingEnabled; |