From d92a6e4c19567cb03de76963068c002353cea528 Mon Sep 17 00:00:00 2001 From: Kevin May Date: Thu, 4 Feb 2021 10:27:41 +0000 Subject: IVGCVSW-4873 Implement Pimpl Idiom for IRuntime Signed-off-by: Kevin May Change-Id: I52448938735b2aa678c47e0f3061c87fa0c693b1 --- src/armnn/Runtime.cpp | 112 +++++++++++++++++++++++++++++++--------- src/armnn/Runtime.hpp | 45 ++++++++-------- src/armnn/test/RuntimeTests.cpp | 8 +-- src/armnn/test/RuntimeTests.hpp | 2 +- src/armnn/test/TestUtils.cpp | 2 +- src/armnn/test/TestUtils.hpp | 2 +- 6 files changed, 116 insertions(+), 55 deletions(-) (limited to 'src/armnn') diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index aeecbfedc1..8fdc4f1e0a 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -23,10 +23,15 @@ using namespace std; namespace armnn { +IRuntime::IRuntime() : pRuntimeImpl( new RuntimeImpl(armnn::IRuntime::CreationOptions())) {} + +IRuntime::IRuntime(const IRuntime::CreationOptions& options) : pRuntimeImpl(new RuntimeImpl(options)) {} + +IRuntime::~IRuntime() = default; IRuntime* IRuntime::CreateRaw(const CreationOptions& options) { - return new Runtime(options); + return new IRuntime(options); } IRuntimePtr IRuntime::Create(const CreationOptions& options) @@ -36,32 +41,89 @@ IRuntimePtr IRuntime::Create(const CreationOptions& options) void IRuntime::Destroy(IRuntime* runtime) { - delete PolymorphicDowncast(runtime); + delete runtime; +} + +Status IRuntime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) +{ + return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network)); +} + +Status IRuntime::LoadNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr network, + std::string& errorMessage) +{ + return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network), errorMessage); +} + +Status IRuntime::LoadNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr network, + std::string& errorMessage, + const INetworkProperties& networkProperties) +{ + return pRuntimeImpl->LoadNetwork(networkIdOut, std::move(network), errorMessage, networkProperties); +} + +TensorInfo IRuntime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const +{ + return pRuntimeImpl->GetInputTensorInfo(networkId, layerId); +} + +TensorInfo IRuntime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const +{ + return pRuntimeImpl->GetOutputTensorInfo(networkId, layerId); +} + +Status IRuntime::EnqueueWorkload(NetworkId networkId, + const InputTensors& inputTensors, + const OutputTensors& outputTensors) +{ + return pRuntimeImpl->EnqueueWorkload(networkId, inputTensors, outputTensors); +} + +Status IRuntime::UnloadNetwork(NetworkId networkId) +{ + return pRuntimeImpl->UnloadNetwork(networkId); +} + +const IDeviceSpec& IRuntime::GetDeviceSpec() const +{ + return pRuntimeImpl->GetDeviceSpec(); +} + +const std::shared_ptr IRuntime::GetProfiler(NetworkId networkId) const +{ + return pRuntimeImpl->GetProfiler(networkId); +} + +void IRuntime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) +{ + return pRuntimeImpl->RegisterDebugCallback(networkId, func); } -int Runtime::GenerateNetworkId() +int RuntimeImpl::GenerateNetworkId() { return m_NetworkIdCounter++; } -Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork) +Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork) { std::string ignoredErrorMessage; return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage); } -Status Runtime::LoadNetwork(NetworkId& networkIdOut, - IOptimizedNetworkPtr inNetwork, - std::string& errorMessage) +Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr inNetwork, + std::string& errorMessage) { INetworkProperties networkProperties; return LoadNetwork(networkIdOut, std::move(inNetwork), errorMessage, networkProperties); } -Status Runtime::LoadNetwork(NetworkId& networkIdOut, - IOptimizedNetworkPtr inNetwork, - std::string& errorMessage, - const INetworkProperties& networkProperties) +Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr inNetwork, + std::string& errorMessage, + const INetworkProperties& networkProperties) { IOptimizedNetwork* rawNetwork = inNetwork.release(); @@ -103,7 +165,7 @@ Status Runtime::LoadNetwork(NetworkId& networkIdOut, return Status::Success; } -Status Runtime::UnloadNetwork(NetworkId networkId) +Status RuntimeImpl::UnloadNetwork(NetworkId networkId) { bool unloadOk = true; for (auto&& context : m_BackendContexts) @@ -113,7 +175,7 @@ Status Runtime::UnloadNetwork(NetworkId networkId) if (!unloadOk) { - ARMNN_LOG(warning) << "Runtime::UnloadNetwork(): failed to unload " + ARMNN_LOG(warning) << "RuntimeImpl::UnloadNetwork(): failed to unload " "network with ID:" << networkId << " because BeforeUnloadNetwork failed"; return Status::Failure; } @@ -136,7 +198,7 @@ Status Runtime::UnloadNetwork(NetworkId networkId) } if (m_LoadedNetworks.erase(networkId) == 0) { - ARMNN_LOG(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!"; + ARMNN_LOG(warning) << "WARNING: RuntimeImpl::UnloadNetwork(): " << networkId << " not found!"; return Status::Failure; } @@ -151,11 +213,11 @@ Status Runtime::UnloadNetwork(NetworkId networkId) context.second->AfterUnloadNetwork(networkId); } - ARMNN_LOG(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId; + ARMNN_LOG(debug) << "RuntimeImpl::UnloadNetwork(): Unloaded network with ID: " << networkId; return Status::Success; } -const std::shared_ptr Runtime::GetProfiler(NetworkId networkId) const +const std::shared_ptr RuntimeImpl::GetProfiler(NetworkId networkId) const { auto it = m_LoadedNetworks.find(networkId); if (it != m_LoadedNetworks.end()) @@ -167,7 +229,7 @@ const std::shared_ptr Runtime::GetProfiler(NetworkId networkId) const return nullptr; } -void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param +void RuntimeImpl::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param { // No-op for the time being, but this may be useful in future to have the profilingService available // if (profilingService.IsProfilingEnabled()){} @@ -182,7 +244,7 @@ void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilin } } -Runtime::Runtime(const CreationOptions& options) +RuntimeImpl::RuntimeImpl(const IRuntime::CreationOptions& options) : m_NetworkIdCounter(0), m_ProfilingService(*this) { @@ -251,7 +313,7 @@ Runtime::Runtime(const CreationOptions& options) << std::fixed << armnn::GetTimeDuration(start_time).count() << " ms\n"; } -Runtime::~Runtime() +RuntimeImpl::~RuntimeImpl() { const auto start_time = armnn::GetTimeNow(); std::vector networkIDs; @@ -301,24 +363,24 @@ Runtime::~Runtime() << std::fixed << armnn::GetTimeDuration(start_time).count() << " ms\n"; } -LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const +LoadedNetwork* RuntimeImpl::GetLoadedNetworkPtr(NetworkId networkId) const { std::lock_guard lockGuard(m_Mutex); return m_LoadedNetworks.at(networkId).get(); } -TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const +TensorInfo RuntimeImpl::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const { return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId); } -TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const +TensorInfo RuntimeImpl::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const { return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId); } -Status Runtime::EnqueueWorkload(NetworkId networkId, +Status RuntimeImpl::EnqueueWorkload(NetworkId networkId, const InputTensors& inputTensors, const OutputTensors& outputTensors) { @@ -340,13 +402,13 @@ Status Runtime::EnqueueWorkload(NetworkId networkId, return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors); } -void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) +void RuntimeImpl::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) { LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); loadedNetwork->RegisterDebugCallback(func); } -void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath) +void RuntimeImpl::LoadDynamicBackends(const std::string& overrideBackendPath) { // Get the paths where to load the dynamic backends from std::vector backendPaths = DynamicBackendUtils::GetBackendPaths(overrideBackendPath); diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp index 3c90c51bb2..2c7e07f9fb 100644 --- a/src/armnn/Runtime.hpp +++ b/src/armnn/Runtime.hpp @@ -27,8 +27,7 @@ namespace armnn using LoadedNetworks = std::unordered_map>; using IReportStructure = profiling::IReportStructure; -class Runtime final : public IRuntime, - public IReportStructure +struct RuntimeImpl final : public IReportStructure { public: /// Loads a complete network into the Runtime. @@ -36,7 +35,7 @@ public: /// @param [in] network - Complete network to load into the Runtime. /// The runtime takes ownership of the network once passed in. /// @return armnn::Status - virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override; + Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network); /// Load a complete network into the IRuntime. /// @param [out] networkIdOut Unique identifier for the network is returned in this reference. @@ -44,55 +43,55 @@ public: /// @param [out] errorMessage Error message if there were any errors. /// The runtime takes ownership of the network once passed in. /// @return armnn::Status - virtual Status LoadNetwork(NetworkId& networkIdOut, - IOptimizedNetworkPtr network, - std::string& errorMessage) override; + Status LoadNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr network, + std::string& errorMessage); - virtual Status LoadNetwork(NetworkId& networkIdOut, - IOptimizedNetworkPtr network, - std::string& errorMessage, - const INetworkProperties& networkProperties) override; + Status LoadNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr network, + std::string& errorMessage, + const INetworkProperties& networkProperties); - virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override; - virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override; + TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const; + TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const; // Evaluates network using input in inputTensors, outputs filled into outputTensors. - virtual Status EnqueueWorkload(NetworkId networkId, + Status EnqueueWorkload(NetworkId networkId, const InputTensors& inputTensors, - const OutputTensors& outputTensors) override; + const OutputTensors& outputTensors); /// Unloads a network from the Runtime. /// At the moment this only removes the network from the m_Impl->m_Network. /// This might need more work in the future to be AndroidNN compliant. /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork(). /// @return armnn::Status - virtual Status UnloadNetwork(NetworkId networkId) override; + Status UnloadNetwork(NetworkId networkId); - virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; } + const IDeviceSpec& GetDeviceSpec() const { return m_DeviceSpec; } /// Gets the profiler corresponding to the given network id. /// @param networkId The id of the network for which to get the profile. /// @return A pointer to the requested profiler, or nullptr if not found. - virtual const std::shared_ptr GetProfiler(NetworkId networkId) const override; + const std::shared_ptr GetProfiler(NetworkId networkId) const; /// Registers a callback function to debug layers performing custom computations on intermediate tensors. /// @param networkId The id of the network to register the callback. /// @param func callback function to pass to the debug layer. - virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) override; + void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func); /// Creates a runtime for workload execution. - Runtime(const CreationOptions& options); + RuntimeImpl(const IRuntime::CreationOptions& options); - ~Runtime(); + ~RuntimeImpl(); //NOTE: we won't need the profiling service reference but it is good to pass the service // in this way to facilitate other implementations down the road - virtual void ReportStructure() override; + void ReportStructure(); private: - friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp + friend void RuntimeLoadedNetworksReserve(RuntimeImpl* runtime); // See RuntimeTests.cpp - friend profiling::ProfilingService& GetProfilingService(armnn::Runtime* runtime); // See RuntimeTests.cpp + friend profiling::ProfilingService& GetProfilingService(RuntimeImpl* runtime); // See RuntimeTests.cpp int GenerateNetworkId(); diff --git a/src/armnn/test/RuntimeTests.cpp b/src/armnn/test/RuntimeTests.cpp index b3a8bbd6a6..1d5960b2a4 100644 --- a/src/armnn/test/RuntimeTests.cpp +++ b/src/armnn/test/RuntimeTests.cpp @@ -27,7 +27,7 @@ namespace armnn { -void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime) +void RuntimeLoadedNetworksReserve(armnn::RuntimeImpl* runtime) { runtime->m_LoadedNetworks.reserve(1); } @@ -129,7 +129,7 @@ BOOST_AUTO_TEST_CASE(RuntimeMemoryLeak) // ensure that runtime is large enough before checking for memory leaks // otherwise when loading the network it will automatically reserve memory that won't be released until destruction armnn::IRuntime::CreationOptions options; - armnn::Runtime runtime(options); + armnn::RuntimeImpl runtime(options); armnn::RuntimeLoadedNetworksReserve(&runtime); { @@ -333,7 +333,7 @@ BOOST_AUTO_TEST_CASE(ProfilingDisable) // Create runtime in which the test will run armnn::IRuntime::CreationOptions options; - armnn::Runtime runtime(options); + armnn::RuntimeImpl runtime(options); // build up the structure of the network INetworkPtr net(INetwork::Create()); @@ -378,7 +378,7 @@ BOOST_AUTO_TEST_CASE(ProfilingEnableCpuRef) options.m_ProfilingOptions.m_EnableProfiling = true; options.m_ProfilingOptions.m_TimelineEnabled = true; - armnn::Runtime runtime(options); + armnn::RuntimeImpl runtime(options); GetProfilingService(&runtime).ResetExternalProfilingOptions(options.m_ProfilingOptions, false); profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); diff --git a/src/armnn/test/RuntimeTests.hpp b/src/armnn/test/RuntimeTests.hpp index 90aed5de1e..bca8324733 100644 --- a/src/armnn/test/RuntimeTests.hpp +++ b/src/armnn/test/RuntimeTests.hpp @@ -9,6 +9,6 @@ namespace armnn { -void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); +void RuntimeLoadedNetworksReserve(armnn::RuntimeImpl* runtime); } // namespace armnn diff --git a/src/armnn/test/TestUtils.cpp b/src/armnn/test/TestUtils.cpp index 6d7d02dcff..440d4e09f3 100644 --- a/src/armnn/test/TestUtils.cpp +++ b/src/armnn/test/TestUtils.cpp @@ -22,7 +22,7 @@ void Connect(armnn::IConnectableLayer* from, armnn::IConnectableLayer* to, const namespace armnn { -profiling::ProfilingService& GetProfilingService(armnn::Runtime* runtime) +profiling::ProfilingService& GetProfilingService(armnn::RuntimeImpl* runtime) { return runtime->m_ProfilingService; } diff --git a/src/armnn/test/TestUtils.hpp b/src/armnn/test/TestUtils.hpp index 9c5f672a5a..bf222b3c56 100644 --- a/src/armnn/test/TestUtils.hpp +++ b/src/armnn/test/TestUtils.hpp @@ -52,6 +52,6 @@ bool CheckRelatedLayers(armnn::Graph& graph, const std::list& testR namespace armnn { -profiling::ProfilingService& GetProfilingService(armnn::Runtime* runtime); +profiling::ProfilingService& GetProfilingService(RuntimeImpl* runtime); } // namespace armnn \ No newline at end of file -- cgit v1.2.1