diff options
Diffstat (limited to 'ArmnnPreparedModel_1_3.cpp')
-rw-r--r-- | ArmnnPreparedModel_1_3.cpp | 153 |
1 files changed, 145 insertions, 8 deletions
diff --git a/ArmnnPreparedModel_1_3.cpp b/ArmnnPreparedModel_1_3.cpp index 3d93b99d..5a370321 100644 --- a/ArmnnPreparedModel_1_3.cpp +++ b/ArmnnPreparedModel_1_3.cpp @@ -168,7 +168,8 @@ ArmnnPreparedModel_1_3<HalVersion>::ArmnnPreparedModel_1_3(armnn::NetworkId netw const V1_3::Model& model, const std::string& requestInputsAndOutputsDumpDir, const bool gpuProfilingEnabled, - V1_3::Priority priority) + V1_3::Priority priority, + const bool asyncModelExecutionEnabled) : m_NetworkId(networkId) , m_Runtime(runtime) , m_Model(model) @@ -176,9 +177,15 @@ ArmnnPreparedModel_1_3<HalVersion>::ArmnnPreparedModel_1_3(armnn::NetworkId netw , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir) , m_GpuProfilingEnabled(gpuProfilingEnabled) , m_ModelPriority(priority) + , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled) { // Enable profiling if required. m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled); + + if (asyncModelExecutionEnabled) + { + m_WorkingMemHandle = m_Runtime->CreateWorkingMemHandle(networkId); + } } template<typename HalVersion> @@ -726,8 +733,17 @@ Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::ExecuteGraph( { cb.ctx.deviceStart = Now(); } - - armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors); + armnn::Status status; + if (m_AsyncModelExecutionEnabled) + { + ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph m_AsyncModelExecutionEnabled true"); + status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors); + } + else + { + ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph m_AsyncModelExecutionEnabled false"); + status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors); + } if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES) { @@ -735,7 +751,7 @@ Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::ExecuteGraph( } if (status != armnn::Status::Success) { - ALOGW("EnqueueWorkload failed"); + ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph EnqueueWorkload failed"); cb.callback(V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph"); return V1_3::ErrorStatus::GENERAL_FAILURE; } @@ -773,6 +789,47 @@ Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::ExecuteGraph( return V1_3::ErrorStatus::NONE; } +/// Schedule the graph prepared from the request for execution +template<typename HalVersion> +template<typename CallbackContext> +void ArmnnPreparedModel_1_3<HalVersion>::ScheduleGraphForExecution( + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + std::shared_ptr<armnn::InputTensors>& inputTensors, + std::shared_ptr<armnn::OutputTensors>& outputTensors, + CallbackContext callbackContext, + armnn::QosExecPriority priority) +{ + ALOGV("ArmnnPreparedModel_1_3::ScheduleGraphForExecution(...)"); + + DumpTensorsIfRequired("Input", *inputTensors); + + unsigned int outputTensorSize = outputTensors.get()->size(); + std::vector<V1_2::OutputShape> outputShapes(outputTensorSize); + for (unsigned int i = 0; i < outputTensorSize; i++) + { + std::pair<int, armnn::Tensor> outputTensorPair = outputTensors.get()->at(i); + const armnn::Tensor outputTensor = outputTensorPair.second; + const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo(); + + outputShapes[i] = ComputeShape(outputTensorInfo); + } + + auto tpCb = std::make_shared< + ArmnnThreadPoolCallback_1_3<CallbackContext_1_3>>(this, + pMemPools, + outputShapes, + inputTensors, + outputTensors, + callbackContext); + + m_Runtime->Schedule(m_NetworkId, + *tpCb->m_InputTensors, + *tpCb->m_OutputTensors, + priority, + tpCb); + ALOGV("ArmnnPreparedModel_1_3::ScheduleGraphForExecution end"); +} + template<typename HalVersion> bool ArmnnPreparedModel_1_3<HalVersion>::ExecuteWithDummyInputs() { @@ -862,13 +919,46 @@ Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::Execute(const V1_ default: {} } - - ALOGV("ArmnnPreparedModel_1_3::execute(...) before PostMsg"); - - // post the request for asynchronous execution CallbackContext_1_3 cb; cb.callback = callback; cb.ctx = ctx; + + + enum class QosExecPriority + { + Low = 0, + Medium = 1, + High = 2 + }; + + + if (m_AsyncModelExecutionEnabled) + { + armnn::QosExecPriority priority; + + switch (GetModelPriority()) { + case V1_3::Priority::LOW: + priority = armnn::QosExecPriority::Low; + break; + case V1_3::Priority::MEDIUM: + priority = armnn::QosExecPriority::Medium; + break; + case V1_3::Priority::HIGH: + priority = armnn::QosExecPriority::High; + break; + default: + priority = armnn::QosExecPriority::Medium; + + } + + ALOGV("ArmnnPreparedModel_1_3::execute(...) before ScheduleGraphForExecution"); + ScheduleGraphForExecution(memPools, inputTensors, outputTensors, cb, priority); + ALOGV("ArmnnPreparedModel_1_3::execute(...) after ScheduleGraphForExecution"); + return V1_3::ErrorStatus::NONE; + } + + ALOGV("ArmnnPreparedModel_1_3::execute(...) before PostMsg"); + // post the request for asynchronous execution m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb); ALOGV("ArmnnPreparedModel_1_3::execute(...) after PostMsg"); return V1_3::ErrorStatus::NONE; @@ -880,6 +970,46 @@ V1_3::Priority ArmnnPreparedModel_1_3<HalVersion>::GetModelPriority() return m_ModelPriority; } +template<typename HalVersion> +template <typename CallbackContext> +void ArmnnPreparedModel_1_3<HalVersion>::ArmnnThreadPoolCallback_1_3<CallbackContext>::Notify( + armnn::Status status, armnn::InferenceTimingPair timeTaken) +{ + ALOGV("ArmnnPreparedModel_1_3::ArmnnThreadPoolCallback_1_3<CallbackContext>::Notify"); + CommitPools(*m_MemPools); + + m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors); + + if (status != armnn::Status::Success) + { + ALOGW("ArmnnThreadPoolCallback_1_3::Notify EnqueueWorkload failed"); + m_CallbackContext.callback(V1_3::ErrorStatus::GENERAL_FAILURE, + {}, + g_NoTiming, + "ArmnnPreparedModel_1_3::ArmnnThreadPoolCallback_1_3"); + return; + } + + if (m_CallbackContext.ctx.measureTimings == V1_2::MeasureTiming::YES) + { + m_CallbackContext.ctx.deviceStart = timeTaken.first; + m_CallbackContext.ctx.deviceEnd = timeTaken.second; + m_CallbackContext.ctx.driverEnd = std::chrono::steady_clock::now(); + V1_2::Timing timing; + timing.timeOnDevice = MicrosecondsDuration(m_CallbackContext.ctx.deviceEnd, m_CallbackContext.ctx.deviceStart); + timing.timeInDriver = MicrosecondsDuration(m_CallbackContext.ctx.driverEnd, m_CallbackContext.ctx.driverStart); + ALOGV("ArmnnPreparedModel_1_3::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice, + timing.timeInDriver); + m_CallbackContext.callback( + V1_3::ErrorStatus::NONE, m_OutputShapes, timing, "ArmnnPreparedModel_1_3::ExecuteGraph"); + } else + { + m_CallbackContext.callback( + V1_3::ErrorStatus::NONE, m_OutputShapes, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph"); + } + return; +} + #ifdef ARMNN_ANDROID_NN_V1_3 template class ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>; template Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>::ExecuteGraph<CallbackContext_1_3>( @@ -887,6 +1017,13 @@ template Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>:: armnn::InputTensors& pInputTensors, armnn::OutputTensors& pOutputTensors, CallbackContext_1_3 cb); + +template void ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>::ScheduleGraphForExecution<CallbackContext_1_3>( + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + std::shared_ptr<armnn::InputTensors>& inputTensors, + std::shared_ptr<armnn::OutputTensors>& outputTensors, + CallbackContext_1_3 callbackContext, + armnn::QosExecPriority priority); #endif } // namespace armnn_driver |