aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Runtime.cpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2021-04-07 20:10:49 +0100
committerfinn.williams <finn.williams@arm.com>2021-04-08 11:23:47 +0000
commit55a8ffda24fff5515803df10fb4863d46a1effdf (patch)
treee314dea48f22ae88d452527b2decaca61df108ad /src/armnn/Runtime.cpp
parentb76eaed55a89330b3b448c4f4522b3fc94a4f38d (diff)
downloadarmnn-55a8ffda24fff5515803df10fb4863d46a1effdf.tar.gz
IVGCVSW-5823 Refactor Async Network API
* Moved IAsyncNetwork into IRuntime. * All LoadedNetworks can be executed Asynchronously. Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: Ibbc901ab9110dc2f881425b75489bccf9ad54169
Diffstat (limited to 'src/armnn/Runtime.cpp')
-rw-r--r--src/armnn/Runtime.cpp135
1 files changed, 90 insertions, 45 deletions
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 57aaabd277..91a21d4b53 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -64,14 +64,6 @@ Status IRuntime::LoadNetwork(NetworkId& networkIdOut,
return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network), errorMessage, networkProperties);
}
-std::unique_ptr<IAsyncNetwork> IRuntime::CreateAsyncNetwork(NetworkId& networkIdOut,
- IOptimizedNetworkPtr network,
- std::string& errorMessage,
- const INetworkProperties& networkProperties)
-{
- return pRuntimeImpl->CreateAsyncNetwork(networkIdOut, std::move(network), errorMessage, networkProperties);
-}
-
TensorInfo IRuntime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
{
return pRuntimeImpl->GetInputTensorInfo(networkId, layerId);
@@ -89,6 +81,13 @@ Status IRuntime::EnqueueWorkload(NetworkId networkId,
return pRuntimeImpl->EnqueueWorkload(networkId, inputTensors, outputTensors);
}
+Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors)
+{
+ return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors);
+}
+
Status IRuntime::UnloadNetwork(NetworkId networkId)
{
return pRuntimeImpl->UnloadNetwork(networkId);
@@ -99,6 +98,11 @@ const IDeviceSpec& IRuntime::GetDeviceSpec() const
return pRuntimeImpl->GetDeviceSpec();
}
+std::unique_ptr<IWorkingMemHandle> IRuntime::CreateWorkingMemHandle(NetworkId networkId)
+{
+ return pRuntimeImpl->CreateWorkingMemHandle(networkId);
+}
+
const std::shared_ptr<IProfiler> IRuntime::GetProfiler(NetworkId networkId) const
{
return pRuntimeImpl->GetProfiler(networkId);
@@ -173,43 +177,6 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut,
return Status::Success;
}
-std::unique_ptr<IAsyncNetwork> RuntimeImpl::CreateAsyncNetwork(NetworkId& networkIdOut,
- IOptimizedNetworkPtr network,
- std::string&,
- const INetworkProperties& networkProperties)
-{
- IOptimizedNetwork* rawNetwork = network.release();
-
- networkIdOut = GenerateNetworkId();
-
- for (auto&& context : m_BackendContexts)
- {
- context.second->BeforeLoadNetwork(networkIdOut);
- }
-
- unique_ptr<IAsyncNetwork> asyncNetwork = std::make_unique<IAsyncNetwork>(
- std::unique_ptr<IOptimizedNetwork>(rawNetwork),
- networkProperties,
- m_ProfilingService);
-
- if (!asyncNetwork)
- {
- return nullptr;
- }
-
- for (auto&& context : m_BackendContexts)
- {
- context.second->AfterLoadNetwork(networkIdOut);
- }
-
- if (m_ProfilingService.IsProfilingEnabled())
- {
- m_ProfilingService.IncrementCounterValue(armnn::profiling::NETWORK_LOADS);
- }
-
- return asyncNetwork;
-}
-
Status RuntimeImpl::UnloadNetwork(NetworkId networkId)
{
bool unloadOk = true;
@@ -430,6 +397,17 @@ Status RuntimeImpl::EnqueueWorkload(NetworkId networkId,
const OutputTensors& outputTensors)
{
LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
+
+ if (!loadedNetwork)
+ {
+ ARMNN_LOG(error) << "A Network with an id of " << networkId << " does not exist.\n";
+ return Status::Failure;
+ }
+ if (loadedNetwork->IsAsyncEnabled())
+ {
+ ARMNN_LOG(error) << "Network " << networkId << " is async enabled.\n";
+ return Status::Failure;
+ }
ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get());
ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "EnqueueWorkload");
@@ -447,6 +425,73 @@ Status RuntimeImpl::EnqueueWorkload(NetworkId networkId,
return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
}
+Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors)
+{
+ NetworkId networkId = iWorkingMemHandle.GetNetworkId();
+ LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
+
+ if (!loadedNetwork)
+ {
+ ARMNN_LOG(error) << "A Network with an id of " << networkId << " does not exist.\n";
+ return Status::Failure;
+ }
+ if (!loadedNetwork->IsAsyncEnabled())
+ {
+ ARMNN_LOG(error) << "Network " << networkId << " is not async enabled.\n";
+ return Status::Failure;
+ }
+ ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get());
+
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Execute");
+
+ static thread_local NetworkId lastId = networkId;
+ if (lastId != networkId)
+ {
+ LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
+ {
+ network->FreeWorkingMemory();
+ });
+ }
+ lastId=networkId;
+
+ return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle);
+}
+
+/// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
+/// overlapped Execution by calling this function from different threads.
+std::unique_ptr<IWorkingMemHandle> RuntimeImpl::CreateWorkingMemHandle(NetworkId networkId)
+{
+ LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
+
+ if (!loadedNetwork)
+ {
+ ARMNN_LOG(error) << "A Network with an id of " << networkId << " does not exist.\n";
+ return nullptr;
+ }
+ if (!loadedNetwork->IsAsyncEnabled())
+ {
+ ARMNN_LOG(error) << "Network " << networkId << " is not async enabled.\n";
+ return nullptr;
+ }
+ ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get());
+
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "CreateWorkingMemHandle");
+
+ static thread_local NetworkId lastId = networkId;
+ if (lastId != networkId)
+ {
+ LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
+ {
+ network->FreeWorkingMemory();
+ });
+ }
+ lastId=networkId;
+
+ return loadedNetwork->CreateWorkingMemHandle(networkId);
+}
+
void RuntimeImpl::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
{
LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);