From b4b3ac91990eb5deaffca2300319f2ddf7aa0886 Mon Sep 17 00:00:00 2001 From: Kevin May Date: Fri, 21 May 2021 16:42:21 +0100 Subject: IVGCVSW-6009 Integrate threadpool into ExNet * Remove concurrent flag from ExecuteNetwork as it is possible to deduce if SimultaneousIterations > 1 * Add void RunAsync() * Refactor some unit tests Change-Id: I7021d4821b0e460470908294cbd9462850e8b361 Signed-off-by: Keith Davis Signed-off-by: Kevin May --- tests/InferenceModel.hpp | 46 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) (limited to 'tests/InferenceModel.hpp') diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index 3429598249..7c51011a22 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -102,6 +102,7 @@ struct Params unsigned int m_NumberOfThreads; std::string m_MLGOTuningFilePath; bool m_AsyncEnabled; + size_t m_ThreadPoolSize; Params() @@ -120,6 +121,7 @@ struct Params , m_NumberOfThreads(0) , m_MLGOTuningFilePath("") , m_AsyncEnabled(false) + , m_ThreadPoolSize(1) {} }; @@ -481,7 +483,8 @@ public: const auto loading_start_time = armnn::GetTimeNow(); armnn::INetworkProperties networkProperties(params.m_AsyncEnabled, armnn::MemorySource::Undefined, - armnn::MemorySource::Undefined); + armnn::MemorySource::Undefined, + params.m_ThreadPoolSize); std::string errorMessage; ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties); @@ -632,6 +635,47 @@ public: } } + void RunAsync(const std::vector& inputContainers, + std::vector& outputContainers, + armnn::experimental::IAsyncExecutionCallbackPtr cb) + { + for (unsigned int i = 0; i < outputContainers.size(); ++i) + { + const unsigned int expectedOutputDataSize = GetOutputSize(i); + + mapbox::util::apply_visitor([expectedOutputDataSize, i](auto&& value) + { + const unsigned int actualOutputDataSize = armnn::numeric_cast(value.size()); + if (actualOutputDataSize < expectedOutputDataSize) + { + unsigned int outputIndex = i; + throw armnn::Exception( + fmt::format("Not enough data for output #{0}: expected " + "{1} elements, got {2}", outputIndex, expectedOutputDataSize, actualOutputDataSize)); + } + }, + outputContainers[i]); + } + + std::shared_ptr profiler = m_Runtime->GetProfiler(m_NetworkIdentifier); + if (profiler) + { + profiler->EnableProfiling(m_EnableProfiling); + } + + m_Runtime->Schedule(m_NetworkIdentifier, + MakeInputTensors(inputContainers), + MakeOutputTensors(outputContainers), + armnn::QosExecPriority::Medium, + cb); + + // if profiling is enabled print out the results + if (profiler && profiler->IsProfilingEnabled()) + { + profiler->Print(std::cout); + } + } + const armnn::BindingPointInfo& GetInputBindingInfo(unsigned int inputIndex = 0u) const { CheckInputIndexIsValid(inputIndex); -- cgit v1.2.1