aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel_1_3.hpp
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2021-05-19 20:52:00 +0100
committerFinn Williams <Finn.Williams@arm.com>2021-05-26 14:09:49 +0100
commitd8fb540568b29fd1d81a1cca667a1ad3e33ef5a1 (patch)
tree11754db57611c8069bfb3811eedd86b3304917ee /ArmnnPreparedModel_1_3.hpp
parent8920cae4be95ef68295ca458514f0cc257b14f80 (diff)
downloadandroid-nn-driver-d8fb540568b29fd1d81a1cca667a1ad3e33ef5a1.tar.gz
IVGCVSW-5781 Add Async Support to Android-NN-Driver
Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I1f13d04100fdb119495b9e3054425bf3babc59f1
Diffstat (limited to 'ArmnnPreparedModel_1_3.hpp')
-rw-r--r--ArmnnPreparedModel_1_3.hpp66
1 files changed, 65 insertions, 1 deletions
diff --git a/ArmnnPreparedModel_1_3.hpp b/ArmnnPreparedModel_1_3.hpp
index c6cdcdc7..11299cc4 100644
--- a/ArmnnPreparedModel_1_3.hpp
+++ b/ArmnnPreparedModel_1_3.hpp
@@ -51,7 +51,8 @@ public:
const HalModel& model,
const std::string& requestInputsAndOutputsDumpDir,
const bool gpuProfilingEnabled,
- V1_3::Priority priority = V1_3::Priority::MEDIUM);
+ V1_3::Priority priority = V1_3::Priority::MEDIUM,
+ const bool asyncModelExecutionEnabled = false);
virtual ~ArmnnPreparedModel_1_3();
@@ -109,6 +110,57 @@ public:
V1_3::Priority GetModelPriority();
private:
+
+ template<typename CallbackContext>
+ class ArmnnThreadPoolCallback_1_3 : public armnn::IAsyncExecutionCallback
+ {
+ public:
+ ArmnnThreadPoolCallback_1_3(ArmnnPreparedModel_1_3<HalVersion>* model,
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+ std::vector<V1_2::OutputShape> outputShapes,
+ std::shared_ptr<armnn::InputTensors>& inputTensors,
+ std::shared_ptr<armnn::OutputTensors>& outputTensors,
+ CallbackContext callbackContext) :
+ m_Model(model),
+ m_MemPools(pMemPools),
+ m_OutputShapes(outputShapes),
+ m_InputTensors(inputTensors),
+ m_OutputTensors(outputTensors),
+ m_CallbackContext(callbackContext)
+ {}
+
+ void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
+
+ // Retrieve the Arm NN Status from the AsyncExecutionCallback that has been notified
+ virtual armnn::Status GetStatus() const override
+ {
+ return armnn::Status::Success;
+ }
+
+ // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
+ virtual void Wait() const override
+ {}
+
+ // Retrieve the start time before executing the inference
+ virtual armnn::HighResolutionClock GetStartTime() const override
+ {
+ return std::chrono::high_resolution_clock::now();
+ }
+
+ // Retrieve the time after executing the inference
+ virtual armnn::HighResolutionClock GetEndTime() const override
+ {
+ return std::chrono::high_resolution_clock::now();
+ }
+
+ ArmnnPreparedModel_1_3<HalVersion>* m_Model;
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
+ std::vector<V1_2::OutputShape> m_OutputShapes;
+ std::shared_ptr<armnn::InputTensors> m_InputTensors;
+ std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
+ CallbackContext m_CallbackContext;
+ };
+
Return <V1_3::ErrorStatus> Execute(const V1_3::Request& request,
V1_2::MeasureTiming measureTiming,
CallbackAsync_1_3 callback);
@@ -133,6 +185,15 @@ private:
template <typename TensorBindingCollection>
void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
+ /// schedule the graph prepared from the request for execution
+ template<typename CallbackContext>
+ void ScheduleGraphForExecution(
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+ std::shared_ptr<armnn::InputTensors>& inputTensors,
+ std::shared_ptr<armnn::OutputTensors>& outputTensors,
+ CallbackContext m_CallbackContext,
+ armnn::QosExecPriority priority);
+
armnn::NetworkId m_NetworkId;
armnn::IRuntime* m_Runtime;
V1_3::Model m_Model;
@@ -143,6 +204,9 @@ private:
const std::string& m_RequestInputsAndOutputsDumpDir;
const bool m_GpuProfilingEnabled;
V1_3::Priority m_ModelPriority;
+
+ std::unique_ptr<IWorkingMemHandle> m_WorkingMemHandle;
+ const bool m_AsyncModelExecutionEnabled;
};
}