diff options
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r-- | tests/InferenceModel.hpp | 38 |
1 files changed, 26 insertions, 12 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index 9d6096a3eb..3eb1e6a9e7 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -6,6 +6,7 @@ #pragma once #include <armnn/ArmNN.hpp> +#include <armnn/Threadpool.hpp> #include <armnn/Logging.hpp> #include <armnn/utility/Timer.hpp> #include <armnn/BackendRegistry.hpp> @@ -415,7 +416,7 @@ public: armnn::IRuntime::CreationOptions options; options.m_EnableGpuProfiling = m_EnableProfiling; options.m_DynamicBackendsPath = m_DynamicBackendsPath; - m_Runtime = std::move(armnn::IRuntime::Create(options)); + m_Runtime = armnn::IRuntime::Create(options); } std::string invalidBackends; @@ -484,13 +485,25 @@ public: const auto loading_start_time = armnn::GetTimeNow(); armnn::INetworkProperties networkProperties(params.m_AsyncEnabled, armnn::MemorySource::Undefined, - armnn::MemorySource::Undefined, - params.m_ThreadPoolSize); + armnn::MemorySource::Undefined); std::string errorMessage; ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties); ARMNN_LOG(info) << "Network loading time: " << std::setprecision(2) << std::fixed << armnn::GetTimeDuration(loading_start_time).count() << " ms\n"; + + if (params.m_AsyncEnabled && params.m_ThreadPoolSize > 0) + { + std::vector<std::shared_ptr<armnn::IWorkingMemHandle>> memHandles; + for (size_t i = 0; i < params.m_ThreadPoolSize; ++i) + { + memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(m_NetworkIdentifier)); + } + + m_Threadpool = std::make_unique<armnn::Threadpool>(params.m_ThreadPoolSize, + m_Runtime.get(), + memHandles); + } } if (ret == armnn::Status::Failure) @@ -579,10 +592,11 @@ public: } } - std::tuple<armnn::profiling::ProfilingGuid, std::chrono::duration<double, std::milli>> RunAsync( + std::tuple<unsigned int, std::chrono::duration<double, std::milli>> RunAsync( armnn::experimental::IWorkingMemHandle& workingMemHandleRef, const std::vector<TContainer>& inputContainers, - std::vector<TContainer>& outputContainers) + std::vector<TContainer>& outputContainers, + unsigned int inferenceID) { for (unsigned int i = 0; i < outputContainers.size(); ++i) { @@ -614,7 +628,6 @@ public: armnn::Status ret = m_Runtime->Execute(workingMemHandleRef, MakeInputTensors(inputContainers), MakeOutputTensors(outputContainers)); - auto inferenceID = workingMemHandleRef.GetInferenceId(); const auto duration = armnn::GetTimeDuration(start_time); @@ -638,7 +651,7 @@ public: void RunAsync(const std::vector<TContainer>& inputContainers, std::vector<TContainer>& outputContainers, - armnn::experimental::IAsyncExecutionCallbackPtr cb) + std::shared_ptr<armnn::IAsyncExecutionCallback> cb) { for (unsigned int i = 0; i < outputContainers.size(); ++i) { @@ -664,11 +677,11 @@ public: profiler->EnableProfiling(m_EnableProfiling); } - m_Runtime->Schedule(m_NetworkIdentifier, - MakeInputTensors(inputContainers), - MakeOutputTensors(outputContainers), - armnn::QosExecPriority::Medium, - cb); + m_Threadpool->Schedule(m_NetworkIdentifier, + MakeInputTensors(inputContainers), + MakeOutputTensors(outputContainers), + armnn::QosExecPriority::Medium, + cb); // if profiling is enabled print out the results if (profiler && profiler->IsProfilingEnabled()) @@ -731,6 +744,7 @@ public: private: armnn::NetworkId m_NetworkIdentifier; std::shared_ptr<armnn::IRuntime> m_Runtime; + std::unique_ptr<armnn::Threadpool> m_Threadpool; std::vector<armnn::BindingPointInfo> m_InputBindings; std::vector<armnn::BindingPointInfo> m_OutputBindings; |