diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 37 | ||||
-rw-r--r-- | src/armnn/LoadedNetwork.hpp | 2 | ||||
-rw-r--r-- | src/armnn/Runtime.cpp | 23 | ||||
-rw-r--r-- | src/armnn/Runtime.hpp | 16 | ||||
-rw-r--r-- | src/armnn/test/RuntimeTests.cpp | 9 |
5 files changed, 83 insertions, 4 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index f3d742c515..9d181e535a 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -263,6 +263,43 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net, } } +void LoadedNetwork::SendNetworkStructure() +{ + Graph& order = m_OptimizedNetwork->GetGraph().TopologicalSort(); + ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid(); + + std::unique_ptr<TimelineUtilityMethods> timelineUtils = + TimelineUtilityMethods::GetTimelineUtils(m_ProfilingService); + + timelineUtils->CreateTypedEntity(networkGuid, LabelsAndEventClasses::NETWORK_GUID); + + for (auto&& layer : order) + { + // Add layer to the post-optimisation network structure + AddLayerStructure(timelineUtils, *layer, networkGuid); + switch (layer->GetType()) + { + case LayerType::Input: + case LayerType::Output: + { + // Inputs and outputs are treated in a special way - see EnqueueInput() and EnqueueOutput(). + break; + } + default: + { + for (auto& workload : m_WorkloadQueue) + { + // Add workload to the post-optimisation network structure + AddWorkloadStructure(timelineUtils, workload, *layer); + } + break; + } + } + } + // Commit to send the post-optimisation network structure + timelineUtils->Commit(); +} + TensorInfo LoadedNetwork::GetInputTensorInfo(LayerBindingId layerId) const { for (auto&& inputLayer : m_OptimizedNetwork->GetGraph().GetInputLayers()) diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index 01e3442508..91379d78ed 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -56,6 +56,8 @@ public: void RegisterDebugCallback(const DebugCallbackFunction& func); + void SendNetworkStructure(); + private: void AllocateWorkingMemory(); diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 26636a81f7..dfcbf852e0 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -152,11 +152,32 @@ const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const return nullptr; } +void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param +{ + // No-op for the time being, but this may be useful in future to have the profilingService available + // if (profilingService.IsProfilingEnabled()){} + + LoadedNetworks::iterator it = m_LoadedNetworks.begin(); + while (it != m_LoadedNetworks.end()) + { + auto& loadedNetwork = it->second; + loadedNetwork->SendNetworkStructure(); + // Increment the Iterator to point to next entry + it++; + } +} + Runtime::Runtime(const CreationOptions& options) - : m_NetworkIdCounter(0) + : m_NetworkIdCounter(0), + m_ProfilingService(*this) { ARMNN_LOG(info) << "ArmNN v" << ARMNN_VERSION << "\n"; + if ( options.m_ProfilingOptions.m_TimelineEnabled && !options.m_ProfilingOptions.m_EnableProfiling ) + { + throw RuntimeException("It is not possible to enable timeline reporting without profiling being enabled"); + } + // pass configuration info to the profiling service m_ProfilingService.ConfigureProfilingService(options.m_ProfilingOptions); diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp index 477b1169b1..d4b6dcbd63 100644 --- a/src/armnn/Runtime.hpp +++ b/src/armnn/Runtime.hpp @@ -16,13 +16,19 @@ #include <ProfilingService.hpp> +#include <IProfilingService.hpp> +#include <IReportStructure.hpp> + #include <mutex> #include <unordered_map> namespace armnn { +using LoadedNetworks = std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>>; +using IReportStructure = profiling::IReportStructure; -class Runtime final : public IRuntime +class Runtime final : public IRuntime, + public IReportStructure { public: /// Loads a complete network into the Runtime. @@ -79,6 +85,10 @@ public: ~Runtime(); + //NOTE: we won't need the profiling service reference but it is good to pass the service + // in this way to facilitate other implementations down the road + virtual void ReportStructure() override; + private: friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp @@ -104,7 +114,9 @@ private: mutable std::mutex m_Mutex; - std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks; + /// Map of Loaded Networks with associated GUID as key + LoadedNetworks m_LoadedNetworks; + std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts; int m_NetworkIdCounter; diff --git a/src/armnn/test/RuntimeTests.cpp b/src/armnn/test/RuntimeTests.cpp index c1150231bd..2fc4b50a54 100644 --- a/src/armnn/test/RuntimeTests.cpp +++ b/src/armnn/test/RuntimeTests.cpp @@ -371,7 +371,15 @@ BOOST_AUTO_TEST_CASE(ProfilingEnableCpuRef) // Create runtime in which the test will run armnn::IRuntime::CreationOptions options; options.m_ProfilingOptions.m_EnableProfiling = true; + options.m_ProfilingOptions.m_TimelineEnabled = true; + armnn::Runtime runtime(options); + GetProfilingService(&runtime).ResetExternalProfilingOptions(options.m_ProfilingOptions, false); + + profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); + profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected); + profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck); + profilingServiceHelper.ForceTransitionToState(ProfilingState::Active); // build up the structure of the network INetworkPtr net(INetwork::Create()); @@ -399,7 +407,6 @@ BOOST_AUTO_TEST_CASE(ProfilingEnableCpuRef) armnn::NetworkId netId; BOOST_TEST(runtime.LoadNetwork(netId, std::move(optNet)) == Status::Success); - profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager(); auto readableBuffer = bufferManager.GetReadableBuffer(); |