diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2021-04-27 10:02:10 +0100 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2021-04-29 08:46:09 +0000 |
commit | a04a9d7c11f28c7e932435535e80223782f369f2 (patch) | |
tree | 9c1e86b0b4878dad12a359e60a8d2e8e051d2def /tests/InferenceModel.hpp | |
parent | 484d5ebb00c0db76efd76a601b5bbaa460cd2ccb (diff) | |
download | armnn-a04a9d7c11f28c7e932435535e80223782f369f2.tar.gz |
IVGCVSW-5775 'Add Async Support to ExecuteNetwork'
* Enabled async mode with '-n, concurrent' and 'simultaneous-iterations'
in ExecuteNetwork
* Number of input files provided should be equal to number of input files
provided multiply by number of simultaneous iterations divided by comma
!armnn:5443
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: Ibeb318010430bf4ae61a02b18b1bf88f3657774c
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r-- | tests/InferenceModel.hpp | 71 |
1 files changed, 67 insertions, 4 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index cab594ed48..88c704c10e 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -101,6 +101,7 @@ struct Params std::string m_CachedNetworkFilePath; unsigned int m_NumberOfThreads; std::string m_MLGOTuningFilePath; + bool m_AsyncEnabled; Params() @@ -118,6 +119,7 @@ struct Params , m_CachedNetworkFilePath("") , m_NumberOfThreads(0) , m_MLGOTuningFilePath("") + , m_AsyncEnabled(false) {} }; @@ -472,14 +474,14 @@ public: optNet->SerializeToDot(file); } - - armnn::Status ret; { ARMNN_SCOPED_HEAP_PROFILING("LoadNetwork"); const auto loading_start_time = armnn::GetTimeNow(); - ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet)); + armnn::INetworkProperties networkProperties(false, false, params.m_AsyncEnabled); + std::string errorMessage; + ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties); ARMNN_LOG(info) << "Network loading time: " << std::setprecision(2) << std::fixed << armnn::GetTimeDuration(loading_start_time).count() << " ms\n"; @@ -553,7 +555,6 @@ public: armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier, MakeInputTensors(inputContainers), MakeOutputTensors(outputContainers)); - const auto duration = armnn::GetTimeDuration(start_time); // if profiling is enabled print out the results @@ -572,6 +573,63 @@ public: } } + std::tuple<armnn::profiling::ProfilingGuid, std::chrono::duration<double, std::milli>> RunAsync( + armnn::experimental::IWorkingMemHandle& workingMemHandleRef, + const std::vector<TContainer>& inputContainers, + std::vector<TContainer>& outputContainers) + { + for (unsigned int i = 0; i < outputContainers.size(); ++i) + { + const unsigned int expectedOutputDataSize = GetOutputSize(i); + + mapbox::util::apply_visitor([expectedOutputDataSize, i](auto&& value) + { + const unsigned int actualOutputDataSize = armnn::numeric_cast<unsigned int>(value.size()); + if (actualOutputDataSize < expectedOutputDataSize) + { + unsigned int outputIndex = i; + throw armnn::Exception( + fmt::format("Not enough data for output #{0}: expected " + "{1} elements, got {2}", outputIndex, expectedOutputDataSize, actualOutputDataSize)); + } + }, + outputContainers[i]); + } + + std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier); + if (profiler) + { + profiler->EnableProfiling(m_EnableProfiling); + } + + // Start timer to record inference time in EnqueueWorkload (in milliseconds) + const auto start_time = armnn::GetTimeNow(); + + armnn::Status ret = m_Runtime->Execute(workingMemHandleRef, + MakeInputTensors(inputContainers), + MakeOutputTensors(outputContainers)); + auto inferenceID = workingMemHandleRef.GetInferenceId(); + + const auto duration = armnn::GetTimeDuration(start_time); + + // if profiling is enabled print out the results + if (profiler && profiler->IsProfilingEnabled()) + { + profiler->Print(std::cout); + } + + if (ret == armnn::Status::Failure) + { + throw armnn::Exception( + fmt::format("IRuntime::Execute asynchronously failed for network #{0} on inference #{1}", + m_NetworkIdentifier, inferenceID)); + } + else + { + return std::make_tuple(inferenceID, duration); + } + } + const armnn::BindingPointInfo& GetInputBindingInfo(unsigned int inputIndex = 0u) const { CheckInputIndexIsValid(inputIndex); @@ -618,6 +676,11 @@ public: return quantizationParams; } + std::unique_ptr<armnn::experimental::IWorkingMemHandle> CreateWorkingMemHandle() + { + return m_Runtime->CreateWorkingMemHandle(m_NetworkIdentifier); + } + private: armnn::NetworkId m_NetworkIdentifier; std::shared_ptr<armnn::IRuntime> m_Runtime; |