From d8fb540568b29fd1d81a1cca667a1ad3e33ef5a1 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Wed, 19 May 2021 20:52:00 +0100 Subject: IVGCVSW-5781 Add Async Support to Android-NN-Driver Signed-off-by: Finn Williams Change-Id: I1f13d04100fdb119495b9e3054425bf3babc59f1 --- ArmnnPreparedModel_1_3.cpp | 153 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 145 insertions(+), 8 deletions(-) (limited to 'ArmnnPreparedModel_1_3.cpp') diff --git a/ArmnnPreparedModel_1_3.cpp b/ArmnnPreparedModel_1_3.cpp index 3d93b99d..5a370321 100644 --- a/ArmnnPreparedModel_1_3.cpp +++ b/ArmnnPreparedModel_1_3.cpp @@ -168,7 +168,8 @@ ArmnnPreparedModel_1_3::ArmnnPreparedModel_1_3(armnn::NetworkId netw const V1_3::Model& model, const std::string& requestInputsAndOutputsDumpDir, const bool gpuProfilingEnabled, - V1_3::Priority priority) + V1_3::Priority priority, + const bool asyncModelExecutionEnabled) : m_NetworkId(networkId) , m_Runtime(runtime) , m_Model(model) @@ -176,9 +177,15 @@ ArmnnPreparedModel_1_3::ArmnnPreparedModel_1_3(armnn::NetworkId netw , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir) , m_GpuProfilingEnabled(gpuProfilingEnabled) , m_ModelPriority(priority) + , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled) { // Enable profiling if required. m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled); + + if (asyncModelExecutionEnabled) + { + m_WorkingMemHandle = m_Runtime->CreateWorkingMemHandle(networkId); + } } template @@ -726,8 +733,17 @@ Return ArmnnPreparedModel_1_3::ExecuteGraph( { cb.ctx.deviceStart = Now(); } - - armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors); + armnn::Status status; + if (m_AsyncModelExecutionEnabled) + { + ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph m_AsyncModelExecutionEnabled true"); + status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors); + } + else + { + ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph m_AsyncModelExecutionEnabled false"); + status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors); + } if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES) { @@ -735,7 +751,7 @@ Return ArmnnPreparedModel_1_3::ExecuteGraph( } if (status != armnn::Status::Success) { - ALOGW("EnqueueWorkload failed"); + ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph EnqueueWorkload failed"); cb.callback(V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph"); return V1_3::ErrorStatus::GENERAL_FAILURE; } @@ -773,6 +789,47 @@ Return ArmnnPreparedModel_1_3::ExecuteGraph( return V1_3::ErrorStatus::NONE; } +/// Schedule the graph prepared from the request for execution +template +template +void ArmnnPreparedModel_1_3::ScheduleGraphForExecution( + std::shared_ptr>& pMemPools, + std::shared_ptr& inputTensors, + std::shared_ptr& outputTensors, + CallbackContext callbackContext, + armnn::QosExecPriority priority) +{ + ALOGV("ArmnnPreparedModel_1_3::ScheduleGraphForExecution(...)"); + + DumpTensorsIfRequired("Input", *inputTensors); + + unsigned int outputTensorSize = outputTensors.get()->size(); + std::vector outputShapes(outputTensorSize); + for (unsigned int i = 0; i < outputTensorSize; i++) + { + std::pair outputTensorPair = outputTensors.get()->at(i); + const armnn::Tensor outputTensor = outputTensorPair.second; + const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo(); + + outputShapes[i] = ComputeShape(outputTensorInfo); + } + + auto tpCb = std::make_shared< + ArmnnThreadPoolCallback_1_3>(this, + pMemPools, + outputShapes, + inputTensors, + outputTensors, + callbackContext); + + m_Runtime->Schedule(m_NetworkId, + *tpCb->m_InputTensors, + *tpCb->m_OutputTensors, + priority, + tpCb); + ALOGV("ArmnnPreparedModel_1_3::ScheduleGraphForExecution end"); +} + template bool ArmnnPreparedModel_1_3::ExecuteWithDummyInputs() { @@ -862,13 +919,46 @@ Return ArmnnPreparedModel_1_3::Execute(const V1_ default: {} } - - ALOGV("ArmnnPreparedModel_1_3::execute(...) before PostMsg"); - - // post the request for asynchronous execution CallbackContext_1_3 cb; cb.callback = callback; cb.ctx = ctx; + + + enum class QosExecPriority + { + Low = 0, + Medium = 1, + High = 2 + }; + + + if (m_AsyncModelExecutionEnabled) + { + armnn::QosExecPriority priority; + + switch (GetModelPriority()) { + case V1_3::Priority::LOW: + priority = armnn::QosExecPriority::Low; + break; + case V1_3::Priority::MEDIUM: + priority = armnn::QosExecPriority::Medium; + break; + case V1_3::Priority::HIGH: + priority = armnn::QosExecPriority::High; + break; + default: + priority = armnn::QosExecPriority::Medium; + + } + + ALOGV("ArmnnPreparedModel_1_3::execute(...) before ScheduleGraphForExecution"); + ScheduleGraphForExecution(memPools, inputTensors, outputTensors, cb, priority); + ALOGV("ArmnnPreparedModel_1_3::execute(...) after ScheduleGraphForExecution"); + return V1_3::ErrorStatus::NONE; + } + + ALOGV("ArmnnPreparedModel_1_3::execute(...) before PostMsg"); + // post the request for asynchronous execution m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb); ALOGV("ArmnnPreparedModel_1_3::execute(...) after PostMsg"); return V1_3::ErrorStatus::NONE; @@ -880,6 +970,46 @@ V1_3::Priority ArmnnPreparedModel_1_3::GetModelPriority() return m_ModelPriority; } +template +template +void ArmnnPreparedModel_1_3::ArmnnThreadPoolCallback_1_3::Notify( + armnn::Status status, armnn::InferenceTimingPair timeTaken) +{ + ALOGV("ArmnnPreparedModel_1_3::ArmnnThreadPoolCallback_1_3::Notify"); + CommitPools(*m_MemPools); + + m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors); + + if (status != armnn::Status::Success) + { + ALOGW("ArmnnThreadPoolCallback_1_3::Notify EnqueueWorkload failed"); + m_CallbackContext.callback(V1_3::ErrorStatus::GENERAL_FAILURE, + {}, + g_NoTiming, + "ArmnnPreparedModel_1_3::ArmnnThreadPoolCallback_1_3"); + return; + } + + if (m_CallbackContext.ctx.measureTimings == V1_2::MeasureTiming::YES) + { + m_CallbackContext.ctx.deviceStart = timeTaken.first; + m_CallbackContext.ctx.deviceEnd = timeTaken.second; + m_CallbackContext.ctx.driverEnd = std::chrono::steady_clock::now(); + V1_2::Timing timing; + timing.timeOnDevice = MicrosecondsDuration(m_CallbackContext.ctx.deviceEnd, m_CallbackContext.ctx.deviceStart); + timing.timeInDriver = MicrosecondsDuration(m_CallbackContext.ctx.driverEnd, m_CallbackContext.ctx.driverStart); + ALOGV("ArmnnPreparedModel_1_3::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice, + timing.timeInDriver); + m_CallbackContext.callback( + V1_3::ErrorStatus::NONE, m_OutputShapes, timing, "ArmnnPreparedModel_1_3::ExecuteGraph"); + } else + { + m_CallbackContext.callback( + V1_3::ErrorStatus::NONE, m_OutputShapes, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph"); + } + return; +} + #ifdef ARMNN_ANDROID_NN_V1_3 template class ArmnnPreparedModel_1_3; template Return ArmnnPreparedModel_1_3::ExecuteGraph( @@ -887,6 +1017,13 @@ template Return ArmnnPreparedModel_1_3:: armnn::InputTensors& pInputTensors, armnn::OutputTensors& pOutputTensors, CallbackContext_1_3 cb); + +template void ArmnnPreparedModel_1_3::ScheduleGraphForExecution( + std::shared_ptr>& pMemPools, + std::shared_ptr& inputTensors, + std::shared_ptr& outputTensors, + CallbackContext_1_3 callbackContext, + armnn::QosExecPriority priority); #endif } // namespace armnn_driver -- cgit v1.2.1