diff options
author | Finn Williams <Finn.Williams@arm.com> | 2021-05-19 20:52:00 +0100 |
---|---|---|
committer | Finn Williams <Finn.Williams@arm.com> | 2021-05-26 14:09:49 +0100 |
commit | d8fb540568b29fd1d81a1cca667a1ad3e33ef5a1 (patch) | |
tree | 11754db57611c8069bfb3811eedd86b3304917ee /ArmnnPreparedModel_1_2.hpp | |
parent | 8920cae4be95ef68295ca458514f0cc257b14f80 (diff) | |
download | android-nn-driver-d8fb540568b29fd1d81a1cca667a1ad3e33ef5a1.tar.gz |
IVGCVSW-5781 Add Async Support to Android-NN-Driver
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I1f13d04100fdb119495b9e3054425bf3babc59f1
Diffstat (limited to 'ArmnnPreparedModel_1_2.hpp')
-rw-r--r-- | ArmnnPreparedModel_1_2.hpp | 65 |
1 files changed, 64 insertions, 1 deletions
diff --git a/ArmnnPreparedModel_1_2.hpp b/ArmnnPreparedModel_1_2.hpp index 13d7494e..6c630c56 100644 --- a/ArmnnPreparedModel_1_2.hpp +++ b/ArmnnPreparedModel_1_2.hpp @@ -44,7 +44,8 @@ public: armnn::IRuntime* runtime, const HalModel& model, const std::string& requestInputsAndOutputsDumpDir, - const bool gpuProfilingEnabled); + const bool gpuProfilingEnabled, + const bool asyncModelExecutionEnabled = false); virtual ~ArmnnPreparedModel_1_2(); @@ -76,6 +77,57 @@ public: bool ExecuteWithDummyInputs(); private: + + template<typename CallbackContext> + class ArmnnThreadPoolCallback_1_2 : public armnn::IAsyncExecutionCallback + { + public: + ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion>* model, + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + std::vector<V1_2::OutputShape> outputShapes, + std::shared_ptr<armnn::InputTensors>& inputTensors, + std::shared_ptr<armnn::OutputTensors>& outputTensors, + CallbackContext callbackContext) : + m_Model(model), + m_MemPools(pMemPools), + m_OutputShapes(outputShapes), + m_InputTensors(inputTensors), + m_OutputTensors(outputTensors), + m_CallbackContext(callbackContext) + {} + + void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override; + + // Retrieve the Arm NN Status from the AsyncExecutionCallback that has been notified + virtual armnn::Status GetStatus() const override + { + return armnn::Status::Success; + } + + // Block the calling thread until the AsyncExecutionCallback object allows it to proceed + virtual void Wait() const override + {} + + // Retrieve the start time before executing the inference + virtual armnn::HighResolutionClock GetStartTime() const override + { + return std::chrono::high_resolution_clock::now(); + } + + // Retrieve the time after executing the inference + virtual armnn::HighResolutionClock GetEndTime() const override + { + return std::chrono::high_resolution_clock::now(); + } + + ArmnnPreparedModel_1_2<HalVersion>* m_Model; + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools; + std::vector<V1_2::OutputShape> m_OutputShapes; + std::shared_ptr<armnn::InputTensors> m_InputTensors; + std::shared_ptr<armnn::OutputTensors> m_OutputTensors; + CallbackContext m_CallbackContext; + }; + Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request, V1_2::MeasureTiming measureTiming, CallbackAsync_1_2 callback); @@ -101,6 +153,14 @@ private: template <typename TensorBindingCollection> void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings); + /// schedule the graph prepared from the request for execution + template<typename CallbackContext> + void ScheduleGraphForExecution( + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + std::shared_ptr<armnn::InputTensors>& inputTensors, + std::shared_ptr<armnn::OutputTensors>& outputTensors, + CallbackContext m_CallbackContext); + armnn::NetworkId m_NetworkId; armnn::IRuntime* m_Runtime; V1_2::Model m_Model; @@ -112,6 +172,9 @@ private: uint32_t m_RequestCount; const std::string& m_RequestInputsAndOutputsDumpDir; const bool m_GpuProfilingEnabled; + + std::unique_ptr<IWorkingMemHandle> m_WorkingMemHandle; + const bool m_AsyncModelExecutionEnabled; }; } |