From 386ff1a721cdca3689b009ba31f2d3ac8bea2fae Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Mon, 29 Mar 2021 15:04:50 +0100 Subject: IVGCVSW-5790 Merge async prototype * Added thread safe execution mechanism for armnn * Removed duplicate function bool Compare(T a, T b, float tolerance) * Added StridedSliceAsyncEndToEndTest * Fixed memory leak Signed-off-by: Mike Kelly Change-Id: I2d367fc77ee7c01b8953138543e76af5e691211f --- src/armnn/Runtime.cpp | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) (limited to 'src/armnn/Runtime.cpp') diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 9cc7b2cb81..5dc1ef9cc5 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -64,6 +64,14 @@ Status IRuntime::LoadNetwork(NetworkId& networkIdOut, return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network), errorMessage, networkProperties); } +std::unique_ptr IRuntime::CreateAsyncNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr network, + std::string& errorMessage, + const INetworkProperties& networkProperties) +{ + return pRuntimeImpl->CreateAsyncNetwork(networkIdOut, std::move(network), errorMessage, networkProperties); +} + TensorInfo IRuntime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const { return pRuntimeImpl->GetInputTensorInfo(networkId, layerId); @@ -165,6 +173,43 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, return Status::Success; } +std::unique_ptr RuntimeImpl::CreateAsyncNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr network, + std::string&, + const INetworkProperties& networkProperties) +{ + IOptimizedNetwork* rawNetwork = network.release(); + + networkIdOut = GenerateNetworkId(); + + for (auto&& context : m_BackendContexts) + { + context.second->BeforeLoadNetwork(networkIdOut); + } + + unique_ptr asyncNetwork = std::make_unique( + std::unique_ptr(rawNetwork), + networkProperties, + m_ProfilingService); + + if (!asyncNetwork) + { + return nullptr; + } + + for (auto&& context : m_BackendContexts) + { + context.second->AfterLoadNetwork(networkIdOut); + } + + if (m_ProfilingService.IsProfilingEnabled()) + { + m_ProfilingService.IncrementCounterValue(armnn::profiling::NETWORK_LOADS); + } + + return asyncNetwork; +} + Status RuntimeImpl::UnloadNetwork(NetworkId networkId) { bool unloadOk = true; -- cgit v1.2.1