diff options
author | Mike Kelly <mike.kelly@arm.com> | 2021-04-07 20:10:49 +0100 |
---|---|---|
committer | finn.williams <finn.williams@arm.com> | 2021-04-08 11:23:47 +0000 |
commit | 55a8ffda24fff5515803df10fb4863d46a1effdf (patch) | |
tree | e314dea48f22ae88d452527b2decaca61df108ad /src/armnn/Runtime.cpp | |
parent | b76eaed55a89330b3b448c4f4522b3fc94a4f38d (diff) | |
download | armnn-55a8ffda24fff5515803df10fb4863d46a1effdf.tar.gz |
IVGCVSW-5823 Refactor Async Network API
* Moved IAsyncNetwork into IRuntime.
* All LoadedNetworks can be executed Asynchronously.
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: Ibbc901ab9110dc2f881425b75489bccf9ad54169
Diffstat (limited to 'src/armnn/Runtime.cpp')
-rw-r--r-- | src/armnn/Runtime.cpp | 135 |
1 files changed, 90 insertions, 45 deletions
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 57aaabd277..91a21d4b53 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -64,14 +64,6 @@ Status IRuntime::LoadNetwork(NetworkId& networkIdOut, return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network), errorMessage, networkProperties); } -std::unique_ptr<IAsyncNetwork> IRuntime::CreateAsyncNetwork(NetworkId& networkIdOut, - IOptimizedNetworkPtr network, - std::string& errorMessage, - const INetworkProperties& networkProperties) -{ - return pRuntimeImpl->CreateAsyncNetwork(networkIdOut, std::move(network), errorMessage, networkProperties); -} - TensorInfo IRuntime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const { return pRuntimeImpl->GetInputTensorInfo(networkId, layerId); @@ -89,6 +81,13 @@ Status IRuntime::EnqueueWorkload(NetworkId networkId, return pRuntimeImpl->EnqueueWorkload(networkId, inputTensors, outputTensors); } +Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle, + const InputTensors& inputTensors, + const OutputTensors& outputTensors) +{ + return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors); +} + Status IRuntime::UnloadNetwork(NetworkId networkId) { return pRuntimeImpl->UnloadNetwork(networkId); @@ -99,6 +98,11 @@ const IDeviceSpec& IRuntime::GetDeviceSpec() const return pRuntimeImpl->GetDeviceSpec(); } +std::unique_ptr<IWorkingMemHandle> IRuntime::CreateWorkingMemHandle(NetworkId networkId) +{ + return pRuntimeImpl->CreateWorkingMemHandle(networkId); +} + const std::shared_ptr<IProfiler> IRuntime::GetProfiler(NetworkId networkId) const { return pRuntimeImpl->GetProfiler(networkId); @@ -173,43 +177,6 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, return Status::Success; } -std::unique_ptr<IAsyncNetwork> RuntimeImpl::CreateAsyncNetwork(NetworkId& networkIdOut, - IOptimizedNetworkPtr network, - std::string&, - const INetworkProperties& networkProperties) -{ - IOptimizedNetwork* rawNetwork = network.release(); - - networkIdOut = GenerateNetworkId(); - - for (auto&& context : m_BackendContexts) - { - context.second->BeforeLoadNetwork(networkIdOut); - } - - unique_ptr<IAsyncNetwork> asyncNetwork = std::make_unique<IAsyncNetwork>( - std::unique_ptr<IOptimizedNetwork>(rawNetwork), - networkProperties, - m_ProfilingService); - - if (!asyncNetwork) - { - return nullptr; - } - - for (auto&& context : m_BackendContexts) - { - context.second->AfterLoadNetwork(networkIdOut); - } - - if (m_ProfilingService.IsProfilingEnabled()) - { - m_ProfilingService.IncrementCounterValue(armnn::profiling::NETWORK_LOADS); - } - - return asyncNetwork; -} - Status RuntimeImpl::UnloadNetwork(NetworkId networkId) { bool unloadOk = true; @@ -430,6 +397,17 @@ Status RuntimeImpl::EnqueueWorkload(NetworkId networkId, const OutputTensors& outputTensors) { LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); + + if (!loadedNetwork) + { + ARMNN_LOG(error) << "A Network with an id of " << networkId << " does not exist.\n"; + return Status::Failure; + } + if (loadedNetwork->IsAsyncEnabled()) + { + ARMNN_LOG(error) << "Network " << networkId << " is async enabled.\n"; + return Status::Failure; + } ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get()); ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "EnqueueWorkload"); @@ -447,6 +425,73 @@ Status RuntimeImpl::EnqueueWorkload(NetworkId networkId, return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors); } +Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle, + const InputTensors& inputTensors, + const OutputTensors& outputTensors) +{ + NetworkId networkId = iWorkingMemHandle.GetNetworkId(); + LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); + + if (!loadedNetwork) + { + ARMNN_LOG(error) << "A Network with an id of " << networkId << " does not exist.\n"; + return Status::Failure; + } + if (!loadedNetwork->IsAsyncEnabled()) + { + ARMNN_LOG(error) << "Network " << networkId << " is not async enabled.\n"; + return Status::Failure; + } + ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get()); + + ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Execute"); + + static thread_local NetworkId lastId = networkId; + if (lastId != networkId) + { + LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network) + { + network->FreeWorkingMemory(); + }); + } + lastId=networkId; + + return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle); +} + +/// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have +/// overlapped Execution by calling this function from different threads. +std::unique_ptr<IWorkingMemHandle> RuntimeImpl::CreateWorkingMemHandle(NetworkId networkId) +{ + LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); + + if (!loadedNetwork) + { + ARMNN_LOG(error) << "A Network with an id of " << networkId << " does not exist.\n"; + return nullptr; + } + if (!loadedNetwork->IsAsyncEnabled()) + { + ARMNN_LOG(error) << "Network " << networkId << " is not async enabled.\n"; + return nullptr; + } + ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get()); + + ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "CreateWorkingMemHandle"); + + static thread_local NetworkId lastId = networkId; + if (lastId != networkId) + { + LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network) + { + network->FreeWorkingMemory(); + }); + } + lastId=networkId; + + return loadedNetwork->CreateWorkingMemHandle(networkId); +} + void RuntimeImpl::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) { LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); |