aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel_1_2.cpp
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2021-05-19 20:52:00 +0100
committerFinn Williams <Finn.Williams@arm.com>2021-05-26 14:09:49 +0100
commitd8fb540568b29fd1d81a1cca667a1ad3e33ef5a1 (patch)
tree11754db57611c8069bfb3811eedd86b3304917ee /ArmnnPreparedModel_1_2.cpp
parent8920cae4be95ef68295ca458514f0cc257b14f80 (diff)
downloadandroid-nn-driver-d8fb540568b29fd1d81a1cca667a1ad3e33ef5a1.tar.gz
IVGCVSW-5781 Add Async Support to Android-NN-Driver
Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I1f13d04100fdb119495b9e3054425bf3babc59f1
Diffstat (limited to 'ArmnnPreparedModel_1_2.cpp')
-rw-r--r--ArmnnPreparedModel_1_2.cpp117
1 files changed, 114 insertions, 3 deletions
diff --git a/ArmnnPreparedModel_1_2.cpp b/ArmnnPreparedModel_1_2.cpp
index a2148c29..c129fd69 100644
--- a/ArmnnPreparedModel_1_2.cpp
+++ b/ArmnnPreparedModel_1_2.cpp
@@ -6,6 +6,7 @@
#define LOG_TAG "ArmnnDriver"
#include "ArmnnPreparedModel_1_2.hpp"
+
#include "Utils.hpp"
#include <log/log.h>
@@ -146,16 +147,23 @@ ArmnnPreparedModel_1_2<HalVersion>::ArmnnPreparedModel_1_2(armnn::NetworkId netw
armnn::IRuntime* runtime,
const V1_2::Model& model,
const std::string& requestInputsAndOutputsDumpDir,
- const bool gpuProfilingEnabled)
+ const bool gpuProfilingEnabled,
+ const bool asyncModelExecutionEnabled)
: m_NetworkId(networkId)
, m_Runtime(runtime)
, m_Model(model)
, m_RequestCount(0)
, m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
, m_GpuProfilingEnabled(gpuProfilingEnabled)
+ , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled)
{
// Enable profiling if required.
m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
+
+ if (asyncModelExecutionEnabled)
+ {
+ m_WorkingMemHandle = m_Runtime->CreateWorkingMemHandle(networkId);
+ }
}
template<typename HalVersion>
@@ -440,7 +448,17 @@ bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph(
deviceStart = Now();
}
- armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+ armnn::Status status;
+ if (m_AsyncModelExecutionEnabled)
+ {
+ ALOGW("ArmnnPreparedModel_1_2::ExecuteGraph m_AsyncModelExecutionEnabled true");
+ status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors);
+ }
+ else
+ {
+ ALOGW("ArmnnPreparedModel_1_2::ExecuteGraph m_AsyncModelExecutionEnabled false");
+ status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+ }
if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
{
@@ -567,12 +585,21 @@ Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::Execute(const V1_
{}
}
- ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg");
// post the request for asynchronous execution
CallbackContext_1_2 cb;
cb.callback = callback;
cb.ctx = ctx;
+
+ if (m_AsyncModelExecutionEnabled)
+ {
+ ALOGV("ArmnnPreparedModel_1_2::execute(...) before ScheduleGraphForExecution");
+ ScheduleGraphForExecution(memPools, inputTensors, outputTensors, cb);
+ ALOGV("ArmnnPreparedModel_1_2::execute(...) after ScheduleGraphForExecution");
+ return V1_0::ErrorStatus::NONE;
+ }
+
+ ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg");
m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb);
ALOGV("ArmnnPreparedModel_1_2::execute(...) after PostMsg");
return V1_0::ErrorStatus::NONE;
@@ -602,6 +629,84 @@ Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst(
return Void();
}
+/// Schedule the graph prepared from the request for execution
+template<typename HalVersion>
+template<typename CallbackContext>
+void ArmnnPreparedModel_1_2<HalVersion>::ScheduleGraphForExecution(
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+ std::shared_ptr<armnn::InputTensors>& inputTensors,
+ std::shared_ptr<armnn::OutputTensors>& outputTensors,
+ CallbackContext callbackContext)
+{
+ ALOGV("ArmnnPreparedModel_1_2::ScheduleGraphForExecution(...)");
+
+ DumpTensorsIfRequired("Input", *inputTensors);
+
+ unsigned int outputTensorSize = outputTensors.get()->size();
+ std::vector<V1_2::OutputShape> outputShapes(outputTensorSize);
+ for (unsigned int i = 0; i < outputTensorSize; i++)
+ {
+ std::pair<int, armnn::Tensor> outputTensorPair = outputTensors.get()->at(i);
+ const armnn::Tensor outputTensor = outputTensorPair.second;
+ const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
+
+ outputShapes[i] = ComputeShape(outputTensorInfo);
+ }
+
+ auto tpCb = std::make_shared<
+ ArmnnThreadPoolCallback_1_2<CallbackContext_1_2>>(this,
+ pMemPools,
+ outputShapes,
+ inputTensors,
+ outputTensors,
+ callbackContext);
+
+ m_Runtime->Schedule(m_NetworkId,
+ *tpCb->m_InputTensors,
+ *tpCb->m_OutputTensors,
+ armnn::QosExecPriority::High,
+ tpCb);
+ ALOGV("ArmnnPreparedModel_1_2::ScheduleGraphForExecution end");
+}
+
+template<typename HalVersion>
+template <typename CallbackContext>
+void ArmnnPreparedModel_1_2<HalVersion>::ArmnnThreadPoolCallback_1_2<CallbackContext>::Notify(
+ armnn::Status status, armnn::InferenceTimingPair timeTaken)
+{
+ ALOGV("ArmnnPreparedModel_1_2::ArmnnThreadPoolCallback_1_2 Notify");
+
+ TimePoint driverEnd;
+
+ CommitPools(*m_MemPools);
+
+ m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors);
+
+ if (status != armnn::Status::Success)
+ {
+ ALOGW("ArmnnThreadPoolCallback::Notify EnqueueWorkload failed");
+ m_CallbackContext.callback(
+ V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel::ExecuteGraph");
+ return;
+ }
+
+ if (m_CallbackContext.ctx.measureTimings == V1_2::MeasureTiming::YES)
+ {
+ driverEnd = std::chrono::steady_clock::now();
+ V1_2::Timing timing;
+ timing.timeOnDevice = MicrosecondsDuration(timeTaken.second, timeTaken.first);
+ timing.timeInDriver = MicrosecondsDuration(driverEnd, m_CallbackContext.ctx.driverStart);
+ ALOGV("ArmnnPreparedModel_1_2::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice,
+ timing.timeInDriver);
+ m_CallbackContext.callback(
+ V1_0::ErrorStatus::NONE, m_OutputShapes, timing, "ArmnnPreparedModel_1_2::ExecuteGraph");
+ } else {
+ m_CallbackContext.callback(
+ V1_0::ErrorStatus::NONE, m_OutputShapes, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
+ }
+ return;
+}
+
#if defined(ARMNN_ANDROID_NN_V1_2) || defined(ARMNN_ANDROID_NN_V1_3)
template class ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>;
template bool ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ExecuteGraph<CallbackContext_1_2>(
@@ -609,6 +714,12 @@ template bool ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ExecuteGraph<CallbackC
armnn::InputTensors& pInputTensors,
armnn::OutputTensors& pOutputTensors,
CallbackContext_1_2 cb);
+
+template void ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ScheduleGraphForExecution<CallbackContext_1_2>(
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+ std::shared_ptr<armnn::InputTensors>& inputTensors,
+ std::shared_ptr<armnn::OutputTensors>& outputTensors,
+ CallbackContext_1_2 callbackContext);
#endif
} // namespace armnn_driver