From 4de83c5a6a57d0468d9f2f854c94bc4a760b66b6 Mon Sep 17 00:00:00 2001 From: Derek Lamberti Date: Tue, 17 Mar 2020 13:40:18 +0000 Subject: Less code duplication in HAL 1.2 Signed-off-by: Derek Lamberti Change-Id: Ic2e8964745a4323efb1e06d466c0699f17a70c55 --- ArmnnPreparedModel_1_2.hpp | 54 +++++++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 15 deletions(-) (limited to 'ArmnnPreparedModel_1_2.hpp') diff --git a/ArmnnPreparedModel_1_2.hpp b/ArmnnPreparedModel_1_2.hpp index f609ef7e..e68614a0 100644 --- a/ArmnnPreparedModel_1_2.hpp +++ b/ArmnnPreparedModel_1_2.hpp @@ -19,18 +19,21 @@ namespace armnn_driver { -typedef std::function outputShapes, - const ::android::hardware::neuralnetworks::V1_2::Timing& timing, - std::string callingFunction)> armnnExecuteCallback_1_2; +using CallbackAsync_1_2 = std::function< + void(V1_0::ErrorStatus errorStatus, + std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes, + const ::android::hardware::neuralnetworks::V1_2::Timing& timing, + std::string callingFunction)>; -struct ArmnnCallback_1_2 +struct ExecutionContext_1_2 { - armnnExecuteCallback_1_2 callback; + ::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings = + ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO; TimePoint driverStart; - MeasureTiming measureTiming; }; +using CallbackContext_1_2 = CallbackContext; + template class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel { @@ -62,19 +65,38 @@ public: configureExecutionBurst_cb cb) override; /// execute the graph prepared from the request - void ExecuteGraph(std::shared_ptr>& pMemPools, - std::shared_ptr& pInputTensors, - std::shared_ptr& pOutputTensors, - ArmnnCallback_1_2 callbackDescriptor); + template + bool ExecuteGraph(std::shared_ptr>& pMemPools, + armnn::InputTensors& inputTensors, + armnn::OutputTensors& outputTensors, + CallbackContext callback); /// Executes this model with dummy inputs (e.g. all zeroes). /// \return false on failure, otherwise true bool ExecuteWithDummyInputs(); private: - Return Execute(const V1_0::Request& request, - MeasureTiming measureTiming, - armnnExecuteCallback_1_2 callback); + Return Execute(const V1_0::Request& request, + MeasureTiming measureTiming, + CallbackAsync_1_2 callback); + + Return PrepareMemoryForInputs( + armnn::InputTensors& inputs, + const V1_0::Request& request, + const std::vector& memPools); + + Return PrepareMemoryForOutputs( + armnn::OutputTensors& outputs, + std::vector &outputShapes, + const V1_0::Request& request, + const std::vector& memPools); + + Return PrepareMemoryForIO( + armnn::InputTensors& inputs, + armnn::OutputTensors& outputs, + std::vector& memPools, + const V1_0::Request& request, + CallbackAsync_1_2 callback); template void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings); @@ -84,7 +106,9 @@ private: V1_2::Model m_Model; // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads // It is specific to this class, so it is declared as static here - static RequestThread m_RequestThread; + static RequestThread m_RequestThread; uint32_t m_RequestCount; const std::string& m_RequestInputsAndOutputsDumpDir; const bool m_GpuProfilingEnabled; -- cgit v1.2.1