aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Runtime.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Runtime.cpp')
-rw-r--r--src/armnn/Runtime.cpp112
1 files changed, 87 insertions, 25 deletions
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index aeecbfedc1..8fdc4f1e0a 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -23,10 +23,15 @@ using namespace std;
namespace armnn
{
+IRuntime::IRuntime() : pRuntimeImpl( new RuntimeImpl(armnn::IRuntime::CreationOptions())) {}
+
+IRuntime::IRuntime(const IRuntime::CreationOptions& options) : pRuntimeImpl(new RuntimeImpl(options)) {}
+
+IRuntime::~IRuntime() = default;
IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
{
- return new Runtime(options);
+ return new IRuntime(options);
}
IRuntimePtr IRuntime::Create(const CreationOptions& options)
@@ -36,32 +41,89 @@ IRuntimePtr IRuntime::Create(const CreationOptions& options)
void IRuntime::Destroy(IRuntime* runtime)
{
- delete PolymorphicDowncast<Runtime*>(runtime);
+ delete runtime;
+}
+
+Status IRuntime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network)
+{
+ return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network));
+}
+
+Status IRuntime::LoadNetwork(NetworkId& networkIdOut,
+ IOptimizedNetworkPtr network,
+ std::string& errorMessage)
+{
+ return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network), errorMessage);
+}
+
+Status IRuntime::LoadNetwork(NetworkId& networkIdOut,
+ IOptimizedNetworkPtr network,
+ std::string& errorMessage,
+ const INetworkProperties& networkProperties)
+{
+ return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network), errorMessage, networkProperties);
+}
+
+TensorInfo IRuntime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
+{
+ return pRuntimeImpl->GetInputTensorInfo(networkId, layerId);
+}
+
+TensorInfo IRuntime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
+{
+ return pRuntimeImpl->GetOutputTensorInfo(networkId, layerId);
+}
+
+Status IRuntime::EnqueueWorkload(NetworkId networkId,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors)
+{
+ return pRuntimeImpl->EnqueueWorkload(networkId, inputTensors, outputTensors);
+}
+
+Status IRuntime::UnloadNetwork(NetworkId networkId)
+{
+ return pRuntimeImpl->UnloadNetwork(networkId);
+}
+
+const IDeviceSpec& IRuntime::GetDeviceSpec() const
+{
+ return pRuntimeImpl->GetDeviceSpec();
+}
+
+const std::shared_ptr<IProfiler> IRuntime::GetProfiler(NetworkId networkId) const
+{
+ return pRuntimeImpl->GetProfiler(networkId);
+}
+
+void IRuntime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
+{
+ return pRuntimeImpl->RegisterDebugCallback(networkId, func);
}
-int Runtime::GenerateNetworkId()
+int RuntimeImpl::GenerateNetworkId()
{
return m_NetworkIdCounter++;
}
-Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
+Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
{
std::string ignoredErrorMessage;
return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
}
-Status Runtime::LoadNetwork(NetworkId& networkIdOut,
- IOptimizedNetworkPtr inNetwork,
- std::string& errorMessage)
+Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut,
+ IOptimizedNetworkPtr inNetwork,
+ std::string& errorMessage)
{
INetworkProperties networkProperties;
return LoadNetwork(networkIdOut, std::move(inNetwork), errorMessage, networkProperties);
}
-Status Runtime::LoadNetwork(NetworkId& networkIdOut,
- IOptimizedNetworkPtr inNetwork,
- std::string& errorMessage,
- const INetworkProperties& networkProperties)
+Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut,
+ IOptimizedNetworkPtr inNetwork,
+ std::string& errorMessage,
+ const INetworkProperties& networkProperties)
{
IOptimizedNetwork* rawNetwork = inNetwork.release();
@@ -103,7 +165,7 @@ Status Runtime::LoadNetwork(NetworkId& networkIdOut,
return Status::Success;
}
-Status Runtime::UnloadNetwork(NetworkId networkId)
+Status RuntimeImpl::UnloadNetwork(NetworkId networkId)
{
bool unloadOk = true;
for (auto&& context : m_BackendContexts)
@@ -113,7 +175,7 @@ Status Runtime::UnloadNetwork(NetworkId networkId)
if (!unloadOk)
{
- ARMNN_LOG(warning) << "Runtime::UnloadNetwork(): failed to unload "
+ ARMNN_LOG(warning) << "RuntimeImpl::UnloadNetwork(): failed to unload "
"network with ID:" << networkId << " because BeforeUnloadNetwork failed";
return Status::Failure;
}
@@ -136,7 +198,7 @@ Status Runtime::UnloadNetwork(NetworkId networkId)
}
if (m_LoadedNetworks.erase(networkId) == 0)
{
- ARMNN_LOG(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
+ ARMNN_LOG(warning) << "WARNING: RuntimeImpl::UnloadNetwork(): " << networkId << " not found!";
return Status::Failure;
}
@@ -151,11 +213,11 @@ Status Runtime::UnloadNetwork(NetworkId networkId)
context.second->AfterUnloadNetwork(networkId);
}
- ARMNN_LOG(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
+ ARMNN_LOG(debug) << "RuntimeImpl::UnloadNetwork(): Unloaded network with ID: " << networkId;
return Status::Success;
}
-const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
+const std::shared_ptr<IProfiler> RuntimeImpl::GetProfiler(NetworkId networkId) const
{
auto it = m_LoadedNetworks.find(networkId);
if (it != m_LoadedNetworks.end())
@@ -167,7 +229,7 @@ const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
return nullptr;
}
-void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param
+void RuntimeImpl::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param
{
// No-op for the time being, but this may be useful in future to have the profilingService available
// if (profilingService.IsProfilingEnabled()){}
@@ -182,7 +244,7 @@ void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilin
}
}
-Runtime::Runtime(const CreationOptions& options)
+RuntimeImpl::RuntimeImpl(const IRuntime::CreationOptions& options)
: m_NetworkIdCounter(0),
m_ProfilingService(*this)
{
@@ -251,7 +313,7 @@ Runtime::Runtime(const CreationOptions& options)
<< std::fixed << armnn::GetTimeDuration(start_time).count() << " ms\n";
}
-Runtime::~Runtime()
+RuntimeImpl::~RuntimeImpl()
{
const auto start_time = armnn::GetTimeNow();
std::vector<int> networkIDs;
@@ -301,24 +363,24 @@ Runtime::~Runtime()
<< std::fixed << armnn::GetTimeDuration(start_time).count() << " ms\n";
}
-LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
+LoadedNetwork* RuntimeImpl::GetLoadedNetworkPtr(NetworkId networkId) const
{
std::lock_guard<std::mutex> lockGuard(m_Mutex);
return m_LoadedNetworks.at(networkId).get();
}
-TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
+TensorInfo RuntimeImpl::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
{
return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
}
-TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
+TensorInfo RuntimeImpl::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
{
return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
}
-Status Runtime::EnqueueWorkload(NetworkId networkId,
+Status RuntimeImpl::EnqueueWorkload(NetworkId networkId,
const InputTensors& inputTensors,
const OutputTensors& outputTensors)
{
@@ -340,13 +402,13 @@ Status Runtime::EnqueueWorkload(NetworkId networkId,
return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
}
-void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
+void RuntimeImpl::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
{
LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
loadedNetwork->RegisterDebugCallback(func);
}
-void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath)
+void RuntimeImpl::LoadDynamicBackends(const std::string& overrideBackendPath)
{
// Get the paths where to load the dynamic backends from
std::vector<std::string> backendPaths = DynamicBackendUtils::GetBackendPaths(overrideBackendPath);