diff options
Diffstat (limited to 'src/armnn/Runtime.cpp')
-rw-r--r-- | src/armnn/Runtime.cpp | 112 |
1 files changed, 87 insertions, 25 deletions
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index aeecbfedc1..8fdc4f1e0a 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -23,10 +23,15 @@ using namespace std; namespace armnn { +IRuntime::IRuntime() : pRuntimeImpl( new RuntimeImpl(armnn::IRuntime::CreationOptions())) {} + +IRuntime::IRuntime(const IRuntime::CreationOptions& options) : pRuntimeImpl(new RuntimeImpl(options)) {} + +IRuntime::~IRuntime() = default; IRuntime* IRuntime::CreateRaw(const CreationOptions& options) { - return new Runtime(options); + return new IRuntime(options); } IRuntimePtr IRuntime::Create(const CreationOptions& options) @@ -36,32 +41,89 @@ IRuntimePtr IRuntime::Create(const CreationOptions& options) void IRuntime::Destroy(IRuntime* runtime) { - delete PolymorphicDowncast<Runtime*>(runtime); + delete runtime; +} + +Status IRuntime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) +{ + return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network)); +} + +Status IRuntime::LoadNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr network, + std::string& errorMessage) +{ + return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network), errorMessage); +} + +Status IRuntime::LoadNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr network, + std::string& errorMessage, + const INetworkProperties& networkProperties) +{ + return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network), errorMessage, networkProperties); +} + +TensorInfo IRuntime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const +{ + return pRuntimeImpl->GetInputTensorInfo(networkId, layerId); +} + +TensorInfo IRuntime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const +{ + return pRuntimeImpl->GetOutputTensorInfo(networkId, layerId); +} + +Status IRuntime::EnqueueWorkload(NetworkId networkId, + const InputTensors& inputTensors, + const OutputTensors& outputTensors) +{ + return pRuntimeImpl->EnqueueWorkload(networkId, inputTensors, outputTensors); +} + +Status IRuntime::UnloadNetwork(NetworkId networkId) +{ + return pRuntimeImpl->UnloadNetwork(networkId); +} + +const IDeviceSpec& IRuntime::GetDeviceSpec() const +{ + return pRuntimeImpl->GetDeviceSpec(); +} + +const std::shared_ptr<IProfiler> IRuntime::GetProfiler(NetworkId networkId) const +{ + return pRuntimeImpl->GetProfiler(networkId); +} + +void IRuntime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) +{ + return pRuntimeImpl->RegisterDebugCallback(networkId, func); } -int Runtime::GenerateNetworkId() +int RuntimeImpl::GenerateNetworkId() { return m_NetworkIdCounter++; } -Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork) +Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork) { std::string ignoredErrorMessage; return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage); } -Status Runtime::LoadNetwork(NetworkId& networkIdOut, - IOptimizedNetworkPtr inNetwork, - std::string& errorMessage) +Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr inNetwork, + std::string& errorMessage) { INetworkProperties networkProperties; return LoadNetwork(networkIdOut, std::move(inNetwork), errorMessage, networkProperties); } -Status Runtime::LoadNetwork(NetworkId& networkIdOut, - IOptimizedNetworkPtr inNetwork, - std::string& errorMessage, - const INetworkProperties& networkProperties) +Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr inNetwork, + std::string& errorMessage, + const INetworkProperties& networkProperties) { IOptimizedNetwork* rawNetwork = inNetwork.release(); @@ -103,7 +165,7 @@ Status Runtime::LoadNetwork(NetworkId& networkIdOut, return Status::Success; } -Status Runtime::UnloadNetwork(NetworkId networkId) +Status RuntimeImpl::UnloadNetwork(NetworkId networkId) { bool unloadOk = true; for (auto&& context : m_BackendContexts) @@ -113,7 +175,7 @@ Status Runtime::UnloadNetwork(NetworkId networkId) if (!unloadOk) { - ARMNN_LOG(warning) << "Runtime::UnloadNetwork(): failed to unload " + ARMNN_LOG(warning) << "RuntimeImpl::UnloadNetwork(): failed to unload " "network with ID:" << networkId << " because BeforeUnloadNetwork failed"; return Status::Failure; } @@ -136,7 +198,7 @@ Status Runtime::UnloadNetwork(NetworkId networkId) } if (m_LoadedNetworks.erase(networkId) == 0) { - ARMNN_LOG(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!"; + ARMNN_LOG(warning) << "WARNING: RuntimeImpl::UnloadNetwork(): " << networkId << " not found!"; return Status::Failure; } @@ -151,11 +213,11 @@ Status Runtime::UnloadNetwork(NetworkId networkId) context.second->AfterUnloadNetwork(networkId); } - ARMNN_LOG(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId; + ARMNN_LOG(debug) << "RuntimeImpl::UnloadNetwork(): Unloaded network with ID: " << networkId; return Status::Success; } -const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const +const std::shared_ptr<IProfiler> RuntimeImpl::GetProfiler(NetworkId networkId) const { auto it = m_LoadedNetworks.find(networkId); if (it != m_LoadedNetworks.end()) @@ -167,7 +229,7 @@ const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const return nullptr; } -void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param +void RuntimeImpl::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param { // No-op for the time being, but this may be useful in future to have the profilingService available // if (profilingService.IsProfilingEnabled()){} @@ -182,7 +244,7 @@ void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilin } } -Runtime::Runtime(const CreationOptions& options) +RuntimeImpl::RuntimeImpl(const IRuntime::CreationOptions& options) : m_NetworkIdCounter(0), m_ProfilingService(*this) { @@ -251,7 +313,7 @@ Runtime::Runtime(const CreationOptions& options) << std::fixed << armnn::GetTimeDuration(start_time).count() << " ms\n"; } -Runtime::~Runtime() +RuntimeImpl::~RuntimeImpl() { const auto start_time = armnn::GetTimeNow(); std::vector<int> networkIDs; @@ -301,24 +363,24 @@ Runtime::~Runtime() << std::fixed << armnn::GetTimeDuration(start_time).count() << " ms\n"; } -LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const +LoadedNetwork* RuntimeImpl::GetLoadedNetworkPtr(NetworkId networkId) const { std::lock_guard<std::mutex> lockGuard(m_Mutex); return m_LoadedNetworks.at(networkId).get(); } -TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const +TensorInfo RuntimeImpl::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const { return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId); } -TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const +TensorInfo RuntimeImpl::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const { return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId); } -Status Runtime::EnqueueWorkload(NetworkId networkId, +Status RuntimeImpl::EnqueueWorkload(NetworkId networkId, const InputTensors& inputTensors, const OutputTensors& outputTensors) { @@ -340,13 +402,13 @@ Status Runtime::EnqueueWorkload(NetworkId networkId, return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors); } -void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) +void RuntimeImpl::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) { LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); loadedNetwork->RegisterDebugCallback(func); } -void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath) +void RuntimeImpl::LoadDynamicBackends(const std::string& overrideBackendPath) { // Get the paths where to load the dynamic backends from std::vector<std::string> backendPaths = DynamicBackendUtils::GetBackendPaths(overrideBackendPath); |