aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel_1_3.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ArmnnPreparedModel_1_3.cpp')
-rw-r--r--ArmnnPreparedModel_1_3.cpp153
1 files changed, 145 insertions, 8 deletions
diff --git a/ArmnnPreparedModel_1_3.cpp b/ArmnnPreparedModel_1_3.cpp
index 3d93b99d..5a370321 100644
--- a/ArmnnPreparedModel_1_3.cpp
+++ b/ArmnnPreparedModel_1_3.cpp
@@ -168,7 +168,8 @@ ArmnnPreparedModel_1_3<HalVersion>::ArmnnPreparedModel_1_3(armnn::NetworkId netw
const V1_3::Model& model,
const std::string& requestInputsAndOutputsDumpDir,
const bool gpuProfilingEnabled,
- V1_3::Priority priority)
+ V1_3::Priority priority,
+ const bool asyncModelExecutionEnabled)
: m_NetworkId(networkId)
, m_Runtime(runtime)
, m_Model(model)
@@ -176,9 +177,15 @@ ArmnnPreparedModel_1_3<HalVersion>::ArmnnPreparedModel_1_3(armnn::NetworkId netw
, m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
, m_GpuProfilingEnabled(gpuProfilingEnabled)
, m_ModelPriority(priority)
+ , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled)
{
// Enable profiling if required.
m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
+
+ if (asyncModelExecutionEnabled)
+ {
+ m_WorkingMemHandle = m_Runtime->CreateWorkingMemHandle(networkId);
+ }
}
template<typename HalVersion>
@@ -726,8 +733,17 @@ Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::ExecuteGraph(
{
cb.ctx.deviceStart = Now();
}
-
- armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+ armnn::Status status;
+ if (m_AsyncModelExecutionEnabled)
+ {
+ ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph m_AsyncModelExecutionEnabled true");
+ status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors);
+ }
+ else
+ {
+ ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph m_AsyncModelExecutionEnabled false");
+ status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+ }
if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
{
@@ -735,7 +751,7 @@ Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::ExecuteGraph(
}
if (status != armnn::Status::Success)
{
- ALOGW("EnqueueWorkload failed");
+ ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph EnqueueWorkload failed");
cb.callback(V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph");
return V1_3::ErrorStatus::GENERAL_FAILURE;
}
@@ -773,6 +789,47 @@ Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::ExecuteGraph(
return V1_3::ErrorStatus::NONE;
}
+/// Schedule the graph prepared from the request for execution
+template<typename HalVersion>
+template<typename CallbackContext>
+void ArmnnPreparedModel_1_3<HalVersion>::ScheduleGraphForExecution(
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+ std::shared_ptr<armnn::InputTensors>& inputTensors,
+ std::shared_ptr<armnn::OutputTensors>& outputTensors,
+ CallbackContext callbackContext,
+ armnn::QosExecPriority priority)
+{
+ ALOGV("ArmnnPreparedModel_1_3::ScheduleGraphForExecution(...)");
+
+ DumpTensorsIfRequired("Input", *inputTensors);
+
+ unsigned int outputTensorSize = outputTensors.get()->size();
+ std::vector<V1_2::OutputShape> outputShapes(outputTensorSize);
+ for (unsigned int i = 0; i < outputTensorSize; i++)
+ {
+ std::pair<int, armnn::Tensor> outputTensorPair = outputTensors.get()->at(i);
+ const armnn::Tensor outputTensor = outputTensorPair.second;
+ const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
+
+ outputShapes[i] = ComputeShape(outputTensorInfo);
+ }
+
+ auto tpCb = std::make_shared<
+ ArmnnThreadPoolCallback_1_3<CallbackContext_1_3>>(this,
+ pMemPools,
+ outputShapes,
+ inputTensors,
+ outputTensors,
+ callbackContext);
+
+ m_Runtime->Schedule(m_NetworkId,
+ *tpCb->m_InputTensors,
+ *tpCb->m_OutputTensors,
+ priority,
+ tpCb);
+ ALOGV("ArmnnPreparedModel_1_3::ScheduleGraphForExecution end");
+}
+
template<typename HalVersion>
bool ArmnnPreparedModel_1_3<HalVersion>::ExecuteWithDummyInputs()
{
@@ -862,13 +919,46 @@ Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::Execute(const V1_
default:
{}
}
-
- ALOGV("ArmnnPreparedModel_1_3::execute(...) before PostMsg");
-
- // post the request for asynchronous execution
CallbackContext_1_3 cb;
cb.callback = callback;
cb.ctx = ctx;
+
+
+ enum class QosExecPriority
+ {
+ Low = 0,
+ Medium = 1,
+ High = 2
+ };
+
+
+ if (m_AsyncModelExecutionEnabled)
+ {
+ armnn::QosExecPriority priority;
+
+ switch (GetModelPriority()) {
+ case V1_3::Priority::LOW:
+ priority = armnn::QosExecPriority::Low;
+ break;
+ case V1_3::Priority::MEDIUM:
+ priority = armnn::QosExecPriority::Medium;
+ break;
+ case V1_3::Priority::HIGH:
+ priority = armnn::QosExecPriority::High;
+ break;
+ default:
+ priority = armnn::QosExecPriority::Medium;
+
+ }
+
+ ALOGV("ArmnnPreparedModel_1_3::execute(...) before ScheduleGraphForExecution");
+ ScheduleGraphForExecution(memPools, inputTensors, outputTensors, cb, priority);
+ ALOGV("ArmnnPreparedModel_1_3::execute(...) after ScheduleGraphForExecution");
+ return V1_3::ErrorStatus::NONE;
+ }
+
+ ALOGV("ArmnnPreparedModel_1_3::execute(...) before PostMsg");
+ // post the request for asynchronous execution
m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb);
ALOGV("ArmnnPreparedModel_1_3::execute(...) after PostMsg");
return V1_3::ErrorStatus::NONE;
@@ -880,6 +970,46 @@ V1_3::Priority ArmnnPreparedModel_1_3<HalVersion>::GetModelPriority()
return m_ModelPriority;
}
+template<typename HalVersion>
+template <typename CallbackContext>
+void ArmnnPreparedModel_1_3<HalVersion>::ArmnnThreadPoolCallback_1_3<CallbackContext>::Notify(
+ armnn::Status status, armnn::InferenceTimingPair timeTaken)
+{
+ ALOGV("ArmnnPreparedModel_1_3::ArmnnThreadPoolCallback_1_3<CallbackContext>::Notify");
+ CommitPools(*m_MemPools);
+
+ m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors);
+
+ if (status != armnn::Status::Success)
+ {
+ ALOGW("ArmnnThreadPoolCallback_1_3::Notify EnqueueWorkload failed");
+ m_CallbackContext.callback(V1_3::ErrorStatus::GENERAL_FAILURE,
+ {},
+ g_NoTiming,
+ "ArmnnPreparedModel_1_3::ArmnnThreadPoolCallback_1_3");
+ return;
+ }
+
+ if (m_CallbackContext.ctx.measureTimings == V1_2::MeasureTiming::YES)
+ {
+ m_CallbackContext.ctx.deviceStart = timeTaken.first;
+ m_CallbackContext.ctx.deviceEnd = timeTaken.second;
+ m_CallbackContext.ctx.driverEnd = std::chrono::steady_clock::now();
+ V1_2::Timing timing;
+ timing.timeOnDevice = MicrosecondsDuration(m_CallbackContext.ctx.deviceEnd, m_CallbackContext.ctx.deviceStart);
+ timing.timeInDriver = MicrosecondsDuration(m_CallbackContext.ctx.driverEnd, m_CallbackContext.ctx.driverStart);
+ ALOGV("ArmnnPreparedModel_1_3::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice,
+ timing.timeInDriver);
+ m_CallbackContext.callback(
+ V1_3::ErrorStatus::NONE, m_OutputShapes, timing, "ArmnnPreparedModel_1_3::ExecuteGraph");
+ } else
+ {
+ m_CallbackContext.callback(
+ V1_3::ErrorStatus::NONE, m_OutputShapes, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph");
+ }
+ return;
+}
+
#ifdef ARMNN_ANDROID_NN_V1_3
template class ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>;
template Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>::ExecuteGraph<CallbackContext_1_3>(
@@ -887,6 +1017,13 @@ template Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>::
armnn::InputTensors& pInputTensors,
armnn::OutputTensors& pOutputTensors,
CallbackContext_1_3 cb);
+
+template void ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>::ScheduleGraphForExecution<CallbackContext_1_3>(
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+ std::shared_ptr<armnn::InputTensors>& inputTensors,
+ std::shared_ptr<armnn::OutputTensors>& outputTensors,
+ CallbackContext_1_3 callbackContext,
+ armnn::QosExecPriority priority);
#endif
} // namespace armnn_driver