From e155bbf0e9be6b4d7974297585a59207cd89b00a Mon Sep 17 00:00:00 2001 From: Derek Lamberti Date: Wed, 13 Oct 2021 14:32:12 +0100 Subject: Refactor: Profiler moved to Graph * This is to enable later work to instrument the Optimizer. Signed-off-by: Derek Lamberti Change-Id: I2cf1fe022e0d100d6d8705adfbb8cab3ffc96a86 --- include/armnn/INetwork.hpp | 3 +++ src/armnn/Graph.cpp | 6 ++++++ src/armnn/Graph.hpp | 6 ++++++ src/armnn/LoadedNetwork.cpp | 10 +++++----- src/armnn/LoadedNetwork.hpp | 3 +-- src/armnn/Network.cpp | 5 +++++ src/armnn/Runtime.cpp | 7 +++++++ 7 files changed, 33 insertions(+), 7 deletions(-) diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 5027818623..ab92f05112 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -715,6 +715,7 @@ class WorkingMemHandle; struct BackendSettings; struct OptimizationResult; class OptimizedNetworkImpl; +class IProfiler; class IOptimizedNetwork { public: @@ -732,6 +733,8 @@ public: IOptimizedNetwork(std::unique_ptr impl); ~IOptimizedNetwork(); + const std::shared_ptr& GetProfiler() const; + protected: friend class LoadedNetwork; diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp index ebfc829340..30639b12e8 100644 --- a/src/armnn/Graph.cpp +++ b/src/armnn/Graph.cpp @@ -27,6 +27,7 @@ namespace armnn Graph::Graph(const Graph& other) : m_LayersInOrder(other.m_LayersInOrder) +, m_Profiler(other.m_Profiler) { std::unordered_map otherToClonedMap; @@ -636,4 +637,9 @@ void Graph::ConstructErrorMessageForUnconnectedInputs(Layer* const layer, throw LayerValidationException(message.str()); } +const std::shared_ptr& Graph::GetProfiler() const +{ + return m_Profiler; +} + } // namespace armnn diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index e2321bb0e4..74aefb23ee 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -6,6 +6,7 @@ #include "LayersFwd.hpp" #include "IGraphObservable.hpp" +#include "Profiling.hpp" #include #include @@ -96,6 +97,7 @@ public: : m_LayersInOrder(true) , m_ShapeInferenceMethod(shapeInferenceMethod ? ShapeInferenceMethod::InferAndValidate : ShapeInferenceMethod::ValidateOnly) + , m_Profiler(std::make_shared()) {} Graph(const Graph& other); @@ -113,6 +115,7 @@ public: m_OutputIds = std::move(other.m_OutputIds); m_LayersInOrder = std::move(other.m_LayersInOrder); m_Views = std::move(other.m_Views); + m_Profiler = std::move(other.m_Profiler); other.ForEachLayer([this](Layer* otherLayer) { @@ -220,6 +223,8 @@ public: /// Gets the position of a layer in the graph. Iterator GetPosInGraph(Layer& layer); + const std::shared_ptr& GetProfiler() const; + private: template class LayerInGraphBase; @@ -268,6 +273,7 @@ private: std::map> m_Views; ShapeInferenceMethod m_ShapeInferenceMethod; + std::shared_ptr m_Profiler; // Throws exception due to a layer input not being connected to an output slot. /// Also verifies weights and bias are set for FullyConnected layers. diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index c161ed35d5..4688b6eea4 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -122,13 +122,13 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr net, m_TensorHandleFactoryRegistry(), m_ProfilingService(profilingService) { - // Create a profiler and register it for the current thread. - m_Profiler = std::make_shared(); - ProfilerManager::GetInstance().RegisterProfiler(m_Profiler.get()); + // Get the profiler and register it for the current thread. + const std::shared_ptr& profiler = m_OptimizedNetwork->GetProfiler(); + ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); - m_Profiler->EnableProfiling(networkProperties.m_ProfilingEnabled); + profiler->EnableProfiling(networkProperties.m_ProfilingEnabled); - m_Profiler->EnableNetworkDetailsToStdOut(networkProperties.m_OutputNetworkDetailsMethod); + profiler->EnableNetworkDetailsToStdOut(networkProperties.m_OutputNetworkDetailsMethod); Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort(); //First create tensor handlers, backends and workload factories. diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index 99dac556ae..71ceaa3938 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -73,7 +73,7 @@ public: // NOTE we return by reference as the purpose of this method is only to provide // access to the private m_Profiler and in theory we should not need to increment // the shared_ptr's reference counter - const std::shared_ptr& GetProfiler() const { return m_Profiler; } + const std::shared_ptr& GetProfiler() const { return m_OptimizedNetwork->GetProfiler(); } void FreeWorkingMemory(); @@ -126,7 +126,6 @@ private: WorkloadFactoryMap m_WorkloadFactories; std::unique_ptr m_OptimizedNetwork; - std::shared_ptr m_Profiler; WorkloadQueue m_InputQueue; WorkloadQueue m_WorkloadQueue; diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 4298b05528..99d7b96ec2 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -516,6 +516,11 @@ Status IOptimizedNetwork::SerializeToDot(std::ostream& stream) const return pOptimizedNetworkImpl->SerializeToDot(stream); } +const std::shared_ptr& IOptimizedNetwork::GetProfiler() const +{ + return pOptimizedNetworkImpl->GetGraph().GetProfiler(); +} + profiling::ProfilingGuid IOptimizedNetwork::GetGuid() const { return pOptimizedNetworkImpl->GetGuid(); diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index a54b71225d..ca28199aae 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -162,6 +162,10 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, std::string& errorMessage, const INetworkProperties& networkProperties) { + // Register the profiler + auto profiler = inNetwork->GetProfiler(); + ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); + IOptimizedNetwork* rawNetwork = inNetwork.release(); networkIdOut = GenerateNetworkId(); @@ -250,6 +254,9 @@ Status RuntimeImpl::UnloadNetwork(NetworkId networkId) context.second->AfterUnloadNetwork(networkId); } + // Unregister the profiler + ProfilerManager::GetInstance().RegisterProfiler(nullptr); + ARMNN_LOG(debug) << "RuntimeImpl::UnloadNetwork(): Unloaded network with ID: " << networkId; return Status::Success; } -- cgit v1.2.1