diff options
Diffstat (limited to 'src/profiling/test')
-rw-r--r-- | src/profiling/test/ProfilingMocks.hpp | 15 | ||||
-rw-r--r-- | src/profiling/test/ProfilingTestUtils.cpp | 8 | ||||
-rw-r--r-- | src/profiling/test/ProfilingTestUtils.hpp | 6 | ||||
-rw-r--r-- | src/profiling/test/ProfilingTests.cpp | 130 | ||||
-rw-r--r-- | src/profiling/test/ProfilingTests.hpp | 2 | ||||
-rw-r--r-- | src/profiling/test/SendCounterPacketTests.cpp | 15 | ||||
-rw-r--r-- | src/profiling/test/SendTimelinePacketTests.cpp | 11 |
7 files changed, 166 insertions, 21 deletions
diff --git a/src/profiling/test/ProfilingMocks.hpp b/src/profiling/test/ProfilingMocks.hpp index eeb641e878..ada55d8dff 100644 --- a/src/profiling/test/ProfilingMocks.hpp +++ b/src/profiling/test/ProfilingMocks.hpp @@ -51,6 +51,8 @@ public: PerJobCounterSelection, TimelineMessageDirectory, PeriodicCounterCapture, + ActivateTimelineReporting, + DeactivateTimelineReporting, Unknown }; @@ -85,7 +87,7 @@ public: switch (packetFamily) { case 0: - packetType = packetId < 6 ? PacketType(packetId) : PacketType::Unknown; + packetType = packetId < 8 ? PacketType(packetId) : PacketType::Unknown; break; case 1: packetType = packetId == 0 ? PacketType::TimelineMessageDirectory : PacketType::Unknown; @@ -628,7 +630,8 @@ public: const CaptureData& captureData) : m_SendCounterPacket(mockBufferManager), m_IsProfilingEnabled(isProfilingEnabled), - m_CaptureData(captureData) {} + m_CaptureData(captureData) + {} /// Return the next random Guid in the sequence ProfilingDynamicGuid NextGuid() override @@ -682,10 +685,10 @@ public: private: ProfilingGuidGenerator m_GuidGenerator; - CounterIdMap m_CounterMapping; - SendCounterPacket m_SendCounterPacket; - bool m_IsProfilingEnabled; - CaptureData m_CaptureData; + CounterIdMap m_CounterMapping; + SendCounterPacket m_SendCounterPacket; + bool m_IsProfilingEnabled; + CaptureData m_CaptureData; }; } // namespace profiling diff --git a/src/profiling/test/ProfilingTestUtils.cpp b/src/profiling/test/ProfilingTestUtils.cpp index 244051c785..8de69f14ec 100644 --- a/src/profiling/test/ProfilingTestUtils.cpp +++ b/src/profiling/test/ProfilingTestUtils.cpp @@ -297,7 +297,14 @@ void VerifyPostOptimisationStructureTestImpl(armnn::BackendId backendId) // Create runtime in which test will run armnn::IRuntime::CreationOptions options; options.m_ProfilingOptions.m_EnableProfiling = true; + options.m_ProfilingOptions.m_TimelineEnabled = true; armnn::Runtime runtime(options); + GetProfilingService(&runtime).ResetExternalProfilingOptions(options.m_ProfilingOptions, false); + + profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); + profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected); + profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck); + profilingServiceHelper.ForceTransitionToState(ProfilingState::Active); // build up the structure of the network INetworkPtr net(INetwork::Create()); @@ -363,7 +370,6 @@ void VerifyPostOptimisationStructureTestImpl(armnn::BackendId backendId) armnn::NetworkId netId; BOOST_TEST(runtime.LoadNetwork(netId, std::move(optNet)) == Status::Success); - profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager(); auto readableBuffer = bufferManager.GetReadableBuffer(); diff --git a/src/profiling/test/ProfilingTestUtils.hpp b/src/profiling/test/ProfilingTestUtils.hpp index 459d62435b..816ffd3dc6 100644 --- a/src/profiling/test/ProfilingTestUtils.hpp +++ b/src/profiling/test/ProfilingTestUtils.hpp @@ -6,6 +6,7 @@ #pragma once #include "ProfilingUtils.hpp" +#include "Runtime.hpp" #include <armnn/BackendId.hpp> #include <armnn/Optional.hpp> @@ -68,6 +69,11 @@ public: return GetBufferManager(m_ProfilingService); } armnn::profiling::ProfilingService& m_ProfilingService; + + void ForceTransitionToState(ProfilingState newState) + { + TransitionToState(m_ProfilingService, newState); + } }; } // namespace profiling diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 74b93d7c4f..f252579022 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -4,6 +4,7 @@ // #include "ProfilingTests.hpp" +#include "ProfilingTestUtils.hpp" #include <backends/BackendProfiling.hpp> #include <CommandHandler.hpp> @@ -1823,6 +1824,132 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) BOOST_TEST(period == armnn::LOWEST_CAPTURE_PERIOD); // capture period } +BOOST_AUTO_TEST_CASE(CheckTimelineActivationAndDeactivation) +{ + class TestReportStructure : public IReportStructure + { + public: + virtual void ReportStructure() override + { + m_ReportStructureCalled = true; + } + + bool m_ReportStructureCalled = false; + }; + + class TestNotifyBackends : public INotifyBackends + { + public: + TestNotifyBackends() : m_timelineReporting(false) {} + virtual void NotifyBackendsForTimelineReporting() override + { + m_TestNotifyBackendsCalled = m_timelineReporting.load(); + } + + bool m_TestNotifyBackendsCalled = false; + std::atomic<bool> m_timelineReporting; + }; + + PacketVersionResolver packetVersionResolver; + + BufferManager bufferManager(512); + SendTimelinePacket sendTimelinePacket(bufferManager); + ProfilingStateMachine stateMachine; + TestReportStructure testReportStructure; + TestNotifyBackends testNotifyBackends; + + profiling::ActivateTimelineReportingCommandHandler activateTimelineReportingCommandHandler(0, + 6, + packetVersionResolver.ResolvePacketVersion(0, 6) + .GetEncodedValue(), + sendTimelinePacket, + stateMachine, + testReportStructure, + testNotifyBackends.m_timelineReporting, + testNotifyBackends); + + // Write an "ActivateTimelineReporting" packet into the mock profiling connection, to simulate an input from an + // external profiling service + const uint32_t packetFamily1 = 0; + const uint32_t packetId1 = 6; + uint32_t packetHeader1 = ConstructHeader(packetFamily1, packetId1); + + // Create the ActivateTimelineReportingPacket + Packet ActivateTimelineReportingPacket(packetHeader1); // Length == 0 + + BOOST_CHECK_THROW( + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK_THROW( + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW( + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::Active); + activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket); + + BOOST_CHECK(testReportStructure.m_ReportStructureCalled); + BOOST_CHECK(testNotifyBackends.m_TestNotifyBackendsCalled); + BOOST_CHECK(testNotifyBackends.m_timelineReporting.load()); + + DeactivateTimelineReportingCommandHandler deactivateTimelineReportingCommandHandler(0, + 7, + packetVersionResolver.ResolvePacketVersion(0, 7).GetEncodedValue(), + testNotifyBackends.m_timelineReporting, + stateMachine, + testNotifyBackends); + + const uint32_t packetFamily2 = 0; + const uint32_t packetId2 = 7; + uint32_t packetHeader2 = ConstructHeader(packetFamily2, packetId2); + + // Create the DeactivateTimelineReportingPacket + Packet deactivateTimelineReportingPacket(packetHeader2); // Length == 0 + + stateMachine.Reset(); + BOOST_CHECK_THROW( + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK_THROW( + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW( + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception); + + stateMachine.TransitionToState(ProfilingState::Active); + deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket); + + BOOST_CHECK(!testNotifyBackends.m_TestNotifyBackendsCalled); + BOOST_CHECK(!testNotifyBackends.m_timelineReporting.load()); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceNotActive) +{ + using namespace armnn; + using namespace armnn::profiling; + + // Create runtime in which the test will run + armnn::IRuntime::CreationOptions options; + options.m_ProfilingOptions.m_EnableProfiling = true; + + armnn::Runtime runtime(options); + profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime)); + profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected); + profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck); + profilingServiceHelper.ForceTransitionToState(ProfilingState::Active); + + profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager(); + auto readableBuffer = bufferManager.GetReadableBuffer(); + + // Profiling is enabled, the post-optimisation structure should be created + BOOST_CHECK(readableBuffer == nullptr); +} + BOOST_AUTO_TEST_CASE(CheckConnectionAcknowledged) { using boost::numeric_cast; @@ -3395,8 +3522,7 @@ BOOST_AUTO_TEST_CASE(CheckRegisterCounters) MockBufferManager mockBuffer(1024); CaptureData captureData; - MockProfilingService mockProfilingService( - mockBuffer, options.m_ProfilingOptions.m_EnableProfiling, captureData); + MockProfilingService mockProfilingService(mockBuffer, options.m_ProfilingOptions.m_EnableProfiling, captureData); armnn::BackendId cpuRefId(armnn::Compute::CpuRef); mockProfilingService.RegisterMapping(6, 0, cpuRefId); diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp index a5971e0a4b..d1052cea97 100644 --- a/src/profiling/test/ProfilingTests.hpp +++ b/src/profiling/test/ProfilingTests.hpp @@ -242,7 +242,7 @@ public: uint32_t length = 0, uint32_t timeout = 1000) { - long packetCount = mockProfilingConnection->CheckForPacket({packetType, length}); + long packetCount = mockProfilingConnection->CheckForPacket({ packetType, length }); // The first packet we receive may not be the one we are looking for, so keep looping until till we find it, // or until WaitForPacketsSent times out while(packetCount == 0 && timeout != 0) diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp index d7dc7e2d9e..51f049ddc6 100644 --- a/src/profiling/test/SendCounterPacketTests.cpp +++ b/src/profiling/test/SendCounterPacketTests.cpp @@ -1172,7 +1172,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) BOOST_CHECK(counterDirectory.GetCategoryCount() == 2); BOOST_CHECK(category2); - uint16_t numberOfCores = 3; + uint16_t numberOfCores = 4; // Register a counter associated to "category1" const Counter* counter1 = nullptr; @@ -1186,7 +1186,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) "counter1description", std::string("counter1units"), numberOfCores)); - BOOST_CHECK(counterDirectory.GetCounterCount() == 3); + BOOST_CHECK(counterDirectory.GetCounterCount() == 4); BOOST_CHECK(counter1); // Register a counter associated to "category1" @@ -1203,7 +1203,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) armnn::EmptyOptional(), device2->m_Uid, 0)); - BOOST_CHECK(counterDirectory.GetCounterCount() == 4); + BOOST_CHECK(counterDirectory.GetCounterCount() == 5); BOOST_CHECK(counter2); // Register a counter associated to "category2" @@ -1217,7 +1217,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) "counter3", "counter3description", armnn::EmptyOptional(), - 5, + numberOfCores, device2->m_Uid, counterSet1->m_Uid)); BOOST_CHECK(counterDirectory.GetCounterCount() == 9); @@ -1236,7 +1236,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) uint32_t packetHeaderWord1 = ReadUint32(readBuffer, 4); BOOST_TEST(((packetHeaderWord0 >> 26) & 0x3F) == 0); // packet_family BOOST_TEST(((packetHeaderWord0 >> 16) & 0x3FF) == 2); // packet_id - BOOST_TEST(packetHeaderWord1 == 928); // data_length + BOOST_TEST(packetHeaderWord1 == 432); // data_length // Check the body header uint32_t bodyHeaderWord0 = ReadUint32(readBuffer, 8); @@ -1269,7 +1269,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) uint32_t categoryRecordOffset0 = ReadUint32(readBuffer, 44); uint32_t categoryRecordOffset1 = ReadUint32(readBuffer, 48); BOOST_TEST(categoryRecordOffset0 == 64); // Category record offset for "category1" - BOOST_TEST(categoryRecordOffset1 == 472); // Category record offset for "category2" + BOOST_TEST(categoryRecordOffset1 == 168); // Category record offset for "category2" // Get the device record pool offset uint32_t uint32_t_size = sizeof(uint32_t); @@ -1584,7 +1584,8 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) const Category* category = counterDirectory.GetCategory(categoryRecord.name); BOOST_CHECK(category); BOOST_CHECK(category->m_Name == categoryRecord.name); - BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count); + BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count + static_cast<size_t>(numberOfCores) -1); + BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count + static_cast<size_t>(numberOfCores) -1); // Check that the event records are correct for (const EventRecord& eventRecord : categoryRecord.event_records) diff --git a/src/profiling/test/SendTimelinePacketTests.cpp b/src/profiling/test/SendTimelinePacketTests.cpp index 98b161f65e..4a13ebf824 100644 --- a/src/profiling/test/SendTimelinePacketTests.cpp +++ b/src/profiling/test/SendTimelinePacketTests.cpp @@ -15,6 +15,7 @@ #include <LabelsAndEventClasses.hpp> #include <functional> +#include <Runtime.hpp> using namespace armnn::profiling; @@ -395,10 +396,12 @@ BOOST_AUTO_TEST_CASE(SendTimelinePacketTests3) BOOST_AUTO_TEST_CASE(GetGuidsFromProfilingService) { - armnn::IRuntime::CreationOptions::ExternalProfilingOptions options; - options.m_EnableProfiling = true; - armnn::profiling::ProfilingService profilingService; - profilingService.ResetExternalProfilingOptions(options, true); + armnn::IRuntime::CreationOptions options; + options.m_ProfilingOptions.m_EnableProfiling = true; + armnn::Runtime runtime(options); + armnn::profiling::ProfilingService profilingService(runtime); + + profilingService.ResetExternalProfilingOptions(options.m_ProfilingOptions, true); ProfilingStaticGuid staticGuid = profilingService.GetStaticId("dummy"); std::hash<std::string> hasher; uint64_t hash = static_cast<uint64_t>(hasher("dummy")); |