diff options
Diffstat (limited to 'src')
31 files changed, 676 insertions, 104 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index f3d742c515..9d181e535a 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -263,6 +263,43 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net, } } +void LoadedNetwork::SendNetworkStructure() +{ + Graph& order = m_OptimizedNetwork->GetGraph().TopologicalSort(); + ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid(); + + std::unique_ptr<TimelineUtilityMethods> timelineUtils = + TimelineUtilityMethods::GetTimelineUtils(m_ProfilingService); + + timelineUtils->CreateTypedEntity(networkGuid, LabelsAndEventClasses::NETWORK_GUID); + + for (auto&& layer : order) + { + // Add layer to the post-optimisation network structure + AddLayerStructure(timelineUtils, *layer, networkGuid); + switch (layer->GetType()) + { + case LayerType::Input: + case LayerType::Output: + { + // Inputs and outputs are treated in a special way - see EnqueueInput() and EnqueueOutput(). + break; + } + default: + { + for (auto& workload : m_WorkloadQueue) + { + // Add workload to the post-optimisation network structure + AddWorkloadStructure(timelineUtils, workload, *layer); + } + break; + } + } + } + // Commit to send the post-optimisation network structure + timelineUtils->Commit(); +} + TensorInfo LoadedNetwork::GetInputTensorInfo(LayerBindingId layerId) const { for (auto&& inputLayer : m_OptimizedNetwork->GetGraph().GetInputLayers()) diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index 01e3442508..91379d78ed 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -56,6 +56,8 @@ public: void RegisterDebugCallback(const DebugCallbackFunction& func); + void SendNetworkStructure(); + private: void AllocateWorkingMemory(); diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 26636a81f7..dfcbf852e0 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -152,11 +152,32 @@ const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const return nullptr; } +void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param +{ + // No-op for the time being, but this may be useful in future to have the profilingService available + // if (profilingService.IsProfilingEnabled()){} + + LoadedNetworks::iterator it = m_LoadedNetworks.begin(); + while (it != m_LoadedNetworks.end()) + { + auto& loadedNetwork = it->second; + loadedNetwork->SendNetworkStructure(); + // Increment the Iterator to point to next entry + it++; + } +} + Runtime::Runtime(const CreationOptions& options) - : m_NetworkIdCounter(0) + : m_NetworkIdCounter(0), + m_ProfilingService(*this) { ARMNN_LOG(info) << "ArmNN v" << ARMNN_VERSION << "\n"; + if ( options.m_ProfilingOptions.m_TimelineEnabled && !options.m_ProfilingOptions.m_EnableProfiling ) + { + throw RuntimeException("It is not possible to enable timeline reporting without profiling being enabled"); + } + // pass configuration info to the profiling service m_ProfilingService.ConfigureProfilingService(options.m_ProfilingOptions); diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp index 477b1169b1..d4b6dcbd63 100644 --- a/src/armnn/Runtime.hpp +++ b/src/armnn/Runtime.hpp @@ -16,13 +16,19 @@ #include <ProfilingService.hpp> +#include <IProfilingService.hpp> +#include <IReportStructure.hpp> + #include <mutex> #include <unordered_map> namespace armnn { +using LoadedNetworks = std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>>; +using IReportStructure = profiling::IReportStructure; -class Runtime final : public IRuntime +class Runtime final : public IRuntime, + public IReportStructure { public: /// Loads a complete network into the Runtime. @@ -79,6 +85,10 @@ public: ~Runtime(); + //NOTE: we won't need the profiling service reference but it is good to pass the service + // in this way to facilitate other implementations down the road + virtual void ReportStructure() override; + private: friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp @@ -104,7 +114,9 @@ private: mutable std::mutex m_Mutex; - std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks; + /// Map of Loaded Networks with associated GUID as key + LoadedNetworks m_LoadedNetworks; + std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts; int m_NetworkIdCounter; diff --git a/src/armnn/test/RuntimeTests.cpp b/src/armnn/test/RuntimeTests.cpp index c1150231bd..2fc4b50a54 100644 --- a/src/armnn/test/RuntimeTests.cpp +++ b/src/armnn/test/RuntimeTests.cpp @@ -371,7 +371,15 @@ BOOST_AUTO_TEST_CASE(ProfilingEnableCpuRef) // Create runtime in which the test will run armnn::IRuntime::CreationOptions options; options.m_ProfilingOptions.m_EnableProfiling = true; + options.m_ProfilingOptions.m_TimelineEnabled = true; + armnn::Runtime runtime(options); + GetProfilingService(&runtime).ResetExternalProfilingOptions(options.m_ProfilingOptions, false); + + profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); + profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected); + profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck); + profilingServiceHelper.ForceTransitionToState(ProfilingState::Active); // build up the structure of the network INetworkPtr net(INetwork::Create()); @@ -399,7 +407,6 @@ BOOST_AUTO_TEST_CASE(ProfilingEnableCpuRef) armnn::NetworkId netId; BOOST_TEST(runtime.LoadNetwork(netId, std::move(optNet)) == Status::Success); - profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager(); auto readableBuffer = bufferManager.GetReadableBuffer(); diff --git a/src/backends/backendsCommon/test/MockBackend.cpp b/src/backends/backendsCommon/test/MockBackend.cpp index 8d40117741..116bf77c63 100644 --- a/src/backends/backendsCommon/test/MockBackend.cpp +++ b/src/backends/backendsCommon/test/MockBackend.cpp @@ -5,7 +5,6 @@ #include "MockBackend.hpp" #include "MockBackendId.hpp" -#include "armnn/backends/profiling/IBackendProfilingContext.hpp" #include <armnn/BackendRegistry.hpp> diff --git a/src/backends/backendsCommon/test/MockBackend.hpp b/src/backends/backendsCommon/test/MockBackend.hpp index 6e415b9b52..e1570ff920 100644 --- a/src/backends/backendsCommon/test/MockBackend.hpp +++ b/src/backends/backendsCommon/test/MockBackend.hpp @@ -32,6 +32,7 @@ public: MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling) : m_BackendProfiling(std::move(backendProfiling)) , m_CapturePeriod(0) + , m_IsTimelineEnabled(true) {} ~MockBackendProfilingContext() = default; @@ -93,10 +94,22 @@ public: return true; } + bool EnableTimelineReporting(bool isEnabled) + { + m_IsTimelineEnabled = isEnabled; + return isEnabled; + } + + bool TimelineReportingEnabled() + { + return m_IsTimelineEnabled; + } + private: IBackendInternal::IBackendProfilingPtr m_BackendProfiling; uint32_t m_CapturePeriod; std::vector<uint16_t> m_ActiveCounters; + bool m_IsTimelineEnabled; }; class MockBackendProfilingService diff --git a/src/profiling/ActivateTimelineReportingCommandHandler.cpp b/src/profiling/ActivateTimelineReportingCommandHandler.cpp new file mode 100644 index 0000000000..d762efc277 --- /dev/null +++ b/src/profiling/ActivateTimelineReportingCommandHandler.cpp @@ -0,0 +1,63 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ActivateTimelineReportingCommandHandler.hpp" +#include "TimelineUtilityMethods.hpp" + +#include <armnn/Exceptions.hpp> +#include <boost/format.hpp> + +namespace armnn +{ + +namespace profiling +{ + +void ActivateTimelineReportingCommandHandler::operator()(const Packet& packet) +{ + ProfilingState currentState = m_StateMachine.GetCurrentState(); + + if (!m_ReportStructure.has_value()) + { + throw armnn::Exception(std::string("Profiling Service constructor must be initialised with an " + "IReportStructure argument in order to run timeline reporting")); + } + + switch ( currentState ) + { + case ProfilingState::Uninitialised: + case ProfilingState::NotConnected: + case ProfilingState::WaitingForAck: + throw RuntimeException(boost::str( + boost::format("Activate Timeline Reporting Command Handler invoked while in a wrong state: %1%") + % GetProfilingStateName(currentState))); + case ProfilingState::Active: + if ( !( packet.GetPacketFamily() == 0u && packet.GetPacketId() == 6u )) + { + throw armnn::Exception(std::string("Expected Packet family = 0, id = 6 but received family =") + + std::to_string(packet.GetPacketFamily()) + + " id = " + std::to_string(packet.GetPacketId())); + } + + m_SendTimelinePacket.SendTimelineMessageDirectoryPackage(); + + TimelineUtilityMethods::SendWellKnownLabelsAndEventClasses(m_SendTimelinePacket); + + m_TimelineReporting = true; + + m_ReportStructure.value().ReportStructure(); + + m_BackendNotifier.NotifyBackendsForTimelineReporting(); + + break; + default: + throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%") + % static_cast<int>(currentState))); + } +} + +} // namespace profiling + +} // namespace armnn
\ No newline at end of file diff --git a/src/profiling/ActivateTimelineReportingCommandHandler.hpp b/src/profiling/ActivateTimelineReportingCommandHandler.hpp new file mode 100644 index 0000000000..ac11b46379 --- /dev/null +++ b/src/profiling/ActivateTimelineReportingCommandHandler.hpp @@ -0,0 +1,54 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "CommandHandlerFunctor.hpp" +#include "ProfilingStateMachine.hpp" +#include "Packet.hpp" +#include "SendTimelinePacket.hpp" +#include "IReportStructure.hpp" +#include "armnn/Optional.hpp" +#include "INotifyBackends.hpp" + +namespace armnn +{ + +namespace profiling +{ + +class ActivateTimelineReportingCommandHandler : public CommandHandlerFunctor +{ +public: + ActivateTimelineReportingCommandHandler(uint32_t familyId, + uint32_t packetId, + uint32_t version, + SendTimelinePacket& sendTimelinePacket, + ProfilingStateMachine& profilingStateMachine, + Optional<IReportStructure&> reportStructure, + std::atomic<bool>& timelineReporting, + INotifyBackends& notifyBackends) + : CommandHandlerFunctor(familyId, packetId, version), + m_SendTimelinePacket(sendTimelinePacket), + m_StateMachine(profilingStateMachine), + m_TimelineReporting(timelineReporting), + m_BackendNotifier(notifyBackends), + m_ReportStructure(reportStructure) + {} + + void operator()(const Packet& packet) override; + +private: + SendTimelinePacket& m_SendTimelinePacket; + ProfilingStateMachine& m_StateMachine; + std::atomic<bool>& m_TimelineReporting; + INotifyBackends& m_BackendNotifier; + + Optional<IReportStructure&> m_ReportStructure; +}; + +} // namespace profiling + +} // namespace armnn
\ No newline at end of file diff --git a/src/profiling/CommandHandler.hpp b/src/profiling/CommandHandler.hpp index 0cc23429cd..4bf820c5db 100644 --- a/src/profiling/CommandHandler.hpp +++ b/src/profiling/CommandHandler.hpp @@ -22,16 +22,16 @@ class CommandHandler { public: CommandHandler(uint32_t timeout, - bool stopAfterTimeout, - CommandHandlerRegistry& commandHandlerRegistry, - PacketVersionResolver& packetVersionResolver) - : m_Timeout(timeout) - , m_StopAfterTimeout(stopAfterTimeout) - , m_IsRunning(false) - , m_KeepRunning(false) - , m_CommandThread() - , m_CommandHandlerRegistry(commandHandlerRegistry) - , m_PacketVersionResolver(packetVersionResolver) + bool stopAfterTimeout, + CommandHandlerRegistry& commandHandlerRegistry, + PacketVersionResolver& packetVersionResolver) + : m_Timeout(timeout), + m_StopAfterTimeout(stopAfterTimeout), + m_IsRunning(false), + m_KeepRunning(false), + m_CommandThread(), + m_CommandHandlerRegistry(commandHandlerRegistry), + m_PacketVersionResolver(packetVersionResolver) {} ~CommandHandler() { Stop(); } @@ -46,13 +46,13 @@ private: void HandleCommands(IProfilingConnection& profilingConnection); std::atomic<uint32_t> m_Timeout; - std::atomic<bool> m_StopAfterTimeout; - std::atomic<bool> m_IsRunning; - std::atomic<bool> m_KeepRunning; - std::thread m_CommandThread; + std::atomic<bool> m_StopAfterTimeout; + std::atomic<bool> m_IsRunning; + std::atomic<bool> m_KeepRunning; + std::thread m_CommandThread; CommandHandlerRegistry& m_CommandHandlerRegistry; - PacketVersionResolver& m_PacketVersionResolver; + PacketVersionResolver& m_PacketVersionResolver; }; } // namespace profiling diff --git a/src/profiling/DeactivateTimelineReportingCommandHandler.cpp b/src/profiling/DeactivateTimelineReportingCommandHandler.cpp new file mode 100644 index 0000000000..dbfb053b3d --- /dev/null +++ b/src/profiling/DeactivateTimelineReportingCommandHandler.cpp @@ -0,0 +1,53 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "DeactivateTimelineReportingCommandHandler.hpp" + +#include <armnn/Exceptions.hpp> +#include <boost/format.hpp> + + +namespace armnn +{ + +namespace profiling +{ + +void DeactivateTimelineReportingCommandHandler::operator()(const Packet& packet) +{ + ProfilingState currentState = m_StateMachine.GetCurrentState(); + + switch ( currentState ) + { + case ProfilingState::Uninitialised: + case ProfilingState::NotConnected: + case ProfilingState::WaitingForAck: + throw RuntimeException(boost::str( + boost::format("Deactivate Timeline Reporting Command Handler invoked while in a wrong state: %1%") + % GetProfilingStateName(currentState))); + case ProfilingState::Active: + if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 7u)) + { + throw armnn::Exception(std::string("Expected Packet family = 0, id = 7 but received family =") + + std::to_string(packet.GetPacketFamily()) + +" id = " + std::to_string(packet.GetPacketId())); + } + + m_TimelineReporting.store(false); + + // Notify Backends + m_BackendNotifier.NotifyBackendsForTimelineReporting(); + + break; + default: + throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%") + % static_cast<int>(currentState))); + } +} + +} // namespace profiling + +} // namespace armnn + diff --git a/src/profiling/DeactivateTimelineReportingCommandHandler.hpp b/src/profiling/DeactivateTimelineReportingCommandHandler.hpp new file mode 100644 index 0000000000..e06bae836f --- /dev/null +++ b/src/profiling/DeactivateTimelineReportingCommandHandler.hpp @@ -0,0 +1,45 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "CommandHandlerFunctor.hpp" +#include "Packet.hpp" +#include "ProfilingStateMachine.hpp" +#include "INotifyBackends.hpp" + +namespace armnn +{ + +namespace profiling +{ + +class DeactivateTimelineReportingCommandHandler : public CommandHandlerFunctor +{ + +public: + DeactivateTimelineReportingCommandHandler(uint32_t familyId, + uint32_t packetId, + uint32_t version, + std::atomic<bool>& timelineReporting, + ProfilingStateMachine& profilingStateMachine, + INotifyBackends& notifyBackends) + : CommandHandlerFunctor(familyId, packetId, version) + , m_TimelineReporting(timelineReporting) + , m_StateMachine(profilingStateMachine) + , m_BackendNotifier(notifyBackends) + {} + + void operator()(const Packet& packet) override; + +private: + std::atomic<bool>& m_TimelineReporting; + ProfilingStateMachine& m_StateMachine; + INotifyBackends& m_BackendNotifier; +}; + +} // namespace profiling + +} // namespace armnn
\ No newline at end of file diff --git a/src/profiling/DirectoryCaptureCommandHandler.cpp b/src/profiling/DirectoryCaptureCommandHandler.cpp index 65cac848ae..93cdde736e 100644 --- a/src/profiling/DirectoryCaptureCommandHandler.cpp +++ b/src/profiling/DirectoryCaptureCommandHandler.cpp @@ -281,7 +281,8 @@ std::vector<CounterDirectoryEventRecord> DirectoryCaptureCommandHandler::ReadEve eventRecords[i].m_CounterDescription = GetStringNameFromBuffer(data, eventRecordOffset + descriptionOffset); - eventRecords[i].m_CounterUnits = GetStringNameFromBuffer(data, eventRecordOffset + unitsOffset); + eventRecords[i].m_CounterUnits = unitsOffset == 0 ? Optional<std::string>() : + GetStringNameFromBuffer(data, eventRecordOffset + unitsOffset); } return eventRecords; diff --git a/src/profiling/DirectoryCaptureCommandHandler.hpp b/src/profiling/DirectoryCaptureCommandHandler.hpp index 03bbb1eb09..6b25714168 100644 --- a/src/profiling/DirectoryCaptureCommandHandler.hpp +++ b/src/profiling/DirectoryCaptureCommandHandler.hpp @@ -25,7 +25,7 @@ struct CounterDirectoryEventRecord std::string m_CounterName; uint16_t m_CounterSetUid; uint16_t m_CounterUid; - std::string m_CounterUnits; + Optional<std::string> m_CounterUnits; uint16_t m_DeviceUid; uint16_t m_MaxCounterUid; }; diff --git a/src/profiling/INotifyBackends.hpp b/src/profiling/INotifyBackends.hpp new file mode 100644 index 0000000000..217ebdef1c --- /dev/null +++ b/src/profiling/INotifyBackends.hpp @@ -0,0 +1,24 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +namespace armnn +{ + +namespace profiling +{ + +class INotifyBackends +{ +public: + virtual ~INotifyBackends() {} + virtual void NotifyBackendsForTimelineReporting() = 0; +}; + +} // namespace profiling + +} // namespace armnn + diff --git a/src/profiling/IReportStructure.hpp b/src/profiling/IReportStructure.hpp new file mode 100644 index 0000000000..1ae049733f --- /dev/null +++ b/src/profiling/IReportStructure.hpp @@ -0,0 +1,24 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +namespace armnn +{ + +namespace profiling +{ + +class IReportStructure +{ +public: + virtual ~IReportStructure() {} + virtual void ReportStructure() = 0; +}; + +} // namespace profiling + +} // namespace armnn + diff --git a/src/profiling/PacketVersionResolver.cpp b/src/profiling/PacketVersionResolver.cpp index 2c75067487..4178ace166 100644 --- a/src/profiling/PacketVersionResolver.cpp +++ b/src/profiling/PacketVersionResolver.cpp @@ -54,8 +54,17 @@ bool PacketKey::operator!=(const PacketKey& rhs) const Version PacketVersionResolver::ResolvePacketVersion(uint32_t familyId, uint32_t packetId) const { - IgnoreUnused(familyId, packetId); - // NOTE: For now every packet specification is at version 1.0.0 + const PacketKey packetKey(familyId, packetId); + + if( packetKey == ActivateTimeLinePacket ) + { + return Version(1, 1, 0); + } + if( packetKey == DectivateTimeLinePacket ) + { + return Version(1, 1, 0); + } + return Version(1, 0, 0); } diff --git a/src/profiling/PacketVersionResolver.hpp b/src/profiling/PacketVersionResolver.hpp index e959ed548e..6222eb02e8 100644 --- a/src/profiling/PacketVersionResolver.hpp +++ b/src/profiling/PacketVersionResolver.hpp @@ -33,6 +33,9 @@ private: uint32_t m_PacketId; }; +static const PacketKey ActivateTimeLinePacket(0 , 6); +static const PacketKey DectivateTimeLinePacket(0 , 7); + class PacketVersionResolver final { public: diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp index e42ef134dc..3a8f3f83a3 100644 --- a/src/profiling/ProfilingService.cpp +++ b/src/profiling/ProfilingService.cpp @@ -34,6 +34,7 @@ void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOpti { // Update the profiling options m_Options = options; + m_TimelineReporting = options.m_TimelineEnabled; // Check if the profiling service needs to be reset if (resetProfilingService) @@ -431,7 +432,7 @@ void ProfilingService::Reset() // ...finally reset the profiling state machine m_StateMachine.Reset(); m_BackendProfilingContexts.clear(); - m_MaxGlobalCounterId = armnn::profiling::INFERENCES_RUN; + m_MaxGlobalCounterId = armnn::profiling::MAX_ARMNN_COUNTER; } void ProfilingService::Stop() @@ -463,11 +464,22 @@ inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const } } +void ProfilingService::NotifyBackendsForTimelineReporting() +{ + BackendProfilingContext::iterator it = m_BackendProfilingContexts.begin(); + while (it != m_BackendProfilingContexts.end()) + { + auto& backendProfilingContext = it->second; + backendProfilingContext->EnableTimelineReporting(m_TimelineReporting); + // Increment the Iterator to point to next entry + it++; + } +} + ProfilingService::~ProfilingService() { Stop(); } - } // namespace profiling } // namespace armnn diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp index f6e409daeb..df7bd8f857 100644 --- a/src/profiling/ProfilingService.hpp +++ b/src/profiling/ProfilingService.hpp @@ -5,14 +5,17 @@ #pragma once +#include "ActivateTimelineReportingCommandHandler.hpp" #include "BufferManager.hpp" #include "CommandHandler.hpp" #include "ConnectionAcknowledgedCommandHandler.hpp" #include "CounterDirectory.hpp" #include "CounterIdMap.hpp" +#include "DeactivateTimelineReportingCommandHandler.hpp" #include "ICounterRegistry.hpp" #include "ICounterValues.hpp" #include "IProfilingService.hpp" +#include "IReportStructure.hpp" #include "PeriodicCounterCapture.hpp" #include "PeriodicCounterSelectionCommandHandler.hpp" #include "PerJobCounterSelectionCommandHandler.hpp" @@ -24,6 +27,7 @@ #include "SendThread.hpp" #include "SendTimelinePacket.hpp" #include "TimelinePacketWriterFactory.hpp" +#include "INotifyBackends.hpp" #include <armnn/backends/profiling/IBackendProfilingContext.hpp> namespace armnn @@ -32,14 +36,14 @@ namespace armnn namespace profiling { // Static constants describing ArmNN's counter UID's -static const uint16_t NETWORK_LOADS = 0; -static const uint16_t NETWORK_UNLOADS = 1; -static const uint16_t REGISTERED_BACKENDS = 2; -static const uint16_t UNREGISTERED_BACKENDS = 3; -static const uint16_t INFERENCES_RUN = 4; -static const uint16_t MAX_ARMNN_COUNTER = INFERENCES_RUN; - -class ProfilingService : public IReadWriteCounterValues, public IProfilingService +static const uint16_t NETWORK_LOADS = 0; +static const uint16_t NETWORK_UNLOADS = 1; +static const uint16_t REGISTERED_BACKENDS = 2; +static const uint16_t UNREGISTERED_BACKENDS = 3; +static const uint16_t INFERENCES_RUN = 4; +static const uint16_t MAX_ARMNN_COUNTER = INFERENCES_RUN; + +class ProfilingService : public IReadWriteCounterValues, public IProfilingService, public INotifyBackends { public: using ExternalProfilingOptions = IRuntime::CreationOptions::ExternalProfilingOptions; @@ -47,10 +51,12 @@ public: using IProfilingConnectionPtr = std::unique_ptr<IProfilingConnection>; using CounterIndices = std::vector<std::atomic<uint32_t>*>; using CounterValues = std::list<std::atomic<uint32_t>>; + using BackendProfilingContext = std::unordered_map<BackendId, + std::shared_ptr<armnn::profiling::IBackendProfilingContext>>; - // Default constructor/destructor kept protected for testing - ProfilingService() + ProfilingService(Optional<IReportStructure&> reportStructure = EmptyOptional()) : m_Options() + , m_TimelineReporting(false) , m_CounterDirectory() , m_ProfilingConnectionFactory(new ProfilingConnectionFactory()) , m_ProfilingConnection() @@ -97,6 +103,22 @@ public: 5, m_PacketVersionResolver.ResolvePacketVersion(0, 5).GetEncodedValue(), m_StateMachine) + , m_ActivateTimelineReportingCommandHandler(0, + 6, + m_PacketVersionResolver.ResolvePacketVersion(0, 6) + .GetEncodedValue(), + m_SendTimelinePacket, + m_StateMachine, + reportStructure, + m_TimelineReporting, + *this) + , m_DeactivateTimelineReportingCommandHandler(0, + 7, + m_PacketVersionResolver.ResolvePacketVersion(0, 7) + .GetEncodedValue(), + m_TimelineReporting, + m_StateMachine, + *this) , m_TimelinePacketWriterFactory(m_BufferManager) , m_MaxGlobalCounterId(armnn::profiling::INFERENCES_RUN) { @@ -111,6 +133,10 @@ public: // Register the "Per-Job Counter Selection" command handler m_CommandHandlerRegistry.RegisterFunctor(&m_PerJobCounterSelectionCommandHandler); + + m_CommandHandlerRegistry.RegisterFunctor(&m_ActivateTimelineReportingCommandHandler); + + m_CommandHandlerRegistry.RegisterFunctor(&m_DeactivateTimelineReportingCommandHandler); } ~ProfilingService(); @@ -131,6 +157,9 @@ public: void AddBackendProfilingContext(const BackendId backendId, std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext); + // Enable the recording of timeline events and entities + void NotifyBackendsForTimelineReporting() override; + const ICounterDirectory& GetCounterDirectory() const; ICounterRegistry& GetCounterRegistry(); ProfilingState GetCurrentState() const; @@ -168,13 +197,15 @@ public: return m_SendCounterPacket; } - /// Check if the profiling is enabled - bool IsEnabled() { return m_Options.m_EnableProfiling; } - static ProfilingDynamicGuid GetNextGuid(); static ProfilingStaticGuid GetStaticId(const std::string& str); + bool IsTimelineReportingEnabled() + { + return m_TimelineReporting; + } + private: //Copy/move constructors/destructors and copy/move assignment operators are deleted ProfilingService(const ProfilingService&) = delete; @@ -192,32 +223,37 @@ private: void CheckCounterUid(uint16_t counterUid) const; // Profiling service components - ExternalProfilingOptions m_Options; - CounterDirectory m_CounterDirectory; - CounterIdMap m_CounterIdMap; + ExternalProfilingOptions m_Options; + std::atomic<bool> m_TimelineReporting; + CounterDirectory m_CounterDirectory; + CounterIdMap m_CounterIdMap; IProfilingConnectionFactoryPtr m_ProfilingConnectionFactory; - IProfilingConnectionPtr m_ProfilingConnection; - ProfilingStateMachine m_StateMachine; - CounterIndices m_CounterIndex; - CounterValues m_CounterValues; - CommandHandlerRegistry m_CommandHandlerRegistry; - PacketVersionResolver m_PacketVersionResolver; - CommandHandler m_CommandHandler; - BufferManager m_BufferManager; - SendCounterPacket m_SendCounterPacket; - SendThread m_SendThread; - SendTimelinePacket m_SendTimelinePacket; + IProfilingConnectionPtr m_ProfilingConnection; + ProfilingStateMachine m_StateMachine; + CounterIndices m_CounterIndex; + CounterValues m_CounterValues; + CommandHandlerRegistry m_CommandHandlerRegistry; + PacketVersionResolver m_PacketVersionResolver; + CommandHandler m_CommandHandler; + BufferManager m_BufferManager; + SendCounterPacket m_SendCounterPacket; + SendThread m_SendThread; + SendTimelinePacket m_SendTimelinePacket; + Holder m_Holder; + PeriodicCounterCapture m_PeriodicCounterCapture; - ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler; - RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler; - PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler; - PerJobCounterSelectionCommandHandler m_PerJobCounterSelectionCommandHandler; + + ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler; + RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler; + PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler; + PerJobCounterSelectionCommandHandler m_PerJobCounterSelectionCommandHandler; + ActivateTimelineReportingCommandHandler m_ActivateTimelineReportingCommandHandler; + DeactivateTimelineReportingCommandHandler m_DeactivateTimelineReportingCommandHandler; TimelinePacketWriterFactory m_TimelinePacketWriterFactory; - std::unordered_map<BackendId, - std::shared_ptr<armnn::profiling::IBackendProfilingContext>> m_BackendProfilingContexts; - uint16_t m_MaxGlobalCounterId; + BackendProfilingContext m_BackendProfilingContexts; + uint16_t m_MaxGlobalCounterId; static ProfilingGuidGenerator m_GuidGenerator; diff --git a/src/profiling/ProfilingUtils.cpp b/src/profiling/ProfilingUtils.cpp index 002eeb9616..e419769600 100644 --- a/src/profiling/ProfilingUtils.cpp +++ b/src/profiling/ProfilingUtils.cpp @@ -96,17 +96,8 @@ void WriteBytes(const IPacketBufferPtr& packetBuffer, unsigned int offset, cons uint32_t ConstructHeader(uint32_t packetFamily, uint32_t packetId) { - return ((packetFamily & 0x3F) << 26)| - ((packetId & 0x3FF) << 16); -} - -uint32_t ConstructHeader(uint32_t packetFamily, - uint32_t packetClass, - uint32_t packetType) -{ - return ((packetFamily & 0x3F) << 26)| - ((packetClass & 0x3FF) << 19)| - ((packetType & 0x3FFF) << 16); + return (( packetFamily & 0x0000003F ) << 26 )| + (( packetId & 0x000003FF ) << 16 ); } void WriteUint64(const std::unique_ptr<IPacketBuffer>& packetBuffer, unsigned int offset, uint64_t value) diff --git a/src/profiling/ProfilingUtils.hpp b/src/profiling/ProfilingUtils.hpp index 37ab88cb6f..5888ef0b8c 100644 --- a/src/profiling/ProfilingUtils.hpp +++ b/src/profiling/ProfilingUtils.hpp @@ -146,8 +146,6 @@ void WriteBytes(const IPacketBuffer& packetBuffer, unsigned int offset, const vo uint32_t ConstructHeader(uint32_t packetFamily, uint32_t packetId); -uint32_t ConstructHeader(uint32_t packetFamily, uint32_t packetClass, uint32_t packetType); - void WriteUint64(const IPacketBufferPtr& packetBuffer, unsigned int offset, uint64_t value); void WriteUint32(const IPacketBufferPtr& packetBuffer, unsigned int offset, uint32_t value); diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp index 942ccc7b59..ae4bab91e7 100644 --- a/src/profiling/SendCounterPacket.cpp +++ b/src/profiling/SendCounterPacket.cpp @@ -181,10 +181,34 @@ bool SendCounterPacket::CreateCategoryRecord(const CategoryPtr& category, BOOST_ASSERT(category); const std::string& categoryName = category->m_Name; - const std::vector<uint16_t> categoryCounters = category->m_Counters; - BOOST_ASSERT(!categoryName.empty()); + // Remove any duplicate counters + std::vector<uint16_t> categoryCounters; + for (size_t counterIndex = 0; counterIndex < category->m_Counters.size(); ++counterIndex) + { + uint16_t counterUid = category->m_Counters.at(counterIndex); + auto it = counters.find(counterUid); + if (it == counters.end()) + { + errorMessage = boost::str(boost::format("Counter (%1%) not found in category (%2%)") + % counterUid % category->m_Name ); + return false; + } + + const CounterPtr& counter = it->second; + + if (counterUid == counter->m_MaxCounterUid) + { + categoryCounters.emplace_back(counterUid); + } + } + if (categoryCounters.empty()) + { + errorMessage = boost::str(boost::format("No valid counters found in category (%1%)")% categoryName); + return false; + } + // Utils size_t uint32_t_size = sizeof(uint32_t); @@ -203,7 +227,7 @@ bool SendCounterPacket::CreateCategoryRecord(const CategoryPtr& category, std::vector<uint32_t> categoryNameBuffer; if (!StringToSwTraceString<SwTraceNameCharPolicy>(categoryName, categoryNameBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the name of category \"%1%\" to an SWTrace namestring") + errorMessage = boost::str(boost::format("Cannot convert the name of category (%1%) to an SWTrace namestring") % categoryName); return false; } @@ -221,7 +245,6 @@ bool SendCounterPacket::CreateCategoryRecord(const CategoryPtr& category, { uint16_t counterUid = categoryCounters.at(counterIndex); auto it = counters.find(counterUid); - BOOST_ASSERT(it != counters.end()); const CounterPtr& counter = it->second; EventRecord& eventRecord = eventRecords.at(eventRecordIndex); @@ -299,7 +322,7 @@ bool SendCounterPacket::CreateDeviceRecord(const DevicePtr& device, std::vector<uint32_t> deviceNameBuffer; if (!StringToSwTraceString<SwTraceCharPolicy>(deviceName, deviceNameBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the name of device %1% (\"%2%\") to an SWTrace string") + errorMessage = boost::str(boost::format("Cannot convert the name of device %1% (%2%) to an SWTrace string") % deviceUid % deviceName); return false; @@ -349,7 +372,7 @@ bool SendCounterPacket::CreateCounterSetRecord(const CounterSetPtr& counterSet, std::vector<uint32_t> counterSetNameBuffer; if (!StringToSwTraceString<SwTraceNameCharPolicy>(counterSet->m_Name, counterSetNameBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the name of counter set %1% (\"%2%\") to " + errorMessage = boost::str(boost::format("Cannot convert the name of counter set %1% (%2%) to " "an SWTrace namestring") % counterSetUid % counterSetName); @@ -441,7 +464,7 @@ bool SendCounterPacket::CreateEventRecord(const CounterPtr& counter, std::vector<uint32_t> counterNameBuffer; if (!StringToSwTraceString<SwTraceCharPolicy>(counterName, counterNameBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the name of counter %1% (name: \"%2%\") " + errorMessage = boost::str(boost::format("Cannot convert the name of counter %1% (name: %2%) " "to an SWTrace string") % counterUid % counterName); @@ -457,7 +480,7 @@ bool SendCounterPacket::CreateEventRecord(const CounterPtr& counter, std::vector<uint32_t> counterDescriptionBuffer; if (!StringToSwTraceString<SwTraceCharPolicy>(counterDescription, counterDescriptionBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the description of counter %1% (description: \"%2%\") " + errorMessage = boost::str(boost::format("Cannot convert the description of counter %1% (description: %2%) " "to an SWTrace string") % counterUid % counterName); @@ -481,7 +504,7 @@ bool SendCounterPacket::CreateEventRecord(const CounterPtr& counter, // Convert the counter units into a SWTrace namestring if (!StringToSwTraceString<SwTraceNameCharPolicy>(counterUnits, counterUnitsBuffer)) { - errorMessage = boost::str(boost::format("Cannot convert the units of counter %1% (units: \"%2%\") " + errorMessage = boost::str(boost::format("Cannot convert the units of counter %1% (units: %2%) " "to an SWTrace string") % counterUid % counterName); diff --git a/src/profiling/TimelineUtilityMethods.cpp b/src/profiling/TimelineUtilityMethods.cpp index 8a39ea7c4a..de30b4d4ef 100644 --- a/src/profiling/TimelineUtilityMethods.cpp +++ b/src/profiling/TimelineUtilityMethods.cpp @@ -14,7 +14,7 @@ namespace profiling std::unique_ptr<TimelineUtilityMethods> TimelineUtilityMethods::GetTimelineUtils(ProfilingService& profilingService) { - if (profilingService.IsProfilingEnabled()) + if (profilingService.GetCurrentState() == ProfilingState::Active && profilingService.IsTimelineReportingEnabled()) { std::unique_ptr<ISendTimelinePacket> sendTimelinepacket = profilingService.GetSendTimelinePacket(); return std::make_unique<TimelineUtilityMethods>(sendTimelinepacket); diff --git a/src/profiling/test/ProfilingMocks.hpp b/src/profiling/test/ProfilingMocks.hpp index eeb641e878..ada55d8dff 100644 --- a/src/profiling/test/ProfilingMocks.hpp +++ b/src/profiling/test/ProfilingMocks.hpp @@ -51,6 +51,8 @@ public: PerJobCounterSelection, TimelineMessageDirectory, PeriodicCounterCapture, + ActivateTimelineReporting, + DeactivateTimelineReporting, Unknown }; @@ -85,7 +87,7 @@ public: switch (packetFamily) { case 0: - packetType = packetId < 6 ? PacketType(packetId) : PacketType::Unknown; + packetType = packetId < 8 ? PacketType(packetId) : PacketType::Unknown; break; case 1: packetType = packetId == 0 ? PacketType::TimelineMessageDirectory : PacketType::Unknown; @@ -628,7 +630,8 @@ public: const CaptureData& captureData) : m_SendCounterPacket(mockBufferManager), m_IsProfilingEnabled(isProfilingEnabled), - m_CaptureData(captureData) {} + m_CaptureData(captureData) + {} /// Return the next random Guid in the sequence ProfilingDynamicGuid NextGuid() override @@ -682,10 +685,10 @@ public: private: ProfilingGuidGenerator m_GuidGenerator; - CounterIdMap m_CounterMapping; - SendCounterPacket m_SendCounterPacket; - bool m_IsProfilingEnabled; - CaptureData m_CaptureData; + CounterIdMap m_CounterMapping; + SendCounterPacket m_SendCounterPacket; + bool m_IsProfilingEnabled; + CaptureData m_CaptureData; }; } // namespace profiling diff --git a/src/profiling/test/ProfilingTestUtils.cpp b/src/profiling/test/ProfilingTestUtils.cpp index 244051c785..8de69f14ec 100644 --- a/src/profiling/test/ProfilingTestUtils.cpp +++ b/src/profiling/test/ProfilingTestUtils.cpp @@ -297,7 +297,14 @@ void VerifyPostOptimisationStructureTestImpl(armnn::BackendId backendId) // Create runtime in which test will run armnn::IRuntime::CreationOptions options; options.m_ProfilingOptions.m_EnableProfiling = true; + options.m_ProfilingOptions.m_TimelineEnabled = true; armnn::Runtime runtime(options); + GetProfilingService(&runtime).ResetExternalProfilingOptions(options.m_ProfilingOptions, false); + + profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); + profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected); + profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck); + profilingServiceHelper.ForceTransitionToState(ProfilingState::Active); // build up the structure of the network INetworkPtr net(INetwork::Create()); @@ -363,7 +370,6 @@ void VerifyPostOptimisationStructureTestImpl(armnn::BackendId backendId) armnn::NetworkId netId; BOOST_TEST(runtime.LoadNetwork(netId, std::move(optNet)) == Status::Success); - profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager(); auto readableBuffer = bufferManager.GetReadableBuffer(); diff --git a/src/profiling/test/ProfilingTestUtils.hpp b/src/profiling/test/ProfilingTestUtils.hpp index 459d62435b..816ffd3dc6 100644 --- a/src/profiling/test/ProfilingTestUtils.hpp +++ b/src/profiling/test/ProfilingTestUtils.hpp @@ -6,6 +6,7 @@ #pragma once #include "ProfilingUtils.hpp" +#include "Runtime.hpp" #include <armnn/BackendId.hpp> #include <armnn/Optional.hpp> @@ -68,6 +69,11 @@ public: return GetBufferManager(m_ProfilingService); } armnn::profiling::ProfilingService& m_ProfilingService; + + void ForceTransitionToState(ProfilingState newState) + { + TransitionToState(m_ProfilingService, newState); + } }; } // namespace profiling diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 74b93d7c4f..f252579022 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -4,6 +4,7 @@ // #include "ProfilingTests.hpp" +#include "ProfilingTestUtils.hpp" #include <backends/BackendProfiling.hpp> #include <CommandHandler.hpp> @@ -1823,6 +1824,132 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) BOOST_TEST(period == armnn::LOWEST_CAPTURE_PERIOD); // capture period } +BOOST_AUTO_TEST_CASE(CheckTimelineActivationAndDeactivation) +{ + class TestReportStructure : public IReportStructure + { + public: + virtual void ReportStructure() override + { + m_ReportStructureCalled = true; + } + + bool m_ReportStructureCalled = false; + }; + + class TestNotifyBackends : public INotifyBackends + { + public: + TestNotifyBackends() : m_timelineReporting(false) {} + virtual void NotifyBackendsForTimelineReporting() override + { + m_TestNotifyBackendsCalled = m_timelineReporting.load(); + } + + bool m_TestNotifyBackendsCalled = false; + std::atomic<bool> m_timelineReporting; + }; + + PacketVersionResolver packetVersionResolver; + + BufferManager bufferManager(512); + SendTimelinePacket sendTimelinePacket(bufferManager); + ProfilingStateMachine stateMachine; + TestReportStructure testReportStructure; + TestNotifyBackends testNotifyBackends; + + profiling::ActivateTimelineReportingCommandHandler activateTimelineReportingCommandHandler(0, + 6, + packetVersionResolver.ResolvePacketVersion(0, 6) + .GetEncodedValue(), + sendTimelinePacket, + stateMachine, + testReportStructure, + testNotifyBackends.m_timelineReporting, + testNotifyBackends); + + // Write an "ActivateTimelineReporting" packet into the mock profiling connection, to simulate an input from an + // external profiling service + const uint32_t packetFamily1 = 0; + const uint32_t packetId1 = 6; + uint32_t packetHeader1 = ConstructHeader(packetFamily1, packetId1); + + // Create the ActivateTimelineReportingPacket + Packet ActivateTimelineReportingPacket(packetHeader1); // Length == 0 + + BOOST_CHECK_THROW( + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK_THROW( + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW( + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::Active); + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket); + + BOOST_CHECK(testReportStructure.m_ReportStructureCalled); + BOOST_CHECK(testNotifyBackends.m_TestNotifyBackendsCalled); + BOOST_CHECK(testNotifyBackends.m_timelineReporting.load()); + + DeactivateTimelineReportingCommandHandler deactivateTimelineReportingCommandHandler(0, + 7, + packetVersionResolver.ResolvePacketVersion(0, 7).GetEncodedValue(), + testNotifyBackends.m_timelineReporting, + stateMachine, + testNotifyBackends); + + const uint32_t packetFamily2 = 0; + const uint32_t packetId2 = 7; + uint32_t packetHeader2 = ConstructHeader(packetFamily2, packetId2); + + // Create the DeactivateTimelineReportingPacket + Packet deactivateTimelineReportingPacket(packetHeader2); // Length == 0 + + stateMachine.Reset(); + BOOST_CHECK_THROW( + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK_THROW( + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW( + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::Active); + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket); + + BOOST_CHECK(!testNotifyBackends.m_TestNotifyBackendsCalled); + BOOST_CHECK(!testNotifyBackends.m_timelineReporting.load()); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceNotActive) +{ + using namespace armnn; + using namespace armnn::profiling; + + // Create runtime in which the test will run + armnn::IRuntime::CreationOptions options; + options.m_ProfilingOptions.m_EnableProfiling = true; + + armnn::Runtime runtime(options); + profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); + profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected); + profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck); + profilingServiceHelper.ForceTransitionToState(ProfilingState::Active); + + profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager(); + auto readableBuffer = bufferManager.GetReadableBuffer(); + + // Profiling is enabled, the post-optimisation structure should be created + BOOST_CHECK(readableBuffer == nullptr); +} + BOOST_AUTO_TEST_CASE(CheckConnectionAcknowledged) { using boost::numeric_cast; @@ -3395,8 +3522,7 @@ BOOST_AUTO_TEST_CASE(CheckRegisterCounters) MockBufferManager mockBuffer(1024); CaptureData captureData; - MockProfilingService mockProfilingService( - mockBuffer, options.m_ProfilingOptions.m_EnableProfiling, captureData); + MockProfilingService mockProfilingService(mockBuffer, options.m_ProfilingOptions.m_EnableProfiling, captureData); armnn::BackendId cpuRefId(armnn::Compute::CpuRef); mockProfilingService.RegisterMapping(6, 0, cpuRefId); diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp index a5971e0a4b..d1052cea97 100644 --- a/src/profiling/test/ProfilingTests.hpp +++ b/src/profiling/test/ProfilingTests.hpp @@ -242,7 +242,7 @@ public: uint32_t length = 0, uint32_t timeout = 1000) { - long packetCount = mockProfilingConnection->CheckForPacket({packetType, length}); + long packetCount = mockProfilingConnection->CheckForPacket({ packetType, length }); // The first packet we receive may not be the one we are looking for, so keep looping until till we find it, // or until WaitForPacketsSent times out while(packetCount == 0 && timeout != 0) diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp index d7dc7e2d9e..51f049ddc6 100644 --- a/src/profiling/test/SendCounterPacketTests.cpp +++ b/src/profiling/test/SendCounterPacketTests.cpp @@ -1172,7 +1172,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) BOOST_CHECK(counterDirectory.GetCategoryCount() == 2); BOOST_CHECK(category2); - uint16_t numberOfCores = 3; + uint16_t numberOfCores = 4; // Register a counter associated to "category1" const Counter* counter1 = nullptr; @@ -1186,7 +1186,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) "counter1description", std::string("counter1units"), numberOfCores)); - BOOST_CHECK(counterDirectory.GetCounterCount() == 3); + BOOST_CHECK(counterDirectory.GetCounterCount() == 4); BOOST_CHECK(counter1); // Register a counter associated to "category1" @@ -1203,7 +1203,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) armnn::EmptyOptional(), device2->m_Uid, 0)); - BOOST_CHECK(counterDirectory.GetCounterCount() == 4); + BOOST_CHECK(counterDirectory.GetCounterCount() == 5); BOOST_CHECK(counter2); // Register a counter associated to "category2" @@ -1217,7 +1217,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) "counter3", "counter3description", armnn::EmptyOptional(), - 5, + numberOfCores, device2->m_Uid, counterSet1->m_Uid)); BOOST_CHECK(counterDirectory.GetCounterCount() == 9); @@ -1236,7 +1236,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) uint32_t packetHeaderWord1 = ReadUint32(readBuffer, 4); BOOST_TEST(((packetHeaderWord0 >> 26) & 0x3F) == 0); // packet_family BOOST_TEST(((packetHeaderWord0 >> 16) & 0x3FF) == 2); // packet_id - BOOST_TEST(packetHeaderWord1 == 928); // data_length + BOOST_TEST(packetHeaderWord1 == 432); // data_length // Check the body header uint32_t bodyHeaderWord0 = ReadUint32(readBuffer, 8); @@ -1269,7 +1269,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) uint32_t categoryRecordOffset0 = ReadUint32(readBuffer, 44); uint32_t categoryRecordOffset1 = ReadUint32(readBuffer, 48); BOOST_TEST(categoryRecordOffset0 == 64); // Category record offset for "category1" - BOOST_TEST(categoryRecordOffset1 == 472); // Category record offset for "category2" + BOOST_TEST(categoryRecordOffset1 == 168); // Category record offset for "category2" // Get the device record pool offset uint32_t uint32_t_size = sizeof(uint32_t); @@ -1584,7 +1584,8 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) const Category* category = counterDirectory.GetCategory(categoryRecord.name); BOOST_CHECK(category); BOOST_CHECK(category->m_Name == categoryRecord.name); - BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count); + BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count + static_cast<size_t>(numberOfCores) -1); + BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count + static_cast<size_t>(numberOfCores) -1); // Check that the event records are correct for (const EventRecord& eventRecord : categoryRecord.event_records) diff --git a/src/profiling/test/SendTimelinePacketTests.cpp b/src/profiling/test/SendTimelinePacketTests.cpp index 98b161f65e..4a13ebf824 100644 --- a/src/profiling/test/SendTimelinePacketTests.cpp +++ b/src/profiling/test/SendTimelinePacketTests.cpp @@ -15,6 +15,7 @@ #include <LabelsAndEventClasses.hpp> #include <functional> +#include <Runtime.hpp> using namespace armnn::profiling; @@ -395,10 +396,12 @@ BOOST_AUTO_TEST_CASE(SendTimelinePacketTests3) BOOST_AUTO_TEST_CASE(GetGuidsFromProfilingService) { - armnn::IRuntime::CreationOptions::ExternalProfilingOptions options; - options.m_EnableProfiling = true; - armnn::profiling::ProfilingService profilingService; - profilingService.ResetExternalProfilingOptions(options, true); + armnn::IRuntime::CreationOptions options; + options.m_ProfilingOptions.m_EnableProfiling = true; + armnn::Runtime runtime(options); + armnn::profiling::ProfilingService profilingService(runtime); + + profilingService.ResetExternalProfilingOptions(options.m_ProfilingOptions, true); ProfilingStaticGuid staticGuid = profilingService.GetStaticId("dummy"); std::hash<std::string> hasher; uint64_t hash = static_cast<uint64_t>(hasher("dummy")); |