From f364d5391b08e9071cd965f5765385ec9156b652 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Wed, 9 Jun 2021 17:07:33 +0100 Subject: IVGCVSW-6062 Rework the async threadpool !android-nn-driver:5802 * Extract the threadpool from LoadedNetwork/Runtime * Refactor the threadpool to be handle multiple networks * Trim IAsyncExecutionCallback and add an InferenceId to AsyncExecutionCallback * Add AsyncCallbackManager class Signed-off-by: Finn Williams Change-Id: I36aa2ad29c16bc10ee0706adfeb6b27f60012afb --- tests/ExecuteNetwork/ExecuteNetwork.cpp | 27 +++++++++++------------ tests/InferenceModel.hpp | 38 ++++++++++++++++++++++----------- 2 files changed, 39 insertions(+), 26 deletions(-) (limited to 'tests') diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index e8d5b1860c..48577c9990 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -445,14 +445,8 @@ int MainImpl(const ExecuteNetworkParams& params, try { ARMNN_LOG(info) << "Asynchronous execution with Arm NN thread pool... \n"; - std::vector callbacks; - - // Create callbacks that will be checked post scheduling - for (size_t i = 0; i < params.m_SimultaneousIterations; ++i) - { - // Point to ArmNN example implementation of AsyncExecutionCallback - callbacks.emplace_back(std::make_shared()); - } + armnn::AsyncCallbackManager callbackManager; + std::unordered_map&> inferenceOutputMap; // Declare the latest and earliest inference times here to be used when calculating overall time std::chrono::high_resolution_clock::time_point earliestStartTime; @@ -461,15 +455,19 @@ int MainImpl(const ExecuteNetworkParams& params, // For the asynchronous execution, we are adding a pool of working memory handles (1 per thread) in the // LoadedNetwork with each scheduled inference having a specific priority - for (size_t i = 0; i < callbacks.size(); ++i) + for (size_t i = 0; i < params.m_SimultaneousIterations; ++i) { - model.RunAsync(inputs[i], outputs[i], callbacks[i]); + std::shared_ptr cb = callbackManager.GetNewCallback(); + inferenceOutputMap.insert({cb->GetInferenceId(), outputs[i]}); + model.RunAsync(inputs[i], outputs[i], cb); } // Check the results unsigned int j = 0; - for (armnn::experimental::IAsyncExecutionCallbackPtr cb : callbacks) + for (size_t iteration = 0; iteration < params.m_SimultaneousIterations; ++iteration) { + auto cb = callbackManager.GetNotifiedCallback(); + // Get the results auto endTime = time_point_cast(cb->GetEndTime()); auto startTime = time_point_cast(cb->GetStartTime()); @@ -507,7 +505,7 @@ int MainImpl(const ExecuteNetworkParams& params, infoOut, outputTensorFile, params.m_DequantizeOutput); - mapbox::util::apply_visitor(printer, outputs[j][i]); + mapbox::util::apply_visitor(printer, inferenceOutputMap.at(cb->GetInferenceId())[i]); } ARMNN_LOG(info) << "\nInference time: " << std::setprecision(2) @@ -549,7 +547,7 @@ int MainImpl(const ExecuteNetworkParams& params, try { ARMNN_LOG(info) << "Asynchronous Execution with std::launch:async... \n"; - std::vector>>> inferenceResults; inferenceResults.reserve(params.m_SimultaneousIterations); @@ -567,9 +565,10 @@ int MainImpl(const ExecuteNetworkParams& params, for (unsigned int i = 0; i < params.m_SimultaneousIterations; ++i) { armnn::experimental::IWorkingMemHandle& workingMemHandleRef = *workingMemHandles[i].get(); + inferenceResults.push_back(std::async( std::launch::async, [&model, &workingMemHandleRef, &inputs, &outputs, i]() { - return model.RunAsync(workingMemHandleRef, inputs[i], outputs[i]); + return model.RunAsync(workingMemHandleRef, inputs[i], outputs[i], i); } )); } diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index 9d6096a3eb..3eb1e6a9e7 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -6,6 +6,7 @@ #pragma once #include +#include #include #include #include @@ -415,7 +416,7 @@ public: armnn::IRuntime::CreationOptions options; options.m_EnableGpuProfiling = m_EnableProfiling; options.m_DynamicBackendsPath = m_DynamicBackendsPath; - m_Runtime = std::move(armnn::IRuntime::Create(options)); + m_Runtime = armnn::IRuntime::Create(options); } std::string invalidBackends; @@ -484,13 +485,25 @@ public: const auto loading_start_time = armnn::GetTimeNow(); armnn::INetworkProperties networkProperties(params.m_AsyncEnabled, armnn::MemorySource::Undefined, - armnn::MemorySource::Undefined, - params.m_ThreadPoolSize); + armnn::MemorySource::Undefined); std::string errorMessage; ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties); ARMNN_LOG(info) << "Network loading time: " << std::setprecision(2) << std::fixed << armnn::GetTimeDuration(loading_start_time).count() << " ms\n"; + + if (params.m_AsyncEnabled && params.m_ThreadPoolSize > 0) + { + std::vector> memHandles; + for (size_t i = 0; i < params.m_ThreadPoolSize; ++i) + { + memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(m_NetworkIdentifier)); + } + + m_Threadpool = std::make_unique(params.m_ThreadPoolSize, + m_Runtime.get(), + memHandles); + } } if (ret == armnn::Status::Failure) @@ -579,10 +592,11 @@ public: } } - std::tuple> RunAsync( + std::tuple> RunAsync( armnn::experimental::IWorkingMemHandle& workingMemHandleRef, const std::vector& inputContainers, - std::vector& outputContainers) + std::vector& outputContainers, + unsigned int inferenceID) { for (unsigned int i = 0; i < outputContainers.size(); ++i) { @@ -614,7 +628,6 @@ public: armnn::Status ret = m_Runtime->Execute(workingMemHandleRef, MakeInputTensors(inputContainers), MakeOutputTensors(outputContainers)); - auto inferenceID = workingMemHandleRef.GetInferenceId(); const auto duration = armnn::GetTimeDuration(start_time); @@ -638,7 +651,7 @@ public: void RunAsync(const std::vector& inputContainers, std::vector& outputContainers, - armnn::experimental::IAsyncExecutionCallbackPtr cb) + std::shared_ptr cb) { for (unsigned int i = 0; i < outputContainers.size(); ++i) { @@ -664,11 +677,11 @@ public: profiler->EnableProfiling(m_EnableProfiling); } - m_Runtime->Schedule(m_NetworkIdentifier, - MakeInputTensors(inputContainers), - MakeOutputTensors(outputContainers), - armnn::QosExecPriority::Medium, - cb); + m_Threadpool->Schedule(m_NetworkIdentifier, + MakeInputTensors(inputContainers), + MakeOutputTensors(outputContainers), + armnn::QosExecPriority::Medium, + cb); // if profiling is enabled print out the results if (profiler && profiler->IsProfilingEnabled()) @@ -731,6 +744,7 @@ public: private: armnn::NetworkId m_NetworkIdentifier; std::shared_ptr m_Runtime; + std::unique_ptr m_Threadpool; std::vector m_InputBindings; std::vector m_OutputBindings; -- cgit v1.2.1