diff options
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r-- | tests/InferenceModel.hpp | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index 3429598249..7c51011a22 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -102,6 +102,7 @@ struct Params unsigned int m_NumberOfThreads; std::string m_MLGOTuningFilePath; bool m_AsyncEnabled; + size_t m_ThreadPoolSize; Params() @@ -120,6 +121,7 @@ struct Params , m_NumberOfThreads(0) , m_MLGOTuningFilePath("") , m_AsyncEnabled(false) + , m_ThreadPoolSize(1) {} }; @@ -481,7 +483,8 @@ public: const auto loading_start_time = armnn::GetTimeNow(); armnn::INetworkProperties networkProperties(params.m_AsyncEnabled, armnn::MemorySource::Undefined, - armnn::MemorySource::Undefined); + armnn::MemorySource::Undefined, + params.m_ThreadPoolSize); std::string errorMessage; ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties); @@ -632,6 +635,47 @@ public: } } + void RunAsync(const std::vector<TContainer>& inputContainers, + std::vector<TContainer>& outputContainers, + armnn::experimental::IAsyncExecutionCallbackPtr cb) + { + for (unsigned int i = 0; i < outputContainers.size(); ++i) + { + const unsigned int expectedOutputDataSize = GetOutputSize(i); + + mapbox::util::apply_visitor([expectedOutputDataSize, i](auto&& value) + { + const unsigned int actualOutputDataSize = armnn::numeric_cast<unsigned int>(value.size()); + if (actualOutputDataSize < expectedOutputDataSize) + { + unsigned int outputIndex = i; + throw armnn::Exception( + fmt::format("Not enough data for output #{0}: expected " + "{1} elements, got {2}", outputIndex, expectedOutputDataSize, actualOutputDataSize)); + } + }, + outputContainers[i]); + } + + std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier); + if (profiler) + { + profiler->EnableProfiling(m_EnableProfiling); + } + + m_Runtime->Schedule(m_NetworkIdentifier, + MakeInputTensors(inputContainers), + MakeOutputTensors(outputContainers), + armnn::QosExecPriority::Medium, + cb); + + // if profiling is enabled print out the results + if (profiler && profiler->IsProfilingEnabled()) + { + profiler->Print(std::cout); + } + } + const armnn::BindingPointInfo& GetInputBindingInfo(unsigned int inputIndex = 0u) const { CheckInputIndexIsValid(inputIndex); |