aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceModel.hpp
diff options
context:
space:
mode:
authorKevin May <kevin.may@arm.com>2021-05-21 16:42:21 +0100
committerKevin May <kevin.may@arm.com>2021-05-26 11:56:54 +0000
commitb4b3ac91990eb5deaffca2300319f2ddf7aa0886 (patch)
treee480826fe604d652877459ce4bbf4314a461e4b2 /tests/InferenceModel.hpp
parent401c1c3f973da1a2e2cef7f88a5aac2cf295fac7 (diff)
downloadarmnn-b4b3ac91990eb5deaffca2300319f2ddf7aa0886.tar.gz
IVGCVSW-6009 Integrate threadpool into ExNet
* Remove concurrent flag from ExecuteNetwork as it is possible to deduce if SimultaneousIterations > 1 * Add void RunAsync() * Refactor some unit tests Change-Id: I7021d4821b0e460470908294cbd9462850e8b361 Signed-off-by: Keith Davis <keith.davis@arm.com> Signed-off-by: Kevin May <kevin.may@arm.com>
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r--tests/InferenceModel.hpp46
1 files changed, 45 insertions, 1 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index 3429598249..7c51011a22 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -102,6 +102,7 @@ struct Params
unsigned int m_NumberOfThreads;
std::string m_MLGOTuningFilePath;
bool m_AsyncEnabled;
+ size_t m_ThreadPoolSize;
Params()
@@ -120,6 +121,7 @@ struct Params
, m_NumberOfThreads(0)
, m_MLGOTuningFilePath("")
, m_AsyncEnabled(false)
+ , m_ThreadPoolSize(1)
{}
};
@@ -481,7 +483,8 @@ public:
const auto loading_start_time = armnn::GetTimeNow();
armnn::INetworkProperties networkProperties(params.m_AsyncEnabled,
armnn::MemorySource::Undefined,
- armnn::MemorySource::Undefined);
+ armnn::MemorySource::Undefined,
+ params.m_ThreadPoolSize);
std::string errorMessage;
ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties);
@@ -632,6 +635,47 @@ public:
}
}
+ void RunAsync(const std::vector<TContainer>& inputContainers,
+ std::vector<TContainer>& outputContainers,
+ armnn::experimental::IAsyncExecutionCallbackPtr cb)
+ {
+ for (unsigned int i = 0; i < outputContainers.size(); ++i)
+ {
+ const unsigned int expectedOutputDataSize = GetOutputSize(i);
+
+ mapbox::util::apply_visitor([expectedOutputDataSize, i](auto&& value)
+ {
+ const unsigned int actualOutputDataSize = armnn::numeric_cast<unsigned int>(value.size());
+ if (actualOutputDataSize < expectedOutputDataSize)
+ {
+ unsigned int outputIndex = i;
+ throw armnn::Exception(
+ fmt::format("Not enough data for output #{0}: expected "
+ "{1} elements, got {2}", outputIndex, expectedOutputDataSize, actualOutputDataSize));
+ }
+ },
+ outputContainers[i]);
+ }
+
+ std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier);
+ if (profiler)
+ {
+ profiler->EnableProfiling(m_EnableProfiling);
+ }
+
+ m_Runtime->Schedule(m_NetworkIdentifier,
+ MakeInputTensors(inputContainers),
+ MakeOutputTensors(outputContainers),
+ armnn::QosExecPriority::Medium,
+ cb);
+
+ // if profiling is enabled print out the results
+ if (profiler && profiler->IsProfilingEnabled())
+ {
+ profiler->Print(std::cout);
+ }
+ }
+
const armnn::BindingPointInfo& GetInputBindingInfo(unsigned int inputIndex = 0u) const
{
CheckInputIndexIsValid(inputIndex);