aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ArmnnPreparedModel.cpp')
-rw-r--r--ArmnnPreparedModel.cpp96
1 files changed, 92 insertions, 4 deletions
diff --git a/ArmnnPreparedModel.cpp b/ArmnnPreparedModel.cpp
index 60beac4f..978f3787 100644
--- a/ArmnnPreparedModel.cpp
+++ b/ArmnnPreparedModel.cpp
@@ -112,16 +112,23 @@ ArmnnPreparedModel<HalVersion>::ArmnnPreparedModel(armnn::NetworkId networkId,
armnn::IRuntime* runtime,
const HalModel& model,
const std::string& requestInputsAndOutputsDumpDir,
- const bool gpuProfilingEnabled)
+ const bool gpuProfilingEnabled,
+ const bool asyncModelExecutionEnabled)
: m_NetworkId(networkId)
, m_Runtime(runtime)
, m_Model(model)
, m_RequestCount(0)
, m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
, m_GpuProfilingEnabled(gpuProfilingEnabled)
+ , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled)
{
// Enable profiling if required.
m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
+
+ if (asyncModelExecutionEnabled)
+ {
+ m_WorkingMemHandle = m_Runtime->CreateWorkingMemHandle(networkId);
+ }
}
template<typename HalVersion>
@@ -225,8 +232,6 @@ Return<V1_0::ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(
return V1_0::ErrorStatus::GENERAL_FAILURE;
}
- ALOGV("ArmnnPreparedModel::execute(...) before PostMsg");
-
auto cb = [callback](V1_0::ErrorStatus errorStatus, std::string callingFunction)
{
NotifyCallbackAndCheck(callback, errorStatus, callingFunction);
@@ -234,7 +239,17 @@ Return<V1_0::ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(
CallbackContext_1_0 armnnCb;
armnnCb.callback = cb;
+
+ if (m_AsyncModelExecutionEnabled)
+ {
+ ALOGV("ArmnnPreparedModel::execute(...) before ScheduleGraphForExecution");
+ ScheduleGraphForExecution(pMemPools, pInputTensors, pOutputTensors, armnnCb);
+ ALOGV("ArmnnPreparedModel::execute(...) after ScheduleGraphForExecution");
+ return V1_0::ErrorStatus::NONE;
+ }
+
// post the request for asynchronous execution
+ ALOGV("ArmnnPreparedModel::execute(...) before PostMsg");
m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb);
ALOGV("ArmnnPreparedModel::execute(...) after PostMsg");
return V1_0::ErrorStatus::NONE; // successfully queued
@@ -254,7 +269,18 @@ void ArmnnPreparedModel<HalVersion>::ExecuteGraph(
// run it
try
{
- armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+ armnn::Status status;
+ if (m_AsyncModelExecutionEnabled)
+ {
+ ALOGW("ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled true");
+ status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors);
+ }
+ else
+ {
+ ALOGW("ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled false");
+ status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+ }
+
if (status != armnn::Status::Success)
{
ALOGW("EnqueueWorkload failed");
@@ -340,11 +366,73 @@ bool ArmnnPreparedModel<HalVersion>::ExecuteWithDummyInputs()
return true;
}
+/// Schedule the graph prepared from the request for execution
+template<typename HalVersion>
+template<typename CallbackContext>
+void ArmnnPreparedModel<HalVersion>::ScheduleGraphForExecution(
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+ std::shared_ptr<armnn::InputTensors>& inputTensors,
+ std::shared_ptr<armnn::OutputTensors>& outputTensors,
+ CallbackContext callbackContext)
+{
+ ALOGV("ArmnnPreparedModel::ScheduleGraphForExecution(...)");
+
+ DumpTensorsIfRequired("Input", *inputTensors);
+
+
+ auto tpCb = std::make_shared<
+ ArmnnThreadPoolCallback<CallbackContext_1_0>>(this,
+ pMemPools,
+ inputTensors,
+ outputTensors,
+ callbackContext);
+
+ m_Runtime->Schedule(m_NetworkId,
+ *tpCb->m_InputTensors,
+ *tpCb->m_OutputTensors,
+ armnn::QosExecPriority::High,
+ tpCb);
+ ALOGV("ArmnnPreparedModel::ScheduleGraphForExecution end");
+}
+
+template<typename HalVersion>
+template <typename CallbackContext>
+void ArmnnPreparedModel<HalVersion>::ArmnnThreadPoolCallback<CallbackContext>::Notify(
+ armnn::Status status, armnn::InferenceTimingPair timeTaken)
+{
+ armnn::IgnoreUnused(status, timeTaken);
+ ALOGV("ArmnnPreparedModel::ArmnnThreadPoolCallback_1_2 Notify");
+
+ m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors);
+
+ // Commit output buffers.
+ // Note that we update *all* pools, even if they aren't actually used as outputs -
+ // this is simpler and is what the CpuExecutor does.
+ for (android::nn::RunTimePoolInfo& pool : *m_MemPools)
+ {
+ // Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where
+ // update() has been removed and flush() added.
+ #if defined(ARMNN_ANDROID_R) || defined(ARMNN_ANDROID_S) // Use the new Android implementation.
+ pool.flush();
+ #else
+ pool.update();
+ #endif
+ }
+
+ m_CallbackContext.callback(V1_0::ErrorStatus::NONE, "ArmnnPreparedModel::ArmnnThreadPoolCallback_1_2 Notify");
+ return;
+}
+
///
/// Class template specializations
///
template class ArmnnPreparedModel<hal_1_0::HalPolicy>;
+template void ArmnnPreparedModel<hal_1_0::HalPolicy>::ScheduleGraphForExecution<CallbackContext_1_0>(
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+ std::shared_ptr<armnn::InputTensors>& inputTensors,
+ std::shared_ptr<armnn::OutputTensors>& outputTensors,
+ CallbackContext_1_0 callbackContext);
#ifdef ARMNN_ANDROID_NN_V1_1
template class ArmnnPreparedModel<hal_1_1::HalPolicy>;