aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2021-06-09 17:07:33 +0100
committerFinn Williams <Finn.Williams@arm.com>2021-06-23 17:14:53 +0100
commitf364d5391b08e9071cd965f5765385ec9156b652 (patch)
tree1ea93ed574a3eb51f5a1f4bb08dc1ad18aa1c6a2 /tests
parent7a00eaa6ecf121623823b1951c0e6c9093271adf (diff)
downloadarmnn-f364d5391b08e9071cd965f5765385ec9156b652.tar.gz
IVGCVSW-6062 Rework the async threadpool
!android-nn-driver:5802 * Extract the threadpool from LoadedNetwork/Runtime * Refactor the threadpool to be handle multiple networks * Trim IAsyncExecutionCallback and add an InferenceId to AsyncExecutionCallback * Add AsyncCallbackManager class Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I36aa2ad29c16bc10ee0706adfeb6b27f60012afb
Diffstat (limited to 'tests')
-rw-r--r--tests/ExecuteNetwork/ExecuteNetwork.cpp27
-rw-r--r--tests/InferenceModel.hpp38
2 files changed, 39 insertions, 26 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp
index e8d5b1860c..48577c9990 100644
--- a/tests/ExecuteNetwork/ExecuteNetwork.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp
@@ -445,14 +445,8 @@ int MainImpl(const ExecuteNetworkParams& params,
try
{
ARMNN_LOG(info) << "Asynchronous execution with Arm NN thread pool... \n";
- std::vector<armnn::experimental::IAsyncExecutionCallbackPtr> callbacks;
-
- // Create callbacks that will be checked post scheduling
- for (size_t i = 0; i < params.m_SimultaneousIterations; ++i)
- {
- // Point to ArmNN example implementation of AsyncExecutionCallback
- callbacks.emplace_back(std::make_shared<armnn::experimental::AsyncExecutionCallback>());
- }
+ armnn::AsyncCallbackManager callbackManager;
+ std::unordered_map<armnn::InferenceId, std::vector<TContainer>&> inferenceOutputMap;
// Declare the latest and earliest inference times here to be used when calculating overall time
std::chrono::high_resolution_clock::time_point earliestStartTime;
@@ -461,15 +455,19 @@ int MainImpl(const ExecuteNetworkParams& params,
// For the asynchronous execution, we are adding a pool of working memory handles (1 per thread) in the
// LoadedNetwork with each scheduled inference having a specific priority
- for (size_t i = 0; i < callbacks.size(); ++i)
+ for (size_t i = 0; i < params.m_SimultaneousIterations; ++i)
{
- model.RunAsync(inputs[i], outputs[i], callbacks[i]);
+ std::shared_ptr<armnn::AsyncExecutionCallback> cb = callbackManager.GetNewCallback();
+ inferenceOutputMap.insert({cb->GetInferenceId(), outputs[i]});
+ model.RunAsync(inputs[i], outputs[i], cb);
}
// Check the results
unsigned int j = 0;
- for (armnn::experimental::IAsyncExecutionCallbackPtr cb : callbacks)
+ for (size_t iteration = 0; iteration < params.m_SimultaneousIterations; ++iteration)
{
+ auto cb = callbackManager.GetNotifiedCallback();
+
// Get the results
auto endTime = time_point_cast<std::chrono::milliseconds>(cb->GetEndTime());
auto startTime = time_point_cast<std::chrono::milliseconds>(cb->GetStartTime());
@@ -507,7 +505,7 @@ int MainImpl(const ExecuteNetworkParams& params,
infoOut,
outputTensorFile,
params.m_DequantizeOutput);
- mapbox::util::apply_visitor(printer, outputs[j][i]);
+ mapbox::util::apply_visitor(printer, inferenceOutputMap.at(cb->GetInferenceId())[i]);
}
ARMNN_LOG(info) << "\nInference time: " << std::setprecision(2)
@@ -549,7 +547,7 @@ int MainImpl(const ExecuteNetworkParams& params,
try
{
ARMNN_LOG(info) << "Asynchronous Execution with std::launch:async... \n";
- std::vector<std::future<std::tuple<armnn::profiling::ProfilingGuid,
+ std::vector<std::future<std::tuple<unsigned int,
std::chrono::duration<double, std::milli>>>> inferenceResults;
inferenceResults.reserve(params.m_SimultaneousIterations);
@@ -567,9 +565,10 @@ int MainImpl(const ExecuteNetworkParams& params,
for (unsigned int i = 0; i < params.m_SimultaneousIterations; ++i)
{
armnn::experimental::IWorkingMemHandle& workingMemHandleRef = *workingMemHandles[i].get();
+
inferenceResults.push_back(std::async(
std::launch::async, [&model, &workingMemHandleRef, &inputs, &outputs, i]() {
- return model.RunAsync(workingMemHandleRef, inputs[i], outputs[i]);
+ return model.RunAsync(workingMemHandleRef, inputs[i], outputs[i], i);
}
));
}
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index 9d6096a3eb..3eb1e6a9e7 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -6,6 +6,7 @@
#pragma once
#include <armnn/ArmNN.hpp>
+#include <armnn/Threadpool.hpp>
#include <armnn/Logging.hpp>
#include <armnn/utility/Timer.hpp>
#include <armnn/BackendRegistry.hpp>
@@ -415,7 +416,7 @@ public:
armnn::IRuntime::CreationOptions options;
options.m_EnableGpuProfiling = m_EnableProfiling;
options.m_DynamicBackendsPath = m_DynamicBackendsPath;
- m_Runtime = std::move(armnn::IRuntime::Create(options));
+ m_Runtime = armnn::IRuntime::Create(options);
}
std::string invalidBackends;
@@ -484,13 +485,25 @@ public:
const auto loading_start_time = armnn::GetTimeNow();
armnn::INetworkProperties networkProperties(params.m_AsyncEnabled,
armnn::MemorySource::Undefined,
- armnn::MemorySource::Undefined,
- params.m_ThreadPoolSize);
+ armnn::MemorySource::Undefined);
std::string errorMessage;
ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties);
ARMNN_LOG(info) << "Network loading time: " << std::setprecision(2)
<< std::fixed << armnn::GetTimeDuration(loading_start_time).count() << " ms\n";
+
+ if (params.m_AsyncEnabled && params.m_ThreadPoolSize > 0)
+ {
+ std::vector<std::shared_ptr<armnn::IWorkingMemHandle>> memHandles;
+ for (size_t i = 0; i < params.m_ThreadPoolSize; ++i)
+ {
+ memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(m_NetworkIdentifier));
+ }
+
+ m_Threadpool = std::make_unique<armnn::Threadpool>(params.m_ThreadPoolSize,
+ m_Runtime.get(),
+ memHandles);
+ }
}
if (ret == armnn::Status::Failure)
@@ -579,10 +592,11 @@ public:
}
}
- std::tuple<armnn::profiling::ProfilingGuid, std::chrono::duration<double, std::milli>> RunAsync(
+ std::tuple<unsigned int, std::chrono::duration<double, std::milli>> RunAsync(
armnn::experimental::IWorkingMemHandle& workingMemHandleRef,
const std::vector<TContainer>& inputContainers,
- std::vector<TContainer>& outputContainers)
+ std::vector<TContainer>& outputContainers,
+ unsigned int inferenceID)
{
for (unsigned int i = 0; i < outputContainers.size(); ++i)
{
@@ -614,7 +628,6 @@ public:
armnn::Status ret = m_Runtime->Execute(workingMemHandleRef,
MakeInputTensors(inputContainers),
MakeOutputTensors(outputContainers));
- auto inferenceID = workingMemHandleRef.GetInferenceId();
const auto duration = armnn::GetTimeDuration(start_time);
@@ -638,7 +651,7 @@ public:
void RunAsync(const std::vector<TContainer>& inputContainers,
std::vector<TContainer>& outputContainers,
- armnn::experimental::IAsyncExecutionCallbackPtr cb)
+ std::shared_ptr<armnn::IAsyncExecutionCallback> cb)
{
for (unsigned int i = 0; i < outputContainers.size(); ++i)
{
@@ -664,11 +677,11 @@ public:
profiler->EnableProfiling(m_EnableProfiling);
}
- m_Runtime->Schedule(m_NetworkIdentifier,
- MakeInputTensors(inputContainers),
- MakeOutputTensors(outputContainers),
- armnn::QosExecPriority::Medium,
- cb);
+ m_Threadpool->Schedule(m_NetworkIdentifier,
+ MakeInputTensors(inputContainers),
+ MakeOutputTensors(outputContainers),
+ armnn::QosExecPriority::Medium,
+ cb);
// if profiling is enabled print out the results
if (profiler && profiler->IsProfilingEnabled())
@@ -731,6 +744,7 @@ public:
private:
armnn::NetworkId m_NetworkIdentifier;
std::shared_ptr<armnn::IRuntime> m_Runtime;
+ std::unique_ptr<armnn::Threadpool> m_Threadpool;
std::vector<armnn::BindingPointInfo> m_InputBindings;
std::vector<armnn::BindingPointInfo> m_OutputBindings;