From fe5a24beeef6e9a41366e694f41093565e748048 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Thu, 9 Apr 2020 16:05:28 +0100 Subject: IVGCVSW-4666 Call EnableProfiling when state switches to active * Move the call to EnableProfiling() into ConnectionAcknowledgedHandler * Fix an issue with MockGatord forcing some command handlers to be quiet * Add some small unrelated improvements and typo fixes to the periodic counter command handlers Signed-off-by: Finn Williams Change-Id: I9e6066b78d1f782cfaf27c11571c0ec5cb5d126f --- src/armnn/Runtime.cpp | 21 ++++++--------------- src/backends/backendsCommon/test/MockBackend.hpp | 14 +++++++++----- .../ConnectionAcknowledgedCommandHandler.cpp | 14 +++++++++++++- .../ConnectionAcknowledgedCommandHandler.hpp | 11 +++++++++-- src/profiling/CounterDirectory.cpp | 2 +- src/profiling/PeriodicCounterCapture.cpp | 2 +- src/profiling/PeriodicCounterCapture.hpp | 4 ++-- .../PeriodicCounterSelectionCommandHandler.cpp | 5 ++--- .../PeriodicCounterSelectionCommandHandler.hpp | 12 ++++++------ src/profiling/ProfilingService.hpp | 3 ++- tests/profiling/gatordmock/GatordMockService.hpp | 8 ++++---- .../profiling/gatordmock/tests/GatordMockTests.cpp | 7 +++++-- 12 files changed, 60 insertions(+), 43 deletions(-) diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 32c7c39f8a..483eea7165 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -178,9 +178,6 @@ Runtime::Runtime(const CreationOptions& options) throw RuntimeException("It is not possible to enable timeline reporting without profiling being enabled"); } - // pass configuration info to the profiling service - m_ProfilingService.ConfigureProfilingService(options.m_ProfilingOptions); - // Load any available/compatible dynamic backend before the runtime // goes through the backend registry LoadDynamicBackends(options.m_DynamicBackendsPath); @@ -213,24 +210,19 @@ Runtime::Runtime(const CreationOptions& options) // Backends that don't support profiling will return a null profiling context. if (profilingContext) { - // Enable profiling on the backend and assert that it returns true - if(profilingContext->EnableProfiling(true)) - { - // Pass the context onto the profiling service. - m_ProfilingService.AddBackendProfilingContext(id, profilingContext); - } - else - { - throw BackendProfilingException("Unable to enable profiling on Backend Id: " + id.Get()); - } + // Pass the context onto the profiling service. + m_ProfilingService.AddBackendProfilingContext(id, profilingContext); } } catch (const BackendUnavailableException&) { // Ignore backends which are unavailable } - } + + // pass configuration info to the profiling service + m_ProfilingService.ConfigureProfilingService(options.m_ProfilingOptions); + m_DeviceSpec.AddSupportedBackends(supportedBackends); } @@ -273,7 +265,6 @@ Runtime::~Runtime() } } - // Clear all dynamic backends. DynamicBackendUtils::DeregisterDynamicBackends(m_DeviceSpec.GetDynamicBackends()); m_DeviceSpec.ClearDynamicBackends(); diff --git a/src/backends/backendsCommon/test/MockBackend.hpp b/src/backends/backendsCommon/test/MockBackend.hpp index e1570ff920..d90ad798da 100644 --- a/src/backends/backendsCommon/test/MockBackend.hpp +++ b/src/backends/backendsCommon/test/MockBackend.hpp @@ -45,18 +45,19 @@ public: uint16_t RegisterCounters(uint16_t currentMaxGlobalCounterId) { std::unique_ptr counterRegistrar = - m_BackendProfiling->GetCounterRegistrationInterface(currentMaxGlobalCounterId); + m_BackendProfiling->GetCounterRegistrationInterface(static_cast(currentMaxGlobalCounterId)); std::string categoryName("MockCounters"); counterRegistrar->RegisterCategory(categoryName); - uint16_t nextMaxGlobalCounterId = - counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter"); - nextMaxGlobalCounterId = counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two", + counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter"); + + counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two", "Another notional counter"); std::string units("microseconds"); - nextMaxGlobalCounterId = counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter", + uint16_t nextMaxGlobalCounterId = + counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter", "A dummy four core counter", units, 4); return nextMaxGlobalCounterId; } @@ -91,6 +92,9 @@ public: bool EnableProfiling(bool) { + auto sendTimelinePacket = m_BackendProfiling->GetSendTimelinePacket(); + sendTimelinePacket->SendTimelineEntityBinaryPacket(4256); + sendTimelinePacket->Commit(); return true; } diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp index 0071bfc11e..995562fb3f 100644 --- a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp +++ b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp @@ -41,9 +41,21 @@ void ConnectionAcknowledgedCommandHandler::operator()(const Packet& packet) // Send the counter directory packet. m_SendCounterPacket.SendCounterDirectoryPacket(m_CounterDirectory); m_SendTimelinePacket.SendTimelineMessageDirectoryPackage(); - TimelineUtilityMethods::SendWellKnownLabelsAndEventClasses(m_SendTimelinePacket); + if(m_BackendProfilingContext.has_value()) + { + for (auto backendContext : m_BackendProfilingContext.value()) + { + // Enable profiling on the backend and assert that it returns true + if(!backendContext.second->EnableProfiling(true)) + { + throw BackendProfilingException( + "Unable to enable profiling on Backend Id: " + backendContext.first.Get()); + } + } + } + break; case ProfilingState::Active: return; // NOP diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.hpp b/src/profiling/ConnectionAcknowledgedCommandHandler.hpp index 6054306da8..e2bdff8e96 100644 --- a/src/profiling/ConnectionAcknowledgedCommandHandler.hpp +++ b/src/profiling/ConnectionAcknowledgedCommandHandler.hpp @@ -5,11 +5,13 @@ #pragma once +#include #include "CommandHandlerFunctor.hpp" #include "ISendCounterPacket.hpp" #include "armnn/profiling/ISendTimelinePacket.hpp" #include "Packet.hpp" #include "ProfilingStateMachine.hpp" +#include namespace armnn { @@ -20,6 +22,9 @@ namespace profiling class ConnectionAcknowledgedCommandHandler final : public CommandHandlerFunctor { +typedef const std::unordered_map>& + BackendProfilingContexts; + public: ConnectionAcknowledgedCommandHandler(uint32_t familyId, uint32_t packetId, @@ -27,12 +32,14 @@ public: ICounterDirectory& counterDirectory, ISendCounterPacket& sendCounterPacket, ISendTimelinePacket& sendTimelinePacket, - ProfilingStateMachine& profilingStateMachine) + ProfilingStateMachine& profilingStateMachine, + Optional backendProfilingContexts = EmptyOptional()) : CommandHandlerFunctor(familyId, packetId, version) , m_CounterDirectory(counterDirectory) , m_SendCounterPacket(sendCounterPacket) , m_SendTimelinePacket(sendTimelinePacket) , m_StateMachine(profilingStateMachine) + , m_BackendProfilingContext(backendProfilingContexts) {} void operator()(const Packet& packet) override; @@ -42,7 +49,7 @@ private: ISendCounterPacket& m_SendCounterPacket; ISendTimelinePacket& m_SendTimelinePacket; ProfilingStateMachine& m_StateMachine; - + Optional m_BackendProfilingContext; }; } // namespace profiling diff --git a/src/profiling/CounterDirectory.cpp b/src/profiling/CounterDirectory.cpp index 415a66072f..ae1c49796c 100644 --- a/src/profiling/CounterDirectory.cpp +++ b/src/profiling/CounterDirectory.cpp @@ -498,7 +498,7 @@ CountersIt CounterDirectory::FindCounter(const std::string& counterName) const return std::find_if(m_Counters.begin(), m_Counters.end(), [&counterName](const auto& pair) { ARMNN_ASSERT(pair.second); - ARMNN_ASSERT(pair.second->m_Uid == pair.first); + ARMNN_ASSERT(pair.first >= pair.second->m_Uid && pair.first <= pair.second->m_MaxCounterUid); return pair.second->m_Name == counterName; }); diff --git a/src/profiling/PeriodicCounterCapture.cpp b/src/profiling/PeriodicCounterCapture.cpp index b143295bc1..4ad1d113b6 100644 --- a/src/profiling/PeriodicCounterCapture.cpp +++ b/src/profiling/PeriodicCounterCapture.cpp @@ -125,7 +125,7 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues for_each(activeBackends.begin(), activeBackends.end(), [&](const armnn::BackendId& backendId) { DispatchPeriodicCounterCapturePacket( - backendId, m_BackendProfilingContext.at(backendId)->ReportCounterValues()); + backendId, m_BackendProfilingContexts.at(backendId)->ReportCounterValues()); }); // Wait the indicated capture period (microseconds) diff --git a/src/profiling/PeriodicCounterCapture.hpp b/src/profiling/PeriodicCounterCapture.hpp index ff0562377c..51ac273860 100644 --- a/src/profiling/PeriodicCounterCapture.hpp +++ b/src/profiling/PeriodicCounterCapture.hpp @@ -39,7 +39,7 @@ public: , m_ReadCounterValues(readCounterValue) , m_SendCounterPacket(packet) , m_CounterIdMap(counterIdMap) - , m_BackendProfilingContext(backendProfilingContexts) + , m_BackendProfilingContexts(backendProfilingContexts) {} ~PeriodicCounterCapture() { Stop(); } @@ -61,7 +61,7 @@ private: ISendCounterPacket& m_SendCounterPacket; const ICounterMappings& m_CounterIdMap; const std::unordered_map>& m_BackendProfilingContext; + std::shared_ptr>& m_BackendProfilingContexts; }; } // namespace profiling diff --git a/src/profiling/PeriodicCounterSelectionCommandHandler.cpp b/src/profiling/PeriodicCounterSelectionCommandHandler.cpp index d218433d93..4e3e6e554b 100644 --- a/src/profiling/PeriodicCounterSelectionCommandHandler.cpp +++ b/src/profiling/PeriodicCounterSelectionCommandHandler.cpp @@ -140,7 +140,6 @@ void PeriodicCounterSelectionCommandHandler::operator()(const Packet& packet) // save the new backend counter ids for next time m_PrevBackendCounterIds = backendCounterIds; - // Set the capture data with only the valid armnn counter UIDs m_CaptureDataHolder.SetCaptureData(capturePeriod, {validCounterIds.begin(), backendIdStart}, activeBackends); @@ -168,8 +167,8 @@ void PeriodicCounterSelectionCommandHandler::operator()(const Packet& packet) std::set PeriodicCounterSelectionCommandHandler::ProcessBackendCounterIds( const u_int32_t capturePeriod, - std::set newCounterIds, - std::set unusedCounterIds) + const std::set newCounterIds, + const std::set unusedCounterIds) { std::set changedBackends; std::set activeBackends = m_CaptureDataHolder.GetCaptureData().GetActiveBackends(); diff --git a/src/profiling/PeriodicCounterSelectionCommandHandler.hpp b/src/profiling/PeriodicCounterSelectionCommandHandler.hpp index 437d7128be..b59d84cffa 100644 --- a/src/profiling/PeriodicCounterSelectionCommandHandler.hpp +++ b/src/profiling/PeriodicCounterSelectionCommandHandler.hpp @@ -37,7 +37,7 @@ public: uint32_t version, const std::unordered_map>& - backendProfilingContext, + backendProfilingContexts, const ICounterMappings& counterIdMap, Holder& captureDataHolder, const uint16_t maxArmnnCounterId, @@ -46,7 +46,7 @@ public: ISendCounterPacket& sendCounterPacket, const ProfilingStateMachine& profilingStateMachine) : CommandHandlerFunctor(familyId, packetId, version) - , m_BackendProfilingContext(backendProfilingContext) + , m_BackendProfilingContexts(backendProfilingContexts) , m_CounterIdMap(counterIdMap) , m_CaptureDataHolder(captureDataHolder) , m_MaxArmCounterId(maxArmnnCounterId) @@ -66,7 +66,7 @@ private: std::unordered_map> m_BackendCounterMap; const std::unordered_map>& m_BackendProfilingContext; + std::shared_ptr>& m_BackendProfilingContexts; const ICounterMappings& m_CounterIdMap; Holder& m_CaptureDataHolder; const uint16_t m_MaxArmCounterId; @@ -82,7 +82,7 @@ private: const std::vector counterIds) { Optional errorMsg = - m_BackendProfilingContext.at(backendId)->ActivateCounters(capturePeriod, counterIds); + m_BackendProfilingContexts.at(backendId)->ActivateCounters(capturePeriod, counterIds); if(errorMsg.has_value()) { @@ -92,8 +92,8 @@ private: } void ParseData(const Packet& packet, CaptureData& captureData); std::set ProcessBackendCounterIds(const u_int32_t capturePeriod, - std::set newCounterIds, - std::set unusedCounterIds); + const std::set newCounterIds, + const std::set unusedCounterIds); }; diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp index a6c5e29767..f3d10e711b 100644 --- a/src/profiling/ProfilingService.hpp +++ b/src/profiling/ProfilingService.hpp @@ -80,7 +80,8 @@ public: m_CounterDirectory, m_SendCounterPacket, m_SendTimelinePacket, - m_StateMachine) + m_StateMachine, + m_BackendProfilingContexts) , m_RequestCounterDirectoryCommandHandler(0, 3, m_PacketVersionResolver.ResolvePacketVersion(0, 3).GetEncodedValue(), diff --git a/tests/profiling/gatordmock/GatordMockService.hpp b/tests/profiling/gatordmock/GatordMockService.hpp index 2ff93c9de6..9b72f72feb 100644 --- a/tests/profiling/gatordmock/GatordMockService.hpp +++ b/tests/profiling/gatordmock/GatordMockService.hpp @@ -57,16 +57,16 @@ public: , m_HandlerRegistry() , m_TimelineDecoder() , m_StreamMetadataCommandHandler( - 0, 0, m_PacketVersionResolver.ResolvePacketVersion(0, 0).GetEncodedValue(), true) + 0, 0, m_PacketVersionResolver.ResolvePacketVersion(0, 0).GetEncodedValue(), !echoPackets) , m_CounterCaptureCommandHandler( - 0, 4, m_PacketVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(), true) + 0, 4, m_PacketVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(), !echoPackets) , m_DirectoryCaptureCommandHandler( - 0, 2, m_PacketVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), true) + 0, 2, m_PacketVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), !echoPackets) , m_TimelineCaptureCommandHandler( 1, 1, m_PacketVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), m_TimelineDecoder) , m_TimelineDirectoryCaptureCommandHandler( 1, 0, m_PacketVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(), - m_TimelineCaptureCommandHandler, true) + m_TimelineCaptureCommandHandler, !echoPackets) { m_TimelineDecoder.SetDefaultCallbacks(); diff --git a/tests/profiling/gatordmock/tests/GatordMockTests.cpp b/tests/profiling/gatordmock/tests/GatordMockTests.cpp index f8b42df674..11a96fdd7d 100644 --- a/tests/profiling/gatordmock/tests/GatordMockTests.cpp +++ b/tests/profiling/gatordmock/tests/GatordMockTests.cpp @@ -443,8 +443,11 @@ BOOST_AUTO_TEST_CASE(GatorDMockTimeLineActivation) WaitFor([&](){return timelineDecoder.GetModel().m_EventClasses.size() >= 2;}, "MockGatord did not receive well known timeline labels"); + WaitFor([&](){return timelineDecoder.GetModel().m_Entities.size() >= 1;}, + "MockGatord did not receive mock backend test entity"); + // Packets we expect from SendWellKnownLabelsAndEventClassesTest - BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 0); + BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 1); BOOST_CHECK(timelineDecoder.GetModel().m_EventClasses.size() == 2); BOOST_CHECK(timelineDecoder.GetModel().m_Labels.size() == 10); BOOST_CHECK(timelineDecoder.GetModel().m_Relationships.size() == 0); @@ -471,7 +474,7 @@ BOOST_AUTO_TEST_CASE(GatorDMockTimeLineActivation) "MockGatord did not receive well known timeline labels"); // Packets we expect from SendWellKnownLabelsAndEventClassesTest * 2 and the loaded model - BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 5); + BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 6); BOOST_CHECK(timelineDecoder.GetModel().m_EventClasses.size() == 4); BOOST_CHECK(timelineDecoder.GetModel().m_Labels.size() == 24); BOOST_CHECK(timelineDecoder.GetModel().m_Relationships.size() == 28); -- cgit v1.2.1