diff options
Diffstat (limited to 'ArmnnPreparedModel_1_2.cpp')
-rw-r--r-- | ArmnnPreparedModel_1_2.cpp | 117 |
1 files changed, 114 insertions, 3 deletions
diff --git a/ArmnnPreparedModel_1_2.cpp b/ArmnnPreparedModel_1_2.cpp index a2148c29..c129fd69 100644 --- a/ArmnnPreparedModel_1_2.cpp +++ b/ArmnnPreparedModel_1_2.cpp @@ -6,6 +6,7 @@ #define LOG_TAG "ArmnnDriver" #include "ArmnnPreparedModel_1_2.hpp" + #include "Utils.hpp" #include <log/log.h> @@ -146,16 +147,23 @@ ArmnnPreparedModel_1_2<HalVersion>::ArmnnPreparedModel_1_2(armnn::NetworkId netw armnn::IRuntime* runtime, const V1_2::Model& model, const std::string& requestInputsAndOutputsDumpDir, - const bool gpuProfilingEnabled) + const bool gpuProfilingEnabled, + const bool asyncModelExecutionEnabled) : m_NetworkId(networkId) , m_Runtime(runtime) , m_Model(model) , m_RequestCount(0) , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir) , m_GpuProfilingEnabled(gpuProfilingEnabled) + , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled) { // Enable profiling if required. m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled); + + if (asyncModelExecutionEnabled) + { + m_WorkingMemHandle = m_Runtime->CreateWorkingMemHandle(networkId); + } } template<typename HalVersion> @@ -440,7 +448,17 @@ bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph( deviceStart = Now(); } - armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors); + armnn::Status status; + if (m_AsyncModelExecutionEnabled) + { + ALOGW("ArmnnPreparedModel_1_2::ExecuteGraph m_AsyncModelExecutionEnabled true"); + status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors); + } + else + { + ALOGW("ArmnnPreparedModel_1_2::ExecuteGraph m_AsyncModelExecutionEnabled false"); + status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors); + } if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES) { @@ -567,12 +585,21 @@ Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::Execute(const V1_ {} } - ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg"); // post the request for asynchronous execution CallbackContext_1_2 cb; cb.callback = callback; cb.ctx = ctx; + + if (m_AsyncModelExecutionEnabled) + { + ALOGV("ArmnnPreparedModel_1_2::execute(...) before ScheduleGraphForExecution"); + ScheduleGraphForExecution(memPools, inputTensors, outputTensors, cb); + ALOGV("ArmnnPreparedModel_1_2::execute(...) after ScheduleGraphForExecution"); + return V1_0::ErrorStatus::NONE; + } + + ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg"); m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb); ALOGV("ArmnnPreparedModel_1_2::execute(...) after PostMsg"); return V1_0::ErrorStatus::NONE; @@ -602,6 +629,84 @@ Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst( return Void(); } +/// Schedule the graph prepared from the request for execution +template<typename HalVersion> +template<typename CallbackContext> +void ArmnnPreparedModel_1_2<HalVersion>::ScheduleGraphForExecution( + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + std::shared_ptr<armnn::InputTensors>& inputTensors, + std::shared_ptr<armnn::OutputTensors>& outputTensors, + CallbackContext callbackContext) +{ + ALOGV("ArmnnPreparedModel_1_2::ScheduleGraphForExecution(...)"); + + DumpTensorsIfRequired("Input", *inputTensors); + + unsigned int outputTensorSize = outputTensors.get()->size(); + std::vector<V1_2::OutputShape> outputShapes(outputTensorSize); + for (unsigned int i = 0; i < outputTensorSize; i++) + { + std::pair<int, armnn::Tensor> outputTensorPair = outputTensors.get()->at(i); + const armnn::Tensor outputTensor = outputTensorPair.second; + const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo(); + + outputShapes[i] = ComputeShape(outputTensorInfo); + } + + auto tpCb = std::make_shared< + ArmnnThreadPoolCallback_1_2<CallbackContext_1_2>>(this, + pMemPools, + outputShapes, + inputTensors, + outputTensors, + callbackContext); + + m_Runtime->Schedule(m_NetworkId, + *tpCb->m_InputTensors, + *tpCb->m_OutputTensors, + armnn::QosExecPriority::High, + tpCb); + ALOGV("ArmnnPreparedModel_1_2::ScheduleGraphForExecution end"); +} + +template<typename HalVersion> +template <typename CallbackContext> +void ArmnnPreparedModel_1_2<HalVersion>::ArmnnThreadPoolCallback_1_2<CallbackContext>::Notify( + armnn::Status status, armnn::InferenceTimingPair timeTaken) +{ + ALOGV("ArmnnPreparedModel_1_2::ArmnnThreadPoolCallback_1_2 Notify"); + + TimePoint driverEnd; + + CommitPools(*m_MemPools); + + m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors); + + if (status != armnn::Status::Success) + { + ALOGW("ArmnnThreadPoolCallback::Notify EnqueueWorkload failed"); + m_CallbackContext.callback( + V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel::ExecuteGraph"); + return; + } + + if (m_CallbackContext.ctx.measureTimings == V1_2::MeasureTiming::YES) + { + driverEnd = std::chrono::steady_clock::now(); + V1_2::Timing timing; + timing.timeOnDevice = MicrosecondsDuration(timeTaken.second, timeTaken.first); + timing.timeInDriver = MicrosecondsDuration(driverEnd, m_CallbackContext.ctx.driverStart); + ALOGV("ArmnnPreparedModel_1_2::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice, + timing.timeInDriver); + m_CallbackContext.callback( + V1_0::ErrorStatus::NONE, m_OutputShapes, timing, "ArmnnPreparedModel_1_2::ExecuteGraph"); + } else { + m_CallbackContext.callback( + V1_0::ErrorStatus::NONE, m_OutputShapes, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph"); + } + return; +} + #if defined(ARMNN_ANDROID_NN_V1_2) || defined(ARMNN_ANDROID_NN_V1_3) template class ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>; template bool ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ExecuteGraph<CallbackContext_1_2>( @@ -609,6 +714,12 @@ template bool ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ExecuteGraph<CallbackC armnn::InputTensors& pInputTensors, armnn::OutputTensors& pOutputTensors, CallbackContext_1_2 cb); + +template void ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ScheduleGraphForExecution<CallbackContext_1_2>( + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + std::shared_ptr<armnn::InputTensors>& inputTensors, + std::shared_ptr<armnn::OutputTensors>& outputTensors, + CallbackContext_1_2 callbackContext); #endif } // namespace armnn_driver |