aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Runtime.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Runtime.cpp')
-rw-r--r--src/armnn/Runtime.cpp48
1 files changed, 38 insertions, 10 deletions
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 1dd86a61ce..e04cf5ddaf 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -88,6 +88,15 @@ Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle,
return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors);
}
+void IRuntime::Schedule(NetworkId networkId,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> cb)
+{
+ pRuntimeImpl->Schedule(networkId, inputTensors, outputTensors, priority, cb);
+}
+
Status IRuntime::UnloadNetwork(NetworkId networkId)
{
return pRuntimeImpl->UnloadNetwork(networkId);
@@ -150,7 +159,8 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut,
std::unique_ptr<IOptimizedNetwork>(rawNetwork),
errorMessage,
networkProperties,
- m_ProfilingService);
+ m_ProfilingService,
+ networkIdOut);
if (!loadedNetwork)
{
@@ -439,24 +449,42 @@ Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle,
}
if (!loadedNetwork->IsAsyncEnabled())
{
- ARMNN_LOG(error) << "Network " << networkId << " is not async enabled.\n";
+ ARMNN_LOG(error) << "Attempting execute " << networkId << " when it is not async enabled.\n";
return Status::Failure;
}
ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get());
ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Execute");
- static thread_local NetworkId lastId = networkId;
- if (lastId != networkId)
+ return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle);
+}
+
+void RuntimeImpl::Schedule(NetworkId networkId,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> callback)
+{
+ LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
+
+ if (!loadedNetwork)
{
- LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
- {
- network->FreeWorkingMemory();
- });
+ throw armnn::Exception(
+ "Network with ID of " + std::to_string(networkId) + " does not exist \n"
+ );
+ }
+ if (!loadedNetwork->IsAsyncEnabled())
+ {
+ throw armnn::Exception(
+ "Attempting to schedule Network " + std::to_string(networkId) + " when it is not async enabled \n"
+ );
}
- lastId=networkId;
- return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle);
+ ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get());
+
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Schedule");
+
+ loadedNetwork->Schedule(inputTensors, outputTensors, priority, callback);
}
/// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have