diff options
Diffstat (limited to 'src/armnn/Runtime.cpp')
-rw-r--r-- | src/armnn/Runtime.cpp | 48 |
1 files changed, 38 insertions, 10 deletions
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 1dd86a61ce..e04cf5ddaf 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -88,6 +88,15 @@ Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle, return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors); } +void IRuntime::Schedule(NetworkId networkId, + const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> cb) +{ + pRuntimeImpl->Schedule(networkId, inputTensors, outputTensors, priority, cb); +} + Status IRuntime::UnloadNetwork(NetworkId networkId) { return pRuntimeImpl->UnloadNetwork(networkId); @@ -150,7 +159,8 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, std::unique_ptr<IOptimizedNetwork>(rawNetwork), errorMessage, networkProperties, - m_ProfilingService); + m_ProfilingService, + networkIdOut); if (!loadedNetwork) { @@ -439,24 +449,42 @@ Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle, } if (!loadedNetwork->IsAsyncEnabled()) { - ARMNN_LOG(error) << "Network " << networkId << " is not async enabled.\n"; + ARMNN_LOG(error) << "Attempting execute " << networkId << " when it is not async enabled.\n"; return Status::Failure; } ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get()); ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Execute"); - static thread_local NetworkId lastId = networkId; - if (lastId != networkId) + return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle); +} + +void RuntimeImpl::Schedule(NetworkId networkId, + const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> callback) +{ + LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); + + if (!loadedNetwork) { - LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network) - { - network->FreeWorkingMemory(); - }); + throw armnn::Exception( + "Network with ID of " + std::to_string(networkId) + " does not exist \n" + ); + } + if (!loadedNetwork->IsAsyncEnabled()) + { + throw armnn::Exception( + "Attempting to schedule Network " + std::to_string(networkId) + " when it is not async enabled \n" + ); } - lastId=networkId; - return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle); + ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get()); + + ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Schedule"); + + loadedNetwork->Schedule(inputTensors, outputTensors, priority, callback); } /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have |