aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Runtime.cpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2021-03-29 15:04:50 +0100
committermike.kelly <mike.kelly@arm.com>2021-03-29 14:03:30 +0000
commit386ff1a721cdca3689b009ba31f2d3ac8bea2fae (patch)
treee2f5c26ab2601fd0be8c1223111f55cf1ff94e6e /src/armnn/Runtime.cpp
parent23dbe3d3ff51c2b297ce5bf6360da6552f1c3bf5 (diff)
downloadarmnn-386ff1a721cdca3689b009ba31f2d3ac8bea2fae.tar.gz
IVGCVSW-5790 Merge async prototype
* Added thread safe execution mechanism for armnn * Removed duplicate function bool Compare(T a, T b, float tolerance) * Added StridedSliceAsyncEndToEndTest * Fixed memory leak Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I2d367fc77ee7c01b8953138543e76af5e691211f
Diffstat (limited to 'src/armnn/Runtime.cpp')
-rw-r--r--src/armnn/Runtime.cpp45
1 files changed, 45 insertions, 0 deletions
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 9cc7b2cb81..5dc1ef9cc5 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -64,6 +64,14 @@ Status IRuntime::LoadNetwork(NetworkId& networkIdOut,
return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network), errorMessage, networkProperties);
}
+std::unique_ptr<IAsyncNetwork> IRuntime::CreateAsyncNetwork(NetworkId& networkIdOut,
+ IOptimizedNetworkPtr network,
+ std::string& errorMessage,
+ const INetworkProperties& networkProperties)
+{
+ return pRuntimeImpl->CreateAsyncNetwork(networkIdOut, std::move(network), errorMessage, networkProperties);
+}
+
TensorInfo IRuntime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
{
return pRuntimeImpl->GetInputTensorInfo(networkId, layerId);
@@ -165,6 +173,43 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut,
return Status::Success;
}
+std::unique_ptr<IAsyncNetwork> RuntimeImpl::CreateAsyncNetwork(NetworkId& networkIdOut,
+ IOptimizedNetworkPtr network,
+ std::string&,
+ const INetworkProperties& networkProperties)
+{
+ IOptimizedNetwork* rawNetwork = network.release();
+
+ networkIdOut = GenerateNetworkId();
+
+ for (auto&& context : m_BackendContexts)
+ {
+ context.second->BeforeLoadNetwork(networkIdOut);
+ }
+
+ unique_ptr<AsyncNetwork> asyncNetwork = std::make_unique<AsyncNetwork>(
+ std::unique_ptr<IOptimizedNetwork>(rawNetwork),
+ networkProperties,
+ m_ProfilingService);
+
+ if (!asyncNetwork)
+ {
+ return nullptr;
+ }
+
+ for (auto&& context : m_BackendContexts)
+ {
+ context.second->AfterLoadNetwork(networkIdOut);
+ }
+
+ if (m_ProfilingService.IsProfilingEnabled())
+ {
+ m_ProfilingService.IncrementCounterValue(armnn::profiling::NETWORK_LOADS);
+ }
+
+ return asyncNetwork;
+}
+
Status RuntimeImpl::UnloadNetwork(NetworkId networkId)
{
bool unloadOk = true;