From 33ed221e2e8e3a77b5f543061e0cce07b259fc64 Mon Sep 17 00:00:00 2001 From: Keith Davis Date: Mon, 30 Mar 2020 10:43:41 +0100 Subject: IVGCVSW-4455 Add an Activate and Deactivate control packet to the protocol * Add Activate/Deactivate command handlers * Add IReportStructure, INotifyBackends single function interfaces * Add overrided mechanism to report structure in Runtime.cpp * Add overrided mechanism to notify backends in ProfilingService.cpp * Add optional IReportStructure argument to ProfilingService constructor for use in ActivateTimelineReportingCommandHandler * Refactoring and tidying up indentation * Removal of unused code in ProfilingUtils.cpp and ProfilingService.cpp * Added GatordMock end to end test * Fixed an issue with SendCounterPacket sending duplicate packets * Fixed an issue with DirectoryCaptureCommandHandler handling of Optional Signed-off-by: Keith Davis Signed-off-by: Finn Williams Change-Id: I5ef1b74171459bfc649861dedf99921d22c9e63f --- Android.mk | 2 + CMakeLists.txt | 6 + include/armnn/IRuntime.hpp | 2 + .../profiling/IBackendProfilingContext.hpp | 1 + src/armnn/LoadedNetwork.cpp | 37 ++++ src/armnn/LoadedNetwork.hpp | 2 + src/armnn/Runtime.cpp | 23 ++- src/armnn/Runtime.hpp | 16 +- src/armnn/test/RuntimeTests.cpp | 9 +- src/backends/backendsCommon/test/MockBackend.cpp | 1 - src/backends/backendsCommon/test/MockBackend.hpp | 13 ++ .../ActivateTimelineReportingCommandHandler.cpp | 63 ++++++ .../ActivateTimelineReportingCommandHandler.hpp | 54 +++++ src/profiling/CommandHandler.hpp | 30 +-- .../DeactivateTimelineReportingCommandHandler.cpp | 53 +++++ .../DeactivateTimelineReportingCommandHandler.hpp | 45 ++++ src/profiling/DirectoryCaptureCommandHandler.cpp | 3 +- src/profiling/DirectoryCaptureCommandHandler.hpp | 2 +- src/profiling/INotifyBackends.hpp | 24 +++ src/profiling/IReportStructure.hpp | 24 +++ src/profiling/PacketVersionResolver.cpp | 13 +- src/profiling/PacketVersionResolver.hpp | 3 + src/profiling/ProfilingService.cpp | 16 +- src/profiling/ProfilingService.hpp | 104 +++++++--- src/profiling/ProfilingUtils.cpp | 13 +- src/profiling/ProfilingUtils.hpp | 2 - src/profiling/SendCounterPacket.cpp | 41 +++- src/profiling/TimelineUtilityMethods.cpp | 2 +- src/profiling/test/ProfilingMocks.hpp | 15 +- src/profiling/test/ProfilingTestUtils.cpp | 8 +- src/profiling/test/ProfilingTestUtils.hpp | 6 + src/profiling/test/ProfilingTests.cpp | 130 +++++++++++- src/profiling/test/ProfilingTests.hpp | 2 +- src/profiling/test/SendCounterPacketTests.cpp | 15 +- src/profiling/test/SendTimelinePacketTests.cpp | 11 +- tests/profiling/gatordmock/GatordMockMain.cpp | 44 +--- tests/profiling/gatordmock/GatordMockService.cpp | 27 ++- tests/profiling/gatordmock/GatordMockService.hpp | 75 ++++++- .../profiling/gatordmock/tests/GatordMockTests.cpp | 228 ++++++++++++--------- 39 files changed, 918 insertions(+), 247 deletions(-) create mode 100644 src/profiling/ActivateTimelineReportingCommandHandler.cpp create mode 100644 src/profiling/ActivateTimelineReportingCommandHandler.hpp create mode 100644 src/profiling/DeactivateTimelineReportingCommandHandler.cpp create mode 100644 src/profiling/DeactivateTimelineReportingCommandHandler.hpp create mode 100644 src/profiling/INotifyBackends.hpp create mode 100644 src/profiling/IReportStructure.hpp diff --git a/Android.mk b/Android.mk index 87b1f9ac1a..6723debbe3 100644 --- a/Android.mk +++ b/Android.mk @@ -184,6 +184,7 @@ LOCAL_SRC_FILES := \ src/armnn/layers/SwitchLayer.cpp \ src/armnn/layers/TransposeConvolution2dLayer.cpp \ src/armnn/layers/TransposeLayer.cpp \ + src/profiling/ActivateTimelineReportingCommandHandler.cpp \ src/profiling/BufferManager.cpp \ src/profiling/CommandHandler.cpp \ src/profiling/CommandHandlerFunctor.cpp \ @@ -193,6 +194,7 @@ LOCAL_SRC_FILES := \ src/profiling/CounterDirectory.cpp \ src/profiling/CounterIdMap.cpp \ src/profiling/DirectoryCaptureCommandHandler.cpp \ + src/profiling/DeactivateTimelineReportingCommandHandler.cpp \ src/profiling/FileOnlyProfilingConnection.cpp \ src/profiling/Holder.cpp \ src/profiling/LabelsAndEventClasses.cpp \ diff --git a/CMakeLists.txt b/CMakeLists.txt index c093344acd..aac18e99cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -462,6 +462,8 @@ list(APPEND armnn_sources src/armnn/optimizations/PermuteAndBatchToSpaceAsDepthToSpace.hpp src/armnn/optimizations/PermuteAsReshape.hpp src/armnn/optimizations/SquashEqualSiblings.hpp + src/profiling/ActivateTimelineReportingCommandHandler.cpp + src/profiling/ActivateTimelineReportingCommandHandler.hpp src/profiling/BufferManager.cpp src/profiling/BufferManager.hpp src/profiling/CommandHandler.cpp @@ -478,6 +480,8 @@ list(APPEND armnn_sources src/profiling/CounterDirectory.hpp src/profiling/CounterIdMap.cpp src/profiling/CounterIdMap.hpp + src/profiling/DeactivateTimelineReportingCommandHandler.cpp + src/profiling/DeactivateTimelineReportingCommandHandler.hpp src/profiling/DirectoryCaptureCommandHandler.cpp src/profiling/DirectoryCaptureCommandHandler.hpp src/profiling/EncodeVersion.hpp @@ -490,6 +494,8 @@ list(APPEND armnn_sources src/profiling/ICounterDirectory.hpp src/profiling/ICounterRegistry.hpp src/profiling/ICounterValues.hpp + src/profiling/INotifyBackends.hpp + src/profiling/IReportStructure.hpp src/profiling/ISendCounterPacket.hpp src/profiling/ISendThread.hpp src/profiling/IPacketBuffer.hpp diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp index 8391ed3b15..06d249ea8c 100644 --- a/include/armnn/IRuntime.hpp +++ b/include/armnn/IRuntime.hpp @@ -66,6 +66,7 @@ public: , m_FileOnly(false) , m_CapturePeriod(LOWEST_CAPTURE_PERIOD) , m_FileFormat("binary") + , m_TimelineEnabled(false) {} bool m_EnableProfiling; @@ -74,6 +75,7 @@ public: bool m_FileOnly; uint32_t m_CapturePeriod; std::string m_FileFormat; + bool m_TimelineEnabled; }; ExternalProfilingOptions m_ProfilingOptions; diff --git a/include/armnn/backends/profiling/IBackendProfilingContext.hpp b/include/armnn/backends/profiling/IBackendProfilingContext.hpp index 063ebc946d..77959e959b 100644 --- a/include/armnn/backends/profiling/IBackendProfilingContext.hpp +++ b/include/armnn/backends/profiling/IBackendProfilingContext.hpp @@ -22,6 +22,7 @@ public: virtual Optional ActivateCounters(uint32_t capturePeriod, const std::vector& counterIds) = 0; virtual std::vector ReportCounterValues() = 0; virtual bool EnableProfiling(bool flag) = 0; + virtual bool EnableTimelineReporting(bool flag) = 0; }; using IBackendProfilingContextUniquePtr = std::unique_ptr; diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index f3d742c515..9d181e535a 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -263,6 +263,43 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr net, } } +void LoadedNetwork::SendNetworkStructure() +{ + Graph& order = m_OptimizedNetwork->GetGraph().TopologicalSort(); + ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid(); + + std::unique_ptr timelineUtils = + TimelineUtilityMethods::GetTimelineUtils(m_ProfilingService); + + timelineUtils->CreateTypedEntity(networkGuid, LabelsAndEventClasses::NETWORK_GUID); + + for (auto&& layer : order) + { + // Add layer to the post-optimisation network structure + AddLayerStructure(timelineUtils, *layer, networkGuid); + switch (layer->GetType()) + { + case LayerType::Input: + case LayerType::Output: + { + // Inputs and outputs are treated in a special way - see EnqueueInput() and EnqueueOutput(). + break; + } + default: + { + for (auto& workload : m_WorkloadQueue) + { + // Add workload to the post-optimisation network structure + AddWorkloadStructure(timelineUtils, workload, *layer); + } + break; + } + } + } + // Commit to send the post-optimisation network structure + timelineUtils->Commit(); +} + TensorInfo LoadedNetwork::GetInputTensorInfo(LayerBindingId layerId) const { for (auto&& inputLayer : m_OptimizedNetwork->GetGraph().GetInputLayers()) diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index 01e3442508..91379d78ed 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -56,6 +56,8 @@ public: void RegisterDebugCallback(const DebugCallbackFunction& func); + void SendNetworkStructure(); + private: void AllocateWorkingMemory(); diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 26636a81f7..dfcbf852e0 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -152,11 +152,32 @@ const std::shared_ptr Runtime::GetProfiler(NetworkId networkId) const return nullptr; } +void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param +{ + // No-op for the time being, but this may be useful in future to have the profilingService available + // if (profilingService.IsProfilingEnabled()){} + + LoadedNetworks::iterator it = m_LoadedNetworks.begin(); + while (it != m_LoadedNetworks.end()) + { + auto& loadedNetwork = it->second; + loadedNetwork->SendNetworkStructure(); + // Increment the Iterator to point to next entry + it++; + } +} + Runtime::Runtime(const CreationOptions& options) - : m_NetworkIdCounter(0) + : m_NetworkIdCounter(0), + m_ProfilingService(*this) { ARMNN_LOG(info) << "ArmNN v" << ARMNN_VERSION << "\n"; + if ( options.m_ProfilingOptions.m_TimelineEnabled && !options.m_ProfilingOptions.m_EnableProfiling ) + { + throw RuntimeException("It is not possible to enable timeline reporting without profiling being enabled"); + } + // pass configuration info to the profiling service m_ProfilingService.ConfigureProfilingService(options.m_ProfilingOptions); diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp index 477b1169b1..d4b6dcbd63 100644 --- a/src/armnn/Runtime.hpp +++ b/src/armnn/Runtime.hpp @@ -16,13 +16,19 @@ #include +#include +#include + #include #include namespace armnn { +using LoadedNetworks = std::unordered_map>; +using IReportStructure = profiling::IReportStructure; -class Runtime final : public IRuntime +class Runtime final : public IRuntime, + public IReportStructure { public: /// Loads a complete network into the Runtime. @@ -79,6 +85,10 @@ public: ~Runtime(); + //NOTE: we won't need the profiling service reference but it is good to pass the service + // in this way to facilitate other implementations down the road + virtual void ReportStructure() override; + private: friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp @@ -104,7 +114,9 @@ private: mutable std::mutex m_Mutex; - std::unordered_map> m_LoadedNetworks; + /// Map of Loaded Networks with associated GUID as key + LoadedNetworks m_LoadedNetworks; + std::unordered_map m_BackendContexts; int m_NetworkIdCounter; diff --git a/src/armnn/test/RuntimeTests.cpp b/src/armnn/test/RuntimeTests.cpp index c1150231bd..2fc4b50a54 100644 --- a/src/armnn/test/RuntimeTests.cpp +++ b/src/armnn/test/RuntimeTests.cpp @@ -371,7 +371,15 @@ BOOST_AUTO_TEST_CASE(ProfilingEnableCpuRef) // Create runtime in which the test will run armnn::IRuntime::CreationOptions options; options.m_ProfilingOptions.m_EnableProfiling = true; + options.m_ProfilingOptions.m_TimelineEnabled = true; + armnn::Runtime runtime(options); + GetProfilingService(&runtime).ResetExternalProfilingOptions(options.m_ProfilingOptions, false); + + profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); + profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected); + profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck); + profilingServiceHelper.ForceTransitionToState(ProfilingState::Active); // build up the structure of the network INetworkPtr net(INetwork::Create()); @@ -399,7 +407,6 @@ BOOST_AUTO_TEST_CASE(ProfilingEnableCpuRef) armnn::NetworkId netId; BOOST_TEST(runtime.LoadNetwork(netId, std::move(optNet)) == Status::Success); - profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager(); auto readableBuffer = bufferManager.GetReadableBuffer(); diff --git a/src/backends/backendsCommon/test/MockBackend.cpp b/src/backends/backendsCommon/test/MockBackend.cpp index 8d40117741..116bf77c63 100644 --- a/src/backends/backendsCommon/test/MockBackend.cpp +++ b/src/backends/backendsCommon/test/MockBackend.cpp @@ -5,7 +5,6 @@ #include "MockBackend.hpp" #include "MockBackendId.hpp" -#include "armnn/backends/profiling/IBackendProfilingContext.hpp" #include diff --git a/src/backends/backendsCommon/test/MockBackend.hpp b/src/backends/backendsCommon/test/MockBackend.hpp index 6e415b9b52..e1570ff920 100644 --- a/src/backends/backendsCommon/test/MockBackend.hpp +++ b/src/backends/backendsCommon/test/MockBackend.hpp @@ -32,6 +32,7 @@ public: MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling) : m_BackendProfiling(std::move(backendProfiling)) , m_CapturePeriod(0) + , m_IsTimelineEnabled(true) {} ~MockBackendProfilingContext() = default; @@ -93,10 +94,22 @@ public: return true; } + bool EnableTimelineReporting(bool isEnabled) + { + m_IsTimelineEnabled = isEnabled; + return isEnabled; + } + + bool TimelineReportingEnabled() + { + return m_IsTimelineEnabled; + } + private: IBackendInternal::IBackendProfilingPtr m_BackendProfiling; uint32_t m_CapturePeriod; std::vector m_ActiveCounters; + bool m_IsTimelineEnabled; }; class MockBackendProfilingService diff --git a/src/profiling/ActivateTimelineReportingCommandHandler.cpp b/src/profiling/ActivateTimelineReportingCommandHandler.cpp new file mode 100644 index 0000000000..d762efc277 --- /dev/null +++ b/src/profiling/ActivateTimelineReportingCommandHandler.cpp @@ -0,0 +1,63 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ActivateTimelineReportingCommandHandler.hpp" +#include "TimelineUtilityMethods.hpp" + +#include +#include + +namespace armnn +{ + +namespace profiling +{ + +void ActivateTimelineReportingCommandHandler::operator()(const Packet& packet) +{ + ProfilingState currentState = m_StateMachine.GetCurrentState(); + + if (!m_ReportStructure.has_value()) + { + throw armnn::Exception(std::string("Profiling Service constructor must be initialised with an " + "IReportStructure argument in order to run timeline reporting")); + } + + switch ( currentState ) + { + case ProfilingState::Uninitialised: + case ProfilingState::NotConnected: + case ProfilingState::WaitingForAck: + throw RuntimeException(boost::str( + boost::format("Activate Timeline Reporting Command Handler invoked while in a wrong state: %1%") + % GetProfilingStateName(currentState))); + case ProfilingState::Active: + if ( !( packet.GetPacketFamily() == 0u && packet.GetPacketId() == 6u )) + { + throw armnn::Exception(std::string("Expected Packet family = 0, id = 6 but received family =") + + std::to_string(packet.GetPacketFamily()) + + " id = " + std::to_string(packet.GetPacketId())); + } + + m_SendTimelinePacket.SendTimelineMessageDirectoryPackage(); + + TimelineUtilityMethods::SendWellKnownLabelsAndEventClasses(m_SendTimelinePacket); + + m_TimelineReporting = true; + + m_ReportStructure.value().ReportStructure(); + + m_BackendNotifier.NotifyBackendsForTimelineReporting(); + + break; + default: + throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%") + % static_cast(currentState))); + } +} + +} // namespace profiling + +} // namespace armnn \ No newline at end of file diff --git a/src/profiling/ActivateTimelineReportingCommandHandler.hpp b/src/profiling/ActivateTimelineReportingCommandHandler.hpp new file mode 100644 index 0000000000..ac11b46379 --- /dev/null +++ b/src/profiling/ActivateTimelineReportingCommandHandler.hpp @@ -0,0 +1,54 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "CommandHandlerFunctor.hpp" +#include "ProfilingStateMachine.hpp" +#include "Packet.hpp" +#include "SendTimelinePacket.hpp" +#include "IReportStructure.hpp" +#include "armnn/Optional.hpp" +#include "INotifyBackends.hpp" + +namespace armnn +{ + +namespace profiling +{ + +class ActivateTimelineReportingCommandHandler : public CommandHandlerFunctor +{ +public: + ActivateTimelineReportingCommandHandler(uint32_t familyId, + uint32_t packetId, + uint32_t version, + SendTimelinePacket& sendTimelinePacket, + ProfilingStateMachine& profilingStateMachine, + Optional reportStructure, + std::atomic& timelineReporting, + INotifyBackends& notifyBackends) + : CommandHandlerFunctor(familyId, packetId, version), + m_SendTimelinePacket(sendTimelinePacket), + m_StateMachine(profilingStateMachine), + m_TimelineReporting(timelineReporting), + m_BackendNotifier(notifyBackends), + m_ReportStructure(reportStructure) + {} + + void operator()(const Packet& packet) override; + +private: + SendTimelinePacket& m_SendTimelinePacket; + ProfilingStateMachine& m_StateMachine; + std::atomic& m_TimelineReporting; + INotifyBackends& m_BackendNotifier; + + Optional m_ReportStructure; +}; + +} // namespace profiling + +} // namespace armnn \ No newline at end of file diff --git a/src/profiling/CommandHandler.hpp b/src/profiling/CommandHandler.hpp index 0cc23429cd..4bf820c5db 100644 --- a/src/profiling/CommandHandler.hpp +++ b/src/profiling/CommandHandler.hpp @@ -22,16 +22,16 @@ class CommandHandler { public: CommandHandler(uint32_t timeout, - bool stopAfterTimeout, - CommandHandlerRegistry& commandHandlerRegistry, - PacketVersionResolver& packetVersionResolver) - : m_Timeout(timeout) - , m_StopAfterTimeout(stopAfterTimeout) - , m_IsRunning(false) - , m_KeepRunning(false) - , m_CommandThread() - , m_CommandHandlerRegistry(commandHandlerRegistry) - , m_PacketVersionResolver(packetVersionResolver) + bool stopAfterTimeout, + CommandHandlerRegistry& commandHandlerRegistry, + PacketVersionResolver& packetVersionResolver) + : m_Timeout(timeout), + m_StopAfterTimeout(stopAfterTimeout), + m_IsRunning(false), + m_KeepRunning(false), + m_CommandThread(), + m_CommandHandlerRegistry(commandHandlerRegistry), + m_PacketVersionResolver(packetVersionResolver) {} ~CommandHandler() { Stop(); } @@ -46,13 +46,13 @@ private: void HandleCommands(IProfilingConnection& profilingConnection); std::atomic m_Timeout; - std::atomic m_StopAfterTimeout; - std::atomic m_IsRunning; - std::atomic m_KeepRunning; - std::thread m_CommandThread; + std::atomic m_StopAfterTimeout; + std::atomic m_IsRunning; + std::atomic m_KeepRunning; + std::thread m_CommandThread; CommandHandlerRegistry& m_CommandHandlerRegistry; - PacketVersionResolver& m_PacketVersionResolver; + PacketVersionResolver& m_PacketVersionResolver; }; } // namespace profiling diff --git a/src/profiling/DeactivateTimelineReportingCommandHandler.cpp b/src/profiling/DeactivateTimelineReportingCommandHandler.cpp new file mode 100644 index 0000000000..dbfb053b3d --- /dev/null +++ b/src/profiling/DeactivateTimelineReportingCommandHandler.cpp @@ -0,0 +1,53 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "DeactivateTimelineReportingCommandHandler.hpp" + +#include +#include + + +namespace armnn +{ + +namespace profiling +{ + +void DeactivateTimelineReportingCommandHandler::operator()(const Packet& packet) +{ + ProfilingState currentState = m_StateMachine.GetCurrentState(); + + switch ( currentState ) + { + case ProfilingState::Uninitialised: + case ProfilingState::NotConnected: + case ProfilingState::WaitingForAck: + throw RuntimeException(boost::str( + boost::format("Deactivate Timeline Reporting Command Handler invoked while in a wrong state: %1%") + % GetProfilingStateName(currentState))); + case ProfilingState::Active: + if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 7u)) + { + throw armnn::Exception(std::string("Expected Packet family = 0, id = 7 but received family =") + + std::to_string(packet.GetPacketFamily()) + +" id = " + std::to_string(packet.GetPacketId())); + } + + m_TimelineReporting.store(false); + + // Notify Backends + m_BackendNotifier.NotifyBackendsForTimelineReporting(); + + break; + default: + throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%") + % static_cast(currentState))); + } +} + +} // namespace profiling + +} // namespace armnn + diff --git a/src/profiling/DeactivateTimelineReportingCommandHandler.hpp b/src/profiling/DeactivateTimelineReportingCommandHandler.hpp new file mode 100644 index 0000000000..e06bae836f --- /dev/null +++ b/src/profiling/DeactivateTimelineReportingCommandHandler.hpp @@ -0,0 +1,45 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "CommandHandlerFunctor.hpp" +#include "Packet.hpp" +#include "ProfilingStateMachine.hpp" +#include "INotifyBackends.hpp" + +namespace armnn +{ + +namespace profiling +{ + +class DeactivateTimelineReportingCommandHandler : public CommandHandlerFunctor +{ + +public: + DeactivateTimelineReportingCommandHandler(uint32_t familyId, + uint32_t packetId, + uint32_t version, + std::atomic& timelineReporting, + ProfilingStateMachine& profilingStateMachine, + INotifyBackends& notifyBackends) + : CommandHandlerFunctor(familyId, packetId, version) + , m_TimelineReporting(timelineReporting) + , m_StateMachine(profilingStateMachine) + , m_BackendNotifier(notifyBackends) + {} + + void operator()(const Packet& packet) override; + +private: + std::atomic& m_TimelineReporting; + ProfilingStateMachine& m_StateMachine; + INotifyBackends& m_BackendNotifier; +}; + +} // namespace profiling + +} // namespace armnn \ No newline at end of file diff --git a/src/profiling/DirectoryCaptureCommandHandler.cpp b/src/profiling/DirectoryCaptureCommandHandler.cpp index 65cac848ae..93cdde736e 100644 --- a/src/profiling/DirectoryCaptureCommandHandler.cpp +++ b/src/profiling/DirectoryCaptureCommandHandler.cpp @@ -281,7 +281,8 @@ std::vector DirectoryCaptureCommandHandler::ReadEve eventRecords[i].m_CounterDescription = GetStringNameFromBuffer(data, eventRecordOffset + descriptionOffset); - eventRecords[i].m_CounterUnits = GetStringNameFromBuffer(data, eventRecordOffset + unitsOffset); + eventRecords[i].m_CounterUnits = unitsOffset == 0 ? Optional() : + GetStringNameFromBuffer(data, eventRecordOffset + unitsOffset); } return eventRecords; diff --git a/src/profiling/DirectoryCaptureCommandHandler.hpp b/src/profiling/DirectoryCaptureCommandHandler.hpp index 03bbb1eb09..6b25714168 100644 --- a/src/profiling/DirectoryCaptureCommandHandler.hpp +++ b/src/profiling/DirectoryCaptureCommandHandler.hpp @@ -25,7 +25,7 @@ struct CounterDirectoryEventRecord std::string m_CounterName; uint16_t m_CounterSetUid; uint16_t m_CounterUid; - std::string m_CounterUnits; + Optional m_CounterUnits; uint16_t m_DeviceUid; uint16_t m_MaxCounterUid; }; diff --git a/src/profiling/INotifyBackends.hpp b/src/profiling/INotifyBackends.hpp new file mode 100644 index 0000000000..217ebdef1c --- /dev/null +++ b/src/profiling/INotifyBackends.hpp @@ -0,0 +1,24 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +namespace armnn +{ + +namespace profiling +{ + +class INotifyBackends +{ +public: + virtual ~INotifyBackends() {} + virtual void NotifyBackendsForTimelineReporting() = 0; +}; + +} // namespace profiling + +} // namespace armnn + diff --git a/src/profiling/IReportStructure.hpp b/src/profiling/IReportStructure.hpp new file mode 100644 index 0000000000..1ae049733f --- /dev/null +++ b/src/profiling/IReportStructure.hpp @@ -0,0 +1,24 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +namespace armnn +{ + +namespace profiling +{ + +class IReportStructure +{ +public: + virtual ~IReportStructure() {} + virtual void ReportStructure() = 0; +}; + +} // namespace profiling + +} // namespace armnn + diff --git a/src/profiling/PacketVersionResolver.cpp b/src/profiling/PacketVersionResolver.cpp index 2c75067487..4178ace166 100644 --- a/src/profiling/PacketVersionResolver.cpp +++ b/src/profiling/PacketVersionResolver.cpp @@ -54,8 +54,17 @@ bool PacketKey::operator!=(const PacketKey& rhs) const Version PacketVersionResolver::ResolvePacketVersion(uint32_t familyId, uint32_t packetId) const { - IgnoreUnused(familyId, packetId); - // NOTE: For now every packet specification is at version 1.0.0 + const PacketKey packetKey(familyId, packetId); + + if( packetKey == ActivateTimeLinePacket ) + { + return Version(1, 1, 0); + } + if( packetKey == DectivateTimeLinePacket ) + { + return Version(1, 1, 0); + } + return Version(1, 0, 0); } diff --git a/src/profiling/PacketVersionResolver.hpp b/src/profiling/PacketVersionResolver.hpp index e959ed548e..6222eb02e8 100644 --- a/src/profiling/PacketVersionResolver.hpp +++ b/src/profiling/PacketVersionResolver.hpp @@ -33,6 +33,9 @@ private: uint32_t m_PacketId; }; +static const PacketKey ActivateTimeLinePacket(0 , 6); +static const PacketKey DectivateTimeLinePacket(0 , 7); + class PacketVersionResolver final { public: diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp index e42ef134dc..3a8f3f83a3 100644 --- a/src/profiling/ProfilingService.cpp +++ b/src/profiling/ProfilingService.cpp @@ -34,6 +34,7 @@ void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOpti { // Update the profiling options m_Options = options; + m_TimelineReporting = options.m_TimelineEnabled; // Check if the profiling service needs to be reset if (resetProfilingService) @@ -431,7 +432,7 @@ void ProfilingService::Reset() // ...finally reset the profiling state machine m_StateMachine.Reset(); m_BackendProfilingContexts.clear(); - m_MaxGlobalCounterId = armnn::profiling::INFERENCES_RUN; + m_MaxGlobalCounterId = armnn::profiling::MAX_ARMNN_COUNTER; } void ProfilingService::Stop() @@ -463,11 +464,22 @@ inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const } } +void ProfilingService::NotifyBackendsForTimelineReporting() +{ + BackendProfilingContext::iterator it = m_BackendProfilingContexts.begin(); + while (it != m_BackendProfilingContexts.end()) + { + auto& backendProfilingContext = it->second; + backendProfilingContext->EnableTimelineReporting(m_TimelineReporting); + // Increment the Iterator to point to next entry + it++; + } +} + ProfilingService::~ProfilingService() { Stop(); } - } // namespace profiling } // namespace armnn diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp index f6e409daeb..df7bd8f857 100644 --- a/src/profiling/ProfilingService.hpp +++ b/src/profiling/ProfilingService.hpp @@ -5,14 +5,17 @@ #pragma once +#include "ActivateTimelineReportingCommandHandler.hpp" #include "BufferManager.hpp" #include "CommandHandler.hpp" #include "ConnectionAcknowledgedCommandHandler.hpp" #include "CounterDirectory.hpp" #include "CounterIdMap.hpp" +#include "DeactivateTimelineReportingCommandHandler.hpp" #include "ICounterRegistry.hpp" #include "ICounterValues.hpp" #include "IProfilingService.hpp" +#include "IReportStructure.hpp" #include "PeriodicCounterCapture.hpp" #include "PeriodicCounterSelectionCommandHandler.hpp" #include "PerJobCounterSelectionCommandHandler.hpp" @@ -24,6 +27,7 @@ #include "SendThread.hpp" #include "SendTimelinePacket.hpp" #include "TimelinePacketWriterFactory.hpp" +#include "INotifyBackends.hpp" #include namespace armnn @@ -32,14 +36,14 @@ namespace armnn namespace profiling { // Static constants describing ArmNN's counter UID's -static const uint16_t NETWORK_LOADS = 0; -static const uint16_t NETWORK_UNLOADS = 1; -static const uint16_t REGISTERED_BACKENDS = 2; -static const uint16_t UNREGISTERED_BACKENDS = 3; -static const uint16_t INFERENCES_RUN = 4; -static const uint16_t MAX_ARMNN_COUNTER = INFERENCES_RUN; - -class ProfilingService : public IReadWriteCounterValues, public IProfilingService +static const uint16_t NETWORK_LOADS = 0; +static const uint16_t NETWORK_UNLOADS = 1; +static const uint16_t REGISTERED_BACKENDS = 2; +static const uint16_t UNREGISTERED_BACKENDS = 3; +static const uint16_t INFERENCES_RUN = 4; +static const uint16_t MAX_ARMNN_COUNTER = INFERENCES_RUN; + +class ProfilingService : public IReadWriteCounterValues, public IProfilingService, public INotifyBackends { public: using ExternalProfilingOptions = IRuntime::CreationOptions::ExternalProfilingOptions; @@ -47,10 +51,12 @@ public: using IProfilingConnectionPtr = std::unique_ptr; using CounterIndices = std::vector*>; using CounterValues = std::list>; + using BackendProfilingContext = std::unordered_map>; - // Default constructor/destructor kept protected for testing - ProfilingService() + ProfilingService(Optional reportStructure = EmptyOptional()) : m_Options() + , m_TimelineReporting(false) , m_CounterDirectory() , m_ProfilingConnectionFactory(new ProfilingConnectionFactory()) , m_ProfilingConnection() @@ -97,6 +103,22 @@ public: 5, m_PacketVersionResolver.ResolvePacketVersion(0, 5).GetEncodedValue(), m_StateMachine) + , m_ActivateTimelineReportingCommandHandler(0, + 6, + m_PacketVersionResolver.ResolvePacketVersion(0, 6) + .GetEncodedValue(), + m_SendTimelinePacket, + m_StateMachine, + reportStructure, + m_TimelineReporting, + *this) + , m_DeactivateTimelineReportingCommandHandler(0, + 7, + m_PacketVersionResolver.ResolvePacketVersion(0, 7) + .GetEncodedValue(), + m_TimelineReporting, + m_StateMachine, + *this) , m_TimelinePacketWriterFactory(m_BufferManager) , m_MaxGlobalCounterId(armnn::profiling::INFERENCES_RUN) { @@ -111,6 +133,10 @@ public: // Register the "Per-Job Counter Selection" command handler m_CommandHandlerRegistry.RegisterFunctor(&m_PerJobCounterSelectionCommandHandler); + + m_CommandHandlerRegistry.RegisterFunctor(&m_ActivateTimelineReportingCommandHandler); + + m_CommandHandlerRegistry.RegisterFunctor(&m_DeactivateTimelineReportingCommandHandler); } ~ProfilingService(); @@ -131,6 +157,9 @@ public: void AddBackendProfilingContext(const BackendId backendId, std::shared_ptr profilingContext); + // Enable the recording of timeline events and entities + void NotifyBackendsForTimelineReporting() override; + const ICounterDirectory& GetCounterDirectory() const; ICounterRegistry& GetCounterRegistry(); ProfilingState GetCurrentState() const; @@ -168,13 +197,15 @@ public: return m_SendCounterPacket; } - /// Check if the profiling is enabled - bool IsEnabled() { return m_Options.m_EnableProfiling; } - static ProfilingDynamicGuid GetNextGuid(); static ProfilingStaticGuid GetStaticId(const std::string& str); + bool IsTimelineReportingEnabled() + { + return m_TimelineReporting; + } + private: //Copy/move constructors/destructors and copy/move assignment operators are deleted ProfilingService(const ProfilingService&) = delete; @@ -192,32 +223,37 @@ private: void CheckCounterUid(uint16_t counterUid) const; // Profiling service components - ExternalProfilingOptions m_Options; - CounterDirectory m_CounterDirectory; - CounterIdMap m_CounterIdMap; + ExternalProfilingOptions m_Options; + std::atomic m_TimelineReporting; + CounterDirectory m_CounterDirectory; + CounterIdMap m_CounterIdMap; IProfilingConnectionFactoryPtr m_ProfilingConnectionFactory; - IProfilingConnectionPtr m_ProfilingConnection; - ProfilingStateMachine m_StateMachine; - CounterIndices m_CounterIndex; - CounterValues m_CounterValues; - CommandHandlerRegistry m_CommandHandlerRegistry; - PacketVersionResolver m_PacketVersionResolver; - CommandHandler m_CommandHandler; - BufferManager m_BufferManager; - SendCounterPacket m_SendCounterPacket; - SendThread m_SendThread; - SendTimelinePacket m_SendTimelinePacket; + IProfilingConnectionPtr m_ProfilingConnection; + ProfilingStateMachine m_StateMachine; + CounterIndices m_CounterIndex; + CounterValues m_CounterValues; + CommandHandlerRegistry m_CommandHandlerRegistry; + PacketVersionResolver m_PacketVersionResolver; + CommandHandler m_CommandHandler; + BufferManager m_BufferManager; + SendCounterPacket m_SendCounterPacket; + SendThread m_SendThread; + SendTimelinePacket m_SendTimelinePacket; + Holder m_Holder; + PeriodicCounterCapture m_PeriodicCounterCapture; - ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler; - RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler; - PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler; - PerJobCounterSelectionCommandHandler m_PerJobCounterSelectionCommandHandler; + + ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler; + RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler; + PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler; + PerJobCounterSelectionCommandHandler m_PerJobCounterSelectionCommandHandler; + ActivateTimelineReportingCommandHandler m_ActivateTimelineReportingCommandHandler; + DeactivateTimelineReportingCommandHandler m_DeactivateTimelineReportingCommandHandler; TimelinePacketWriterFactory m_TimelinePacketWriterFactory; - std::unordered_map> m_BackendProfilingContexts; - uint16_t m_MaxGlobalCounterId; + BackendProfilingContext m_BackendProfilingContexts; + uint16_t m_MaxGlobalCounterId; static ProfilingGuidGenerator m_GuidGenerator; diff --git a/src/profiling/ProfilingUtils.cpp b/src/profiling/ProfilingUtils.cpp index 002eeb9616..e419769600 100644 --- a/src/profiling/ProfilingUtils.cpp +++ b/src/profiling/ProfilingUtils.cpp @@ -96,17 +96,8 @@ void WriteBytes(const IPacketBufferPtr& packetBuffer, unsigned int offset, cons uint32_t ConstructHeader(uint32_t packetFamily, uint32_t packetId) { - return ((packetFamily & 0x3F) << 26)| - ((packetId & 0x3FF) << 16); -} - -uint32_t ConstructHeader(uint32_t packetFamily, - uint32_t packetClass, - uint32_t packetType) -{ - return ((packetFamily & 0x3F) << 26)| - ((packetClass & 0x3FF) << 19)| - ((packetType & 0x3FFF) << 16); + return (( packetFamily & 0x0000003F ) << 26 )| + (( packetId & 0x000003FF ) << 16 ); } void WriteUint64(const std::unique_ptr& packetBuffer, unsigned int offset, uint64_t value) diff --git a/src/profiling/ProfilingUtils.hpp b/src/profiling/ProfilingUtils.hpp index 37ab88cb6f..5888ef0b8c 100644 --- a/src/profiling/ProfilingUtils.hpp +++ b/src/profiling/ProfilingUtils.hpp @@ -146,8 +146,6 @@ void WriteBytes(const IPacketBuffer& packetBuffer, unsigned int offset, const vo uint32_t ConstructHeader(uint32_t packetFamily, uint32_t packetId); -uint32_t ConstructHeader(uint32_t packetFamily, uint32_t packetClass, uint32_t packetType); - void WriteUint64(const IPacketBufferPtr& packetBuffer, unsigned int offset, uint64_t value); void WriteUint32(const IPacketBufferPtr& packetBuffer, unsigned int offset, uint32_t value); diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp index 942ccc7b59..ae4bab91e7 100644 --- a/src/profiling/SendCounterPacket.cpp +++ b/src/profiling/SendCounterPacket.cpp @@ -181,10 +181,34 @@ bool SendCounterPacket::CreateCategoryRecord(const CategoryPtr& category, BOOST_ASSERT(category); const std::string& categoryName = category->m_Name; - const std::vector categoryCounters = category->m_Counters; - BOOST_ASSERT(!categoryName.empty()); + // Remove any duplicate counters + std::vector categoryCounters; + for (size_t counterIndex = 0; counterIndex < category->m_Counters.size(); ++counterIndex) + { + uint16_t counterUid = category->m_Counters.at(counterIndex); + auto it = counters.find(counterUid); + if (it == counters.end()) + { + errorMessage = boost::str(boost::format("Counter (%1%) not found in category (%2%)") + % counterUid % category->m_Name ); + return false; + } + + const CounterPtr& counter = it->second; + + if (counterUid == counter->m_MaxCounterUid) + { + categoryCounters.emplace_back(counterUid); + } + } + if (categoryCounters.empty()) + { + errorMessage = boost::str(boost::format("No valid counters found in category (%1%)")% categoryName); + return false; + } + // Utils size_t uint32_t_size = sizeof(uint32_t); @@ -203,7 +227,7 @@ bool SendCounterPacket::CreateCategoryRecord(const CategoryPtr& category, std::vector categoryNameBuffer; if (!StringToSwTraceString(categoryName, categoryNameBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the name of category \"%1%\" to an SWTrace namestring") + errorMessage = boost::str(boost::format("Cannot convert the name of category (%1%) to an SWTrace namestring") % categoryName); return false; } @@ -221,7 +245,6 @@ bool SendCounterPacket::CreateCategoryRecord(const CategoryPtr& category, { uint16_t counterUid = categoryCounters.at(counterIndex); auto it = counters.find(counterUid); - BOOST_ASSERT(it != counters.end()); const CounterPtr& counter = it->second; EventRecord& eventRecord = eventRecords.at(eventRecordIndex); @@ -299,7 +322,7 @@ bool SendCounterPacket::CreateDeviceRecord(const DevicePtr& device, std::vector deviceNameBuffer; if (!StringToSwTraceString(deviceName, deviceNameBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the name of device %1% (\"%2%\") to an SWTrace string") + errorMessage = boost::str(boost::format("Cannot convert the name of device %1% (%2%) to an SWTrace string") % deviceUid % deviceName); return false; @@ -349,7 +372,7 @@ bool SendCounterPacket::CreateCounterSetRecord(const CounterSetPtr& counterSet, std::vector counterSetNameBuffer; if (!StringToSwTraceString(counterSet->m_Name, counterSetNameBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the name of counter set %1% (\"%2%\") to " + errorMessage = boost::str(boost::format("Cannot convert the name of counter set %1% (%2%) to " "an SWTrace namestring") % counterSetUid % counterSetName); @@ -441,7 +464,7 @@ bool SendCounterPacket::CreateEventRecord(const CounterPtr& counter, std::vector counterNameBuffer; if (!StringToSwTraceString(counterName, counterNameBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the name of counter %1% (name: \"%2%\") " + errorMessage = boost::str(boost::format("Cannot convert the name of counter %1% (name: %2%) " "to an SWTrace string") % counterUid % counterName); @@ -457,7 +480,7 @@ bool SendCounterPacket::CreateEventRecord(const CounterPtr& counter, std::vector counterDescriptionBuffer; if (!StringToSwTraceString(counterDescription, counterDescriptionBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the description of counter %1% (description: \"%2%\") " + errorMessage = boost::str(boost::format("Cannot convert the description of counter %1% (description: %2%) " "to an SWTrace string") % counterUid % counterName); @@ -481,7 +504,7 @@ bool SendCounterPacket::CreateEventRecord(const CounterPtr& counter, // Convert the counter units into a SWTrace namestring if (!StringToSwTraceString(counterUnits, counterUnitsBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the units of counter %1% (units: \"%2%\") " + errorMessage = boost::str(boost::format("Cannot convert the units of counter %1% (units: %2%) " "to an SWTrace string") % counterUid % counterName); diff --git a/src/profiling/TimelineUtilityMethods.cpp b/src/profiling/TimelineUtilityMethods.cpp index 8a39ea7c4a..de30b4d4ef 100644 --- a/src/profiling/TimelineUtilityMethods.cpp +++ b/src/profiling/TimelineUtilityMethods.cpp @@ -14,7 +14,7 @@ namespace profiling std::unique_ptr TimelineUtilityMethods::GetTimelineUtils(ProfilingService& profilingService) { - if (profilingService.IsProfilingEnabled()) + if (profilingService.GetCurrentState() == ProfilingState::Active && profilingService.IsTimelineReportingEnabled()) { std::unique_ptr sendTimelinepacket = profilingService.GetSendTimelinePacket(); return std::make_unique(sendTimelinepacket); diff --git a/src/profiling/test/ProfilingMocks.hpp b/src/profiling/test/ProfilingMocks.hpp index eeb641e878..ada55d8dff 100644 --- a/src/profiling/test/ProfilingMocks.hpp +++ b/src/profiling/test/ProfilingMocks.hpp @@ -51,6 +51,8 @@ public: PerJobCounterSelection, TimelineMessageDirectory, PeriodicCounterCapture, + ActivateTimelineReporting, + DeactivateTimelineReporting, Unknown }; @@ -85,7 +87,7 @@ public: switch (packetFamily) { case 0: - packetType = packetId < 6 ? PacketType(packetId) : PacketType::Unknown; + packetType = packetId < 8 ? PacketType(packetId) : PacketType::Unknown; break; case 1: packetType = packetId == 0 ? PacketType::TimelineMessageDirectory : PacketType::Unknown; @@ -628,7 +630,8 @@ public: const CaptureData& captureData) : m_SendCounterPacket(mockBufferManager), m_IsProfilingEnabled(isProfilingEnabled), - m_CaptureData(captureData) {} + m_CaptureData(captureData) + {} /// Return the next random Guid in the sequence ProfilingDynamicGuid NextGuid() override @@ -682,10 +685,10 @@ public: private: ProfilingGuidGenerator m_GuidGenerator; - CounterIdMap m_CounterMapping; - SendCounterPacket m_SendCounterPacket; - bool m_IsProfilingEnabled; - CaptureData m_CaptureData; + CounterIdMap m_CounterMapping; + SendCounterPacket m_SendCounterPacket; + bool m_IsProfilingEnabled; + CaptureData m_CaptureData; }; } // namespace profiling diff --git a/src/profiling/test/ProfilingTestUtils.cpp b/src/profiling/test/ProfilingTestUtils.cpp index 244051c785..8de69f14ec 100644 --- a/src/profiling/test/ProfilingTestUtils.cpp +++ b/src/profiling/test/ProfilingTestUtils.cpp @@ -297,7 +297,14 @@ void VerifyPostOptimisationStructureTestImpl(armnn::BackendId backendId) // Create runtime in which test will run armnn::IRuntime::CreationOptions options; options.m_ProfilingOptions.m_EnableProfiling = true; + options.m_ProfilingOptions.m_TimelineEnabled = true; armnn::Runtime runtime(options); + GetProfilingService(&runtime).ResetExternalProfilingOptions(options.m_ProfilingOptions, false); + + profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); + profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected); + profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck); + profilingServiceHelper.ForceTransitionToState(ProfilingState::Active); // build up the structure of the network INetworkPtr net(INetwork::Create()); @@ -363,7 +370,6 @@ void VerifyPostOptimisationStructureTestImpl(armnn::BackendId backendId) armnn::NetworkId netId; BOOST_TEST(runtime.LoadNetwork(netId, std::move(optNet)) == Status::Success); - profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager(); auto readableBuffer = bufferManager.GetReadableBuffer(); diff --git a/src/profiling/test/ProfilingTestUtils.hpp b/src/profiling/test/ProfilingTestUtils.hpp index 459d62435b..816ffd3dc6 100644 --- a/src/profiling/test/ProfilingTestUtils.hpp +++ b/src/profiling/test/ProfilingTestUtils.hpp @@ -6,6 +6,7 @@ #pragma once #include "ProfilingUtils.hpp" +#include "Runtime.hpp" #include #include @@ -68,6 +69,11 @@ public: return GetBufferManager(m_ProfilingService); } armnn::profiling::ProfilingService& m_ProfilingService; + + void ForceTransitionToState(ProfilingState newState) + { + TransitionToState(m_ProfilingService, newState); + } }; } // namespace profiling diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 74b93d7c4f..f252579022 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -4,6 +4,7 @@ // #include "ProfilingTests.hpp" +#include "ProfilingTestUtils.hpp" #include #include @@ -1823,6 +1824,132 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) BOOST_TEST(period == armnn::LOWEST_CAPTURE_PERIOD); // capture period } +BOOST_AUTO_TEST_CASE(CheckTimelineActivationAndDeactivation) +{ + class TestReportStructure : public IReportStructure + { + public: + virtual void ReportStructure() override + { + m_ReportStructureCalled = true; + } + + bool m_ReportStructureCalled = false; + }; + + class TestNotifyBackends : public INotifyBackends + { + public: + TestNotifyBackends() : m_timelineReporting(false) {} + virtual void NotifyBackendsForTimelineReporting() override + { + m_TestNotifyBackendsCalled = m_timelineReporting.load(); + } + + bool m_TestNotifyBackendsCalled = false; + std::atomic m_timelineReporting; + }; + + PacketVersionResolver packetVersionResolver; + + BufferManager bufferManager(512); + SendTimelinePacket sendTimelinePacket(bufferManager); + ProfilingStateMachine stateMachine; + TestReportStructure testReportStructure; + TestNotifyBackends testNotifyBackends; + + profiling::ActivateTimelineReportingCommandHandler activateTimelineReportingCommandHandler(0, + 6, + packetVersionResolver.ResolvePacketVersion(0, 6) + .GetEncodedValue(), + sendTimelinePacket, + stateMachine, + testReportStructure, + testNotifyBackends.m_timelineReporting, + testNotifyBackends); + + // Write an "ActivateTimelineReporting" packet into the mock profiling connection, to simulate an input from an + // external profiling service + const uint32_t packetFamily1 = 0; + const uint32_t packetId1 = 6; + uint32_t packetHeader1 = ConstructHeader(packetFamily1, packetId1); + + // Create the ActivateTimelineReportingPacket + Packet ActivateTimelineReportingPacket(packetHeader1); // Length == 0 + + BOOST_CHECK_THROW( + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK_THROW( + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW( + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::Active); + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket); + + BOOST_CHECK(testReportStructure.m_ReportStructureCalled); + BOOST_CHECK(testNotifyBackends.m_TestNotifyBackendsCalled); + BOOST_CHECK(testNotifyBackends.m_timelineReporting.load()); + + DeactivateTimelineReportingCommandHandler deactivateTimelineReportingCommandHandler(0, + 7, + packetVersionResolver.ResolvePacketVersion(0, 7).GetEncodedValue(), + testNotifyBackends.m_timelineReporting, + stateMachine, + testNotifyBackends); + + const uint32_t packetFamily2 = 0; + const uint32_t packetId2 = 7; + uint32_t packetHeader2 = ConstructHeader(packetFamily2, packetId2); + + // Create the DeactivateTimelineReportingPacket + Packet deactivateTimelineReportingPacket(packetHeader2); // Length == 0 + + stateMachine.Reset(); + BOOST_CHECK_THROW( + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK_THROW( + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW( + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::Active); + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket); + + BOOST_CHECK(!testNotifyBackends.m_TestNotifyBackendsCalled); + BOOST_CHECK(!testNotifyBackends.m_timelineReporting.load()); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceNotActive) +{ + using namespace armnn; + using namespace armnn::profiling; + + // Create runtime in which the test will run + armnn::IRuntime::CreationOptions options; + options.m_ProfilingOptions.m_EnableProfiling = true; + + armnn::Runtime runtime(options); + profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); + profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected); + profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck); + profilingServiceHelper.ForceTransitionToState(ProfilingState::Active); + + profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager(); + auto readableBuffer = bufferManager.GetReadableBuffer(); + + // Profiling is enabled, the post-optimisation structure should be created + BOOST_CHECK(readableBuffer == nullptr); +} + BOOST_AUTO_TEST_CASE(CheckConnectionAcknowledged) { using boost::numeric_cast; @@ -3395,8 +3522,7 @@ BOOST_AUTO_TEST_CASE(CheckRegisterCounters) MockBufferManager mockBuffer(1024); CaptureData captureData; - MockProfilingService mockProfilingService( - mockBuffer, options.m_ProfilingOptions.m_EnableProfiling, captureData); + MockProfilingService mockProfilingService(mockBuffer, options.m_ProfilingOptions.m_EnableProfiling, captureData); armnn::BackendId cpuRefId(armnn::Compute::CpuRef); mockProfilingService.RegisterMapping(6, 0, cpuRefId); diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp index a5971e0a4b..d1052cea97 100644 --- a/src/profiling/test/ProfilingTests.hpp +++ b/src/profiling/test/ProfilingTests.hpp @@ -242,7 +242,7 @@ public: uint32_t length = 0, uint32_t timeout = 1000) { - long packetCount = mockProfilingConnection->CheckForPacket({packetType, length}); + long packetCount = mockProfilingConnection->CheckForPacket({ packetType, length }); // The first packet we receive may not be the one we are looking for, so keep looping until till we find it, // or until WaitForPacketsSent times out while(packetCount == 0 && timeout != 0) diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp index d7dc7e2d9e..51f049ddc6 100644 --- a/src/profiling/test/SendCounterPacketTests.cpp +++ b/src/profiling/test/SendCounterPacketTests.cpp @@ -1172,7 +1172,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) BOOST_CHECK(counterDirectory.GetCategoryCount() == 2); BOOST_CHECK(category2); - uint16_t numberOfCores = 3; + uint16_t numberOfCores = 4; // Register a counter associated to "category1" const Counter* counter1 = nullptr; @@ -1186,7 +1186,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) "counter1description", std::string("counter1units"), numberOfCores)); - BOOST_CHECK(counterDirectory.GetCounterCount() == 3); + BOOST_CHECK(counterDirectory.GetCounterCount() == 4); BOOST_CHECK(counter1); // Register a counter associated to "category1" @@ -1203,7 +1203,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) armnn::EmptyOptional(), device2->m_Uid, 0)); - BOOST_CHECK(counterDirectory.GetCounterCount() == 4); + BOOST_CHECK(counterDirectory.GetCounterCount() == 5); BOOST_CHECK(counter2); // Register a counter associated to "category2" @@ -1217,7 +1217,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) "counter3", "counter3description", armnn::EmptyOptional(), - 5, + numberOfCores, device2->m_Uid, counterSet1->m_Uid)); BOOST_CHECK(counterDirectory.GetCounterCount() == 9); @@ -1236,7 +1236,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) uint32_t packetHeaderWord1 = ReadUint32(readBuffer, 4); BOOST_TEST(((packetHeaderWord0 >> 26) & 0x3F) == 0); // packet_family BOOST_TEST(((packetHeaderWord0 >> 16) & 0x3FF) == 2); // packet_id - BOOST_TEST(packetHeaderWord1 == 928); // data_length + BOOST_TEST(packetHeaderWord1 == 432); // data_length // Check the body header uint32_t bodyHeaderWord0 = ReadUint32(readBuffer, 8); @@ -1269,7 +1269,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) uint32_t categoryRecordOffset0 = ReadUint32(readBuffer, 44); uint32_t categoryRecordOffset1 = ReadUint32(readBuffer, 48); BOOST_TEST(categoryRecordOffset0 == 64); // Category record offset for "category1" - BOOST_TEST(categoryRecordOffset1 == 472); // Category record offset for "category2" + BOOST_TEST(categoryRecordOffset1 == 168); // Category record offset for "category2" // Get the device record pool offset uint32_t uint32_t_size = sizeof(uint32_t); @@ -1584,7 +1584,8 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) const Category* category = counterDirectory.GetCategory(categoryRecord.name); BOOST_CHECK(category); BOOST_CHECK(category->m_Name == categoryRecord.name); - BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count); + BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count + static_cast(numberOfCores) -1); + BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count + static_cast(numberOfCores) -1); // Check that the event records are correct for (const EventRecord& eventRecord : categoryRecord.event_records) diff --git a/src/profiling/test/SendTimelinePacketTests.cpp b/src/profiling/test/SendTimelinePacketTests.cpp index 98b161f65e..4a13ebf824 100644 --- a/src/profiling/test/SendTimelinePacketTests.cpp +++ b/src/profiling/test/SendTimelinePacketTests.cpp @@ -15,6 +15,7 @@ #include #include +#include using namespace armnn::profiling; @@ -395,10 +396,12 @@ BOOST_AUTO_TEST_CASE(SendTimelinePacketTests3) BOOST_AUTO_TEST_CASE(GetGuidsFromProfilingService) { - armnn::IRuntime::CreationOptions::ExternalProfilingOptions options; - options.m_EnableProfiling = true; - armnn::profiling::ProfilingService profilingService; - profilingService.ResetExternalProfilingOptions(options, true); + armnn::IRuntime::CreationOptions options; + options.m_ProfilingOptions.m_EnableProfiling = true; + armnn::Runtime runtime(options); + armnn::profiling::ProfilingService profilingService(runtime); + + profilingService.ResetExternalProfilingOptions(options.m_ProfilingOptions, true); ProfilingStaticGuid staticGuid = profilingService.GetStaticId("dummy"); std::hash hasher; uint64_t hash = static_cast(hasher("dummy")); diff --git a/tests/profiling/gatordmock/GatordMockMain.cpp b/tests/profiling/gatordmock/GatordMockMain.cpp index e19461f6cb..029c58f5e8 100644 --- a/tests/profiling/gatordmock/GatordMockMain.cpp +++ b/tests/profiling/gatordmock/GatordMockMain.cpp @@ -3,16 +3,10 @@ // SPDX-License-Identifier: MIT // -#include "PacketVersionResolver.hpp" #include "CommandFileParser.hpp" #include "CommandLineProcessor.hpp" -#include "DirectoryCaptureCommandHandler.hpp" #include "GatordMockService.hpp" -#include "PeriodicCounterCaptureCommandHandler.hpp" -#include "PeriodicCounterSelectionResponseHandler.hpp" #include -#include -#include #include #include @@ -32,38 +26,7 @@ void exit_capture(int signum) bool CreateMockService(armnnUtils::Sockets::Socket clientConnection, std::string commandFile, bool isEchoEnabled) { - profiling::PacketVersionResolver packetVersionResolver; - // Create the Command Handler Registry - profiling::CommandHandlerRegistry registry; - - timelinedecoder::TimelineDecoder timelineDecoder; - timelineDecoder.SetDefaultCallbacks(); - - // This functor will receive back the selection response packet. - PeriodicCounterSelectionResponseHandler periodicCounterSelectionResponseHandler( - 0, 4, packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue()); - // This functor will receive the counter data. - PeriodicCounterCaptureCommandHandler counterCaptureCommandHandler( - 3, 0, packetVersionResolver.ResolvePacketVersion(3, 0).GetEncodedValue()); - - profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler( - 0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), false); - - timelinedecoder::TimelineCaptureCommandHandler timelineCaptureCommandHandler( - 1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder); - - timelinedecoder::TimelineDirectoryCaptureCommandHandler timelineDirectoryCaptureCommandHandler( - 1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(), - timelineCaptureCommandHandler, false); - - // Register different derived functors - registry.RegisterFunctor(&periodicCounterSelectionResponseHandler); - registry.RegisterFunctor(&counterCaptureCommandHandler); - registry.RegisterFunctor(&directoryCaptureCommandHandler); - registry.RegisterFunctor(&timelineDirectoryCaptureCommandHandler); - registry.RegisterFunctor(&timelineCaptureCommandHandler); - - GatordMockService mockService(clientConnection, registry, isEchoEnabled); + GatordMockService mockService(clientConnection, isEchoEnabled); // Send receive the strweam metadata and send connection ack. if (!mockService.WaitForStreamMetaData()) @@ -82,11 +45,6 @@ bool CreateMockService(armnnUtils::Sockets::Socket clientConnection, std::string // Once we've finished processing the file wait for the receiving thread to close. mockService.WaitForReceivingThread(); - if(isEchoEnabled) - { - timelineDecoder.print(); - } - return EXIT_SUCCESS; } diff --git a/tests/profiling/gatordmock/GatordMockService.cpp b/tests/profiling/gatordmock/GatordMockService.cpp index a3f732cb55..3e19c25b6c 100644 --- a/tests/profiling/gatordmock/GatordMockService.cpp +++ b/tests/profiling/gatordmock/GatordMockService.cpp @@ -131,10 +131,30 @@ void GatordMockService::SendRequestCounterDir() { std::cout << "Sending connection acknowledgement." << std::endl; } - // The connection ack packet is an empty data packet with packetId == 1. + // The request counter directory packet is an empty data packet with packetId == 3. SendPacket(0, 3, nullptr, 0); } +void GatordMockService::SendActivateTimelinePacket() +{ + if (m_EchoPackets) + { + std::cout << "Sending activate timeline packet." << std::endl; + } + // The activate timeline packet is an empty data packet with packetId == 6. + SendPacket(0, 6, nullptr, 0); +} + +void GatordMockService::SendDeactivateTimelinePacket() +{ + if (m_EchoPackets) + { + std::cout << "Sending deactivate timeline packet." << std::endl; + } + // The deactivate timeline packet is an empty data packet with packetId == 7. + SendPacket(0, 7, nullptr, 0); +} + bool GatordMockService::LaunchReceivingThread() { if (m_EchoPackets) @@ -165,6 +185,11 @@ void GatordMockService::WaitForReceivingThread() // Wait for the receiving thread to complete operations m_ListeningThread.join(); } + + if(m_EchoPackets) + { + m_TimelineDecoder.print(); + } } void GatordMockService::SendPeriodicCounterSelectionList(uint32_t period, std::vector counters) diff --git a/tests/profiling/gatordmock/GatordMockService.hpp b/tests/profiling/gatordmock/GatordMockService.hpp index c00685fff2..2ff93c9de6 100644 --- a/tests/profiling/gatordmock/GatordMockService.hpp +++ b/tests/profiling/gatordmock/GatordMockService.hpp @@ -13,6 +13,15 @@ #include #include +#include +#include +#include +#include +#include "PeriodicCounterCaptureCommandHandler.hpp" +#include "StreamMetadataCommandHandler.hpp" + +#include "PacketVersionResolver.hpp" + namespace armnn { @@ -39,15 +48,33 @@ class GatordMockService public: /// @param registry reference to a command handler registry. /// @param echoPackets if true the raw packets will be printed to stdout. - GatordMockService(armnnUtils::Sockets::Socket clientConnection, - armnn::profiling::CommandHandlerRegistry& registry, - bool echoPackets) + GatordMockService(armnnUtils::Sockets::Socket clientConnection, bool echoPackets) : m_ClientConnection(clientConnection) - , m_HandlerRegistry(registry) + , m_PacketsReceivedCount(0) , m_EchoPackets(echoPackets) , m_CloseReceivingThread(false) + , m_PacketVersionResolver() + , m_HandlerRegistry() + , m_TimelineDecoder() + , m_StreamMetadataCommandHandler( + 0, 0, m_PacketVersionResolver.ResolvePacketVersion(0, 0).GetEncodedValue(), true) + , m_CounterCaptureCommandHandler( + 0, 4, m_PacketVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(), true) + , m_DirectoryCaptureCommandHandler( + 0, 2, m_PacketVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), true) + , m_TimelineCaptureCommandHandler( + 1, 1, m_PacketVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), m_TimelineDecoder) + , m_TimelineDirectoryCaptureCommandHandler( + 1, 0, m_PacketVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(), + m_TimelineCaptureCommandHandler, true) { - m_PacketsReceivedCount.store(0, std::memory_order_relaxed); + m_TimelineDecoder.SetDefaultCallbacks(); + + m_HandlerRegistry.RegisterFunctor(&m_StreamMetadataCommandHandler); + m_HandlerRegistry.RegisterFunctor(&m_CounterCaptureCommandHandler); + m_HandlerRegistry.RegisterFunctor(&m_DirectoryCaptureCommandHandler); + m_HandlerRegistry.RegisterFunctor(&m_TimelineDirectoryCaptureCommandHandler); + m_HandlerRegistry.RegisterFunctor(&m_TimelineCaptureCommandHandler); } ~GatordMockService() @@ -74,6 +101,12 @@ public: /// Send a request counter directory packet back to the client. void SendRequestCounterDir(); + /// Send a activate timeline packet back to the client. + void SendActivateTimelinePacket(); + + /// Send a deactivate timeline packet back to the client. + void SendDeactivateTimelinePacket(); + /// Start the thread that will receive all packets and print them nicely to stdout. bool LaunchReceivingThread(); @@ -115,6 +148,22 @@ public: return m_StreamMetaDataPid; } + profiling::DirectoryCaptureCommandHandler& GetDirectoryCaptureCommandHandler() + { + return m_DirectoryCaptureCommandHandler; + } + + timelinedecoder::TimelineDecoder& GetTimelineDecoder() + { + return m_TimelineDecoder; + } + + timelinedecoder::TimelineDirectoryCaptureCommandHandler& GetTimelineDirectoryCaptureCommandHandler() + { + return m_TimelineDirectoryCaptureCommandHandler; + } + + private: void ReceiveLoop(GatordMockService& mockService); @@ -141,18 +190,30 @@ private: static const uint32_t PIPE_MAGIC = 0x45495434; - std::atomic m_PacketsReceivedCount; TargetEndianness m_Endianness; uint32_t m_StreamMetaDataVersion; uint32_t m_StreamMetaDataMaxDataLen; uint32_t m_StreamMetaDataPid; armnnUtils::Sockets::Socket m_ClientConnection; - armnn::profiling::CommandHandlerRegistry& m_HandlerRegistry; + std::atomic m_PacketsReceivedCount; bool m_EchoPackets; std::thread m_ListeningThread; std::atomic m_CloseReceivingThread; + + profiling::PacketVersionResolver m_PacketVersionResolver; + profiling::CommandHandlerRegistry m_HandlerRegistry; + + timelinedecoder::TimelineDecoder m_TimelineDecoder; + + gatordmock::StreamMetadataCommandHandler m_StreamMetadataCommandHandler; + gatordmock::PeriodicCounterCaptureCommandHandler m_CounterCaptureCommandHandler; + + profiling::DirectoryCaptureCommandHandler m_DirectoryCaptureCommandHandler; + + timelinedecoder::TimelineCaptureCommandHandler m_TimelineCaptureCommandHandler; + timelinedecoder::TimelineDirectoryCaptureCommandHandler m_TimelineDirectoryCaptureCommandHandler; }; } // namespace gatordmock diff --git a/tests/profiling/gatordmock/tests/GatordMockTests.cpp b/tests/profiling/gatordmock/tests/GatordMockTests.cpp index 7d938bd404..7417946844 100644 --- a/tests/profiling/gatordmock/tests/GatordMockTests.cpp +++ b/tests/profiling/gatordmock/tests/GatordMockTests.cpp @@ -9,12 +9,14 @@ #include #include #include -#include #include #include #include +#include +#include "../../src/backends/backendsCommon/test/MockBackend.hpp" + #include #include #include @@ -104,6 +106,19 @@ BOOST_AUTO_TEST_CASE(CounterCaptureHandlingTest) } } +void WaitFor(std::function predicate, std::string errorMsg, uint32_t timeout = 2000, uint32_t sleepTime = 50) +{ + uint32_t timeSlept = 0; + while (!predicate()) + { + if (timeSlept >= timeout) + { + BOOST_FAIL("Timeout: " + errorMsg); + } + std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); + timeSlept += sleepTime; + } +} void CheckTimelineDirectory(timelinedecoder::TimelineDirectoryCaptureCommandHandler& commandHandler) { @@ -211,43 +226,6 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) // The purpose of this test is to setup both sides of the profiling service and get to the point of receiving // performance data. - //These variables are used to wait for the profiling service - uint32_t timeout = 2000; - uint32_t sleepTime = 50; - uint32_t timeSlept = 0; - - profiling::PacketVersionResolver packetVersionResolver; - - // Create the Command Handler Registry - profiling::CommandHandlerRegistry registry; - - timelinedecoder::TimelineDecoder timelineDecoder; - timelineDecoder.SetDefaultCallbacks(); - - // Update with derived functors - gatordmock::StreamMetadataCommandHandler streamMetadataCommandHandler( - 0, 0, packetVersionResolver.ResolvePacketVersion(0, 0).GetEncodedValue(), true); - - gatordmock::PeriodicCounterCaptureCommandHandler counterCaptureCommandHandler( - 0, 4, packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(), true); - - profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler( - 0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), true); - - timelinedecoder::TimelineCaptureCommandHandler timelineCaptureCommandHandler( - 1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder); - - timelinedecoder::TimelineDirectoryCaptureCommandHandler timelineDirectoryCaptureCommandHandler( - 1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(), - timelineCaptureCommandHandler, true); - - // Register different derived functors - registry.RegisterFunctor(&streamMetadataCommandHandler); - registry.RegisterFunctor(&counterCaptureCommandHandler); - registry.RegisterFunctor(&directoryCaptureCommandHandler); - registry.RegisterFunctor(&timelineDirectoryCaptureCommandHandler); - registry.RegisterFunctor(&timelineCaptureCommandHandler); - // Setup the mock service to bind to the UDS. std::string udsNamespace = "gatord_namespace"; @@ -279,18 +257,15 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) BOOST_FAIL("Failed to connect client"); } - gatordmock::GatordMockService mockService(clientSocket, registry, false); + gatordmock::GatordMockService mockService(clientSocket, false); + + timelinedecoder::TimelineDecoder& timelineDecoder = mockService.GetTimelineDecoder(); + profiling::DirectoryCaptureCommandHandler& directoryCaptureCommandHandler = + mockService.GetDirectoryCaptureCommandHandler(); // Give the profiling service sending thread time start executing and send the stream metadata. - while (profilingService.GetCurrentState() != profiling::ProfilingState::WaitingForAck) - { - if (timeSlept >= timeout) - { - BOOST_FAIL("Timeout: Profiling service did not switch to WaitingForAck state"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); - timeSlept += sleepTime; - } + WaitFor([&](){return profilingService.GetCurrentState() == profiling::ProfilingState::WaitingForAck;}, + "Profiling service did not switch to WaitingForAck state"); profilingService.Update(); // Read the stream metadata on the mock side. @@ -300,55 +275,21 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) } // Send Ack from GatorD mockService.SendConnectionAck(); + // And start to listen for packets + mockService.LaunchReceivingThread(); - timeSlept = 0; - while (profilingService.GetCurrentState() != profiling::ProfilingState::Active) - { - if (timeSlept >= timeout) - { - BOOST_FAIL("Timeout: Profiling service did not switch to Active state"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); - timeSlept += sleepTime; - } + WaitFor([&](){return profilingService.GetCurrentState() == profiling::ProfilingState::Active;}, + "Profiling service did not switch to Active state"); - mockService.LaunchReceivingThread(); // As part of the default startup of the profiling service a counter directory packet will be sent. - timeSlept = 0; - while (!directoryCaptureCommandHandler.ParsedCounterDirectory()) - { - if (timeSlept >= timeout) - { - BOOST_FAIL("Timeout: MockGatord did not receive counter directory packet"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); - timeSlept += sleepTime; - } + WaitFor([&](){return directoryCaptureCommandHandler.ParsedCounterDirectory();}, + "MockGatord did not receive counter directory packet"); - // As part of the default startup of the profiling service a counter directory packet will be sent. - timeSlept = 0; - while (!directoryCaptureCommandHandler.ParsedCounterDirectory()) - { - if (timeSlept >= timeout) - { - BOOST_FAIL("Timeout: MockGatord did not receive counter directory packet"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); - timeSlept += sleepTime; - } + // Following that we will receive a collection of well known timeline labels and event classes + WaitFor([&](){return timelineDecoder.GetModel().m_EventClasses.size() >= 2;}, + "MockGatord did not receive well known timeline labels and event classes"); - timeSlept = 0; - while (timelineDecoder.GetModel().m_EventClasses.size() < 2) - { - if (timeSlept >= timeout) - { - BOOST_FAIL("Timeout: MockGatord did not receive well known timeline labels"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); - timeSlept += sleepTime; - } - - CheckTimelineDirectory(timelineDirectoryCaptureCommandHandler); + CheckTimelineDirectory(mockService.GetTimelineDirectoryCaptureCommandHandler()); // Verify the commonly used timeline packets sent when the profiling service enters the active state CheckTimelinePackets(timelineDecoder); @@ -439,4 +380,107 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) // PeriodicCounterCapture data received. These are yet to be integrated. } +BOOST_AUTO_TEST_CASE(GatorDMockTimeLineActivation) +{ + armnn::MockBackendInitialiser initialiser; + // Setup the mock service to bind to the UDS. + std::string udsNamespace = "gatord_namespace"; + + armnnUtils::Sockets::Initialize(); + armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); + + if (!gatordmock::GatordMockService::OpenListeningSocket(listeningSocket, udsNamespace)) + { + BOOST_FAIL("Failed to open Listening Socket"); + } + + armnn::IRuntime::CreationOptions options; + options.m_ProfilingOptions.m_EnableProfiling = true; + armnn::Runtime runtime(options); + + armnnUtils::Sockets::Socket clientConnection; + clientConnection = armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC); + gatordmock::GatordMockService mockService(clientConnection, false); + + // Read the stream metadata on the mock side. + if (!mockService.WaitForStreamMetaData()) + { + BOOST_FAIL("Failed to receive StreamMetaData"); + } + + armnn::MockBackendProfilingService mockProfilingService = armnn::MockBackendProfilingService::Instance(); + armnn::MockBackendProfilingContext *mockBackEndProfilingContext = mockProfilingService.GetContext(); + + // Send Ack from GatorD + mockService.SendConnectionAck(); + // And start to listen for packets + mockService.LaunchReceivingThread(); + + // Build and optimize a simple network while we wait + INetworkPtr net(INetwork::Create()); + + IConnectableLayer* input = net->AddInputLayer(0, "input"); + + NormalizationDescriptor descriptor; + IConnectableLayer* normalize = net->AddNormalizationLayer(descriptor, "normalization"); + + IConnectableLayer* output = net->AddOutputLayer(0, "output"); + + input->GetOutputSlot(0).Connect(normalize->GetInputSlot(0)); + normalize->GetOutputSlot(0).Connect(output->GetInputSlot(0)); + + input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 4, 4 }, DataType::Float32)); + normalize->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 4, 4 }, DataType::Float32)); + + std::vector backends = { armnn::Compute::CpuRef }; + IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime.GetDeviceSpec()); + + WaitFor([&](){return mockService.GetDirectoryCaptureCommandHandler().ParsedCounterDirectory();}, + "MockGatord did not receive counter directory packet"); + + timelinedecoder::TimelineDecoder& timelineDecoder = mockService.GetTimelineDecoder(); + + WaitFor([&](){return timelineDecoder.GetModel().m_EventClasses.size() >= 2;}, + "MockGatord did not receive well known timeline labels"); + + // Packets we expect from SendWellKnownLabelsAndEventClassesTest + BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 0); + BOOST_CHECK(timelineDecoder.GetModel().m_EventClasses.size() == 2); + BOOST_CHECK(timelineDecoder.GetModel().m_Labels.size() == 10); + BOOST_CHECK(timelineDecoder.GetModel().m_Relationships.size() == 0); + BOOST_CHECK(timelineDecoder.GetModel().m_Events.size() == 0); + + mockService.SendDeactivateTimelinePacket(); + + WaitFor([&](){return !mockBackEndProfilingContext->TimelineReportingEnabled();}, + "Timeline packets were not deactivated"); + + // Load the network into runtime now that timeline reporting is disabled + armnn::NetworkId netId; + runtime.LoadNetwork(netId, std::move(optNet)); + + // Now activate timeline packets + mockService.SendActivateTimelinePacket(); + + WaitFor([&](){return mockBackEndProfilingContext->TimelineReportingEnabled();}, + "Timeline packets were not activated"); + + // Once timeline packets have been reactivated the ActivateTimelineReportingCommandHandler will resend the + // SendWellKnownLabelsAndEventClasses and then send the structure of any loaded networks + WaitFor([&](){return timelineDecoder.GetModel().m_Labels.size() >= 24;}, + "MockGatord did not receive well known timeline labels"); + + // Packets we expect from SendWellKnownLabelsAndEventClassesTest * 2 and the loaded model + BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 5); + BOOST_CHECK(timelineDecoder.GetModel().m_EventClasses.size() == 4); + BOOST_CHECK(timelineDecoder.GetModel().m_Labels.size() == 24); + BOOST_CHECK(timelineDecoder.GetModel().m_Relationships.size() == 28); + BOOST_CHECK(timelineDecoder.GetModel().m_Events.size() == 0); + + mockService.WaitForReceivingThread(); + armnnUtils::Sockets::Close(listeningSocket); + + GetProfilingService(&runtime).Disconnect(); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1