From e848538efbdf01aa0b067da942c3c214f8e62826 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Thu, 10 Oct 2019 14:08:21 +0100 Subject: IVGCVSW-3964 Implement the Periodic Counter Selection command handler * Improved the PeriodicCounterPacket class to handle errors properly * Improved the PeriodicCounterSelectionCommandHandler to handle invalid counter UIDs in the selection packet * Added the Periodic Counter Selection command handler to the ProfilingService class * Code refactoring and added comments * Added WaitForPacketSent method to the SendCounterPacket class to allow waiting for the packets to be sent (useful in the unit tests) * Added unit tests and updated the old ones accordingly * Fixed threading issues with a number of unit tests Signed-off-by: Matteo Martincigh Change-Id: I271b7b0bfa801d88fe1725b934d24e30cd839ed7 --- src/profiling/test/ProfilingTests.cpp | 584 ++++++++++++++++++++++++-- src/profiling/test/ProfilingTests.hpp | 16 +- src/profiling/test/SendCounterPacketTests.hpp | 16 +- 3 files changed, 564 insertions(+), 52 deletions(-) (limited to 'src/profiling/test') diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 27bacf7145..554b7e1936 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -35,6 +35,7 @@ #include #include #include +#include using namespace armnn::profiling; @@ -1691,11 +1692,19 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) void Stop() override {} }; + class TestReadCounterValues : public IReadCounterValues + { + bool IsCounterRegistered(uint16_t counterUid) const override { return true; } + uint16_t GetCounterCount() const override { return 0; } + uint32_t GetCounterValue(uint16_t counterUid) const override { return 0; } + }; + const uint32_t packetId = 0x40000; uint32_t version = 1; Holder holder; TestCaptureThread captureThread; + TestReadCounterValues readCounterValues; MockBufferManager mockBuffer(512); SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); @@ -1718,16 +1727,29 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) Packet packetA(packetId, dataLength1, uniqueData1); - PeriodicCounterSelectionCommandHandler commandHandler(packetId, version, holder, captureThread, - sendCounterPacket); - commandHandler(packetA); + PeriodicCounterSelectionCommandHandler commandHandler(packetId, + version, + holder, + captureThread, + readCounterValues, + sendCounterPacket, + profilingStateMachine); - std::vector counterIds = holder.GetCaptureData().GetCounterIds(); + profilingStateMachine.TransitionToState(ProfilingState::Uninitialised); + BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException); + profilingStateMachine.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException); + profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException); + profilingStateMachine.TransitionToState(ProfilingState::Active); + BOOST_CHECK_NO_THROW(commandHandler(packetA)); + + const std::vector counterIdsA = holder.GetCaptureData().GetCounterIds(); BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period1); - BOOST_TEST(counterIds.size() == 2); - BOOST_TEST(counterIds[0] == 4000); - BOOST_TEST(counterIds[1] == 5000); + BOOST_TEST(counterIdsA.size() == 2); + BOOST_TEST(counterIdsA[0] == 4000); + BOOST_TEST(counterIdsA[1] == 5000); auto readBuffer = mockBuffer.GetReadableBuffer(); @@ -1766,10 +1788,10 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) commandHandler(packetB); - counterIds = holder.GetCaptureData().GetCounterIds(); + const std::vector counterIdsB = holder.GetCaptureData().GetCounterIds(); BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period2); - BOOST_TEST(counterIds.size() == 0); + BOOST_TEST(counterIdsB.size() == 0); readBuffer = mockBuffer.GetReadableBuffer(); @@ -2024,35 +2046,40 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread) public: CaptureReader() {} + bool IsCounterRegistered(uint16_t counterUid) const override + { + return m_Data.find(counterUid) != m_Data.end(); + } + uint16_t GetCounterCount() const override { return boost::numeric_cast(m_Data.size()); } - uint32_t GetCounterValue(uint16_t index) const override + uint32_t GetCounterValue(uint16_t counterUid) const override { - if (m_Data.find(index) == m_Data.end()) + if (m_Data.find(counterUid) == m_Data.end()) { return 0; } - return m_Data.at(index); + return m_Data.at(counterUid).load(); } - void SetCounterValue(uint16_t index, uint32_t value) + void SetCounterValue(uint16_t counterUid, uint32_t value) { - if (m_Data.find(index) == m_Data.end()) + if (m_Data.find(counterUid) == m_Data.end()) { - m_Data.insert(std::pair(index, value)); + m_Data.insert(std::make_pair(counterUid, value)); } else { - m_Data.at(index) = value; + m_Data.at(counterUid).store(value); } } private: - std::unordered_map m_Data; + std::unordered_map> m_Data; }; ProfilingStateMachine profilingStateMachine; @@ -2261,19 +2288,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) // Bring the profiling service to the "WaitingForAck" state BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); - profilingService.Update(); + profilingService.Update(); // Initialize the counter directory BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); - profilingService.Update(); - BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); - profilingService.Update(); - - // Wait for a bit to make sure that we get the packet - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + profilingService.Update();// Create the profiling connection // Get the mock profiling connection MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); BOOST_CHECK(mockProfilingConnection); + // Remove the packets received so far + mockProfilingConnection->Clear(); + + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); + + // Wait for the Stream Metadata packet to be sent + helper.WaitForProfilingPacketsSent(); + // Check that the mock profiling connection contains one Stream Metadata packet const std::vector writtenData = mockProfilingConnection->GetWrittenData(); BOOST_TEST(writtenData.size() == 1); @@ -2330,19 +2361,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) // Bring the profiling service to the "WaitingForAck" state BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); - profilingService.Update(); + profilingService.Update(); // Initialize the counter directory BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); - profilingService.Update(); - BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); - profilingService.Update(); - - // Wait for a bit to make sure that we get the packet - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + profilingService.Update(); // Create the profiling connection // Get the mock profiling connection MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); BOOST_CHECK(mockProfilingConnection); + // Remove the packets received so far + mockProfilingConnection->Clear(); + + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet to be sent + helper.WaitForProfilingPacketsSent(); + // Check that the mock profiling connection contains one Stream Metadata packet const std::vector writtenData = mockProfilingConnection->GetWrittenData(); BOOST_TEST(writtenData.size() == 1); @@ -2403,7 +2438,13 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket) BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); profilingService.Update(); // Create the profiling connection BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); - profilingService.Update(); // Start the threads + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state helper.ForceTransitionToState(ProfilingState::Active); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); @@ -2411,6 +2452,9 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket) MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); BOOST_CHECK(mockProfilingConnection); + // Remove the packets received so far + mockProfilingConnection->Clear(); + // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid // reply from an external profiling service @@ -2437,7 +2481,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket) // Check that the expected error has occurred and logged to the standard output BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=123 and Version=4194304 does not exist")); - // The Connection Acknowledged Command Handler should not have updated the profiling state + // The Request Counter Directory Command Handler should not have updated the profiling state BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); // Reset the profiling service to stop any running thread @@ -2462,7 +2506,13 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket) BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); profilingService.Update(); // Create the profiling connection BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); - profilingService.Update(); // Start the threads + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state helper.ForceTransitionToState(ProfilingState::Active); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); @@ -2470,6 +2520,9 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket) MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); BOOST_CHECK(mockProfilingConnection); + // Remove the packets received so far + mockProfilingConnection->Clear(); + // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid // reply from an external profiling service @@ -2489,17 +2542,470 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket) // Write the packet to the mock profiling connection mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket)); + // Wait for the Counter Directory packet to be sent + helper.WaitForProfilingPacketsSent(); + + // Check that the mock profiling connection contains one Counter Directory packet + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); + BOOST_TEST(writtenData.size() == 1); + BOOST_TEST(writtenData[0] == 416); // The size of the expected Counter Directory packet + + // The Request Counter Directory Command Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadPeriodicCounterSelectionPacket) +{ + // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output + LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Redirect the standard output to a local stream so that we can parse the warning message + std::stringstream ss; + StreamRedirector streamRedirector(std::cout, ss.rdbuf()); + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); // Initialize the counter directory + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Remove the packets received so far + mockProfilingConnection->Clear(); + + // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an + // external profiling service + + // Periodic Counter Selection packet header: + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000100 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 999; // Wrong packet id!!! + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + // Create the Periodic Counter Selection packet + Packet periodicCounterSelectionPacket(header); // Length == 0, this will disable the collection of counters + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket)); + // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that - // the Create the Request Counter packet gets processed by the profiling service + // the Periodic Counter Selection packet gets processed by the profiling service std::this_thread::sleep_for(std::chrono::seconds(2)); - // The Connection Acknowledged Command Handler should not have updated the profiling state + // Check that the expected error has occurred and logged to the standard output + BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=999 and Version=4194304 does not exist")); + + // The Periodic Counter Selection Handler should not have updated the profiling state BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); - // Check that the mock profiling connection contains one Counter Directory packet + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadPeriodicCounterSelectionPacketInvalidCounterUid) +{ + // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output + LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); // Initialize the counter directory + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Remove the packets received so far + mockProfilingConnection->Clear(); + + // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an + // external profiling service + + // Periodic Counter Selection packet header: + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000100 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 4; + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + uint32_t capturePeriod = 123456; // Some capture period (microseconds) + + // Get the first valid counter UID + const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory(); + const Counters& counters = counterDirectory.GetCounters(); + BOOST_CHECK(counters.size() > 1); + uint16_t counterUidA = counters.begin()->first; // First valid counter UID + uint16_t counterUidB = 9999; // Second invalid counter UID + + uint32_t length = 8; + + auto data = std::make_unique(length); + WriteUint32(data.get(), 0, capturePeriod); + WriteUint16(data.get(), 4, counterUidA); + WriteUint16(data.get(), 6, counterUidB); + + // Create the Periodic Counter Selection packet + Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter + // Capture thread + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket)); + + // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet + int expectedPackets = 2; + std::vector receivedPackets; + + // Keep waiting until all the expected packets have been received + do + { + helper.WaitForProfilingPacketsSent(); + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); + if (writtenData.empty()) + { + BOOST_ERROR("Packets should be available for reading at this point"); + return; + } + receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end()); + expectedPackets -= boost::numeric_cast(writtenData.size()); + } + while (expectedPackets > 0); + BOOST_TEST(!receivedPackets.empty()); + + // The size of the expected Periodic Counter Selection packet + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 14) != receivedPackets.end())); + // The size of the expected Periodic Counter Capture packet + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 22) != receivedPackets.end())); + + // The Periodic Counter Selection Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketNoCounters) +{ + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); // Initialize the counter directory + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Remove the packets received so far + mockProfilingConnection->Clear(); + + // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an + // external profiling service + + // Periodic Counter Selection packet header: + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000100 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 4; + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + // Create the Periodic Counter Selection packet + Packet periodicCounterSelectionPacket(header); // Length == 0, this will disable the collection of counters + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket)); + + // Wait for the Periodic Counter Selection packet to be sent + helper.WaitForProfilingPacketsSent(); + + // The Periodic Counter Selection Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Check that the mock profiling connection contains one Periodic Counter Selection const std::vector writtenData = mockProfilingConnection->GetWrittenData(); - BOOST_TEST(writtenData.size() == 1); - BOOST_TEST(writtenData[0] == 416); // The size of a valid Counter Directory packet + BOOST_TEST(writtenData.size() == 1); // Only one packet is expected (no Periodic Counter packets) + BOOST_TEST(writtenData[0] == 12); // The size of the expected Periodic Counter Selection (echos the sent one) + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketSingleCounter) +{ + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); // Initialize the counter directory + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Remove the packets received so far + mockProfilingConnection->Clear(); + + // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an + // external profiling service + + // Periodic Counter Selection packet header: + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000100 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 4; + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + uint32_t capturePeriod = 123456; // Some capture period (microseconds) + + // Get the first valid counter UID + const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory(); + const Counters& counters = counterDirectory.GetCounters(); + BOOST_CHECK(!counters.empty()); + uint16_t counterUid = counters.begin()->first; // Valid counter UID + + uint32_t length = 6; + + auto data = std::make_unique(length); + WriteUint32(data.get(), 0, capturePeriod); + WriteUint16(data.get(), 4, counterUid); + + // Create the Periodic Counter Selection packet + Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter + // Capture thread + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket)); + + // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet + int expectedPackets = 2; + std::vector receivedPackets; + + // Keep waiting until all the expected packets have been received + do + { + helper.WaitForProfilingPacketsSent(); + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); + if (writtenData.empty()) + { + BOOST_ERROR("Packets should be available for reading at this point"); + return; + } + receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end()); + expectedPackets -= boost::numeric_cast(writtenData.size()); + } + while (expectedPackets > 0); + BOOST_TEST(!receivedPackets.empty()); + + // The size of the expected Periodic Counter Selection packet (echos the sent one) + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 14) != receivedPackets.end())); + // The size of the expected Periodic Counter Capture packet + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 22) != receivedPackets.end())); + + // The Periodic Counter Selection Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketMultipleCounters) +{ + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); // Initialize the counter directory + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Remove the packets received so far + mockProfilingConnection->Clear(); + + // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an + // external profiling service + + // Periodic Counter Selection packet header: + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000100 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 4; + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + uint32_t capturePeriod = 123456; // Some capture period (microseconds) + + // Get the first valid counter UID + const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory(); + const Counters& counters = counterDirectory.GetCounters(); + BOOST_CHECK(counters.size() > 1); + uint16_t counterUidA = counters.begin()->first; // First valid counter UID + uint16_t counterUidB = (counters.begin()++)->first; // Second valid counter UID + + uint32_t length = 8; + + auto data = std::make_unique(length); + WriteUint32(data.get(), 0, capturePeriod); + WriteUint16(data.get(), 4, counterUidA); + WriteUint16(data.get(), 6, counterUidB); + + // Create the Periodic Counter Selection packet + Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter + // Capture thread + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket)); + + // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet + int expectedPackets = 2; + std::vector receivedPackets; + + // Keep waiting until all the expected packets have been received + do + { + helper.WaitForProfilingPacketsSent(); + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); + if (writtenData.empty()) + { + BOOST_ERROR("Packets should be available for reading at this point"); + return; + } + receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end()); + expectedPackets -= boost::numeric_cast(writtenData.size()); + } + while (expectedPackets > 0); + BOOST_TEST(!receivedPackets.empty()); + + // The size of the expected Periodic Counter Selection packet (echos the sent one) + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 16) != receivedPackets.end())); + // The size of the expected Periodic Counter Capture packet + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 28) != receivedPackets.end())); + + // The Periodic Counter Selection Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); // Reset the profiling service to stop any running thread options.m_EnableProfiling = false; diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp index 4d2f974344..21c98723be 100644 --- a/src/profiling/test/ProfilingTests.hpp +++ b/src/profiling/test/ProfilingTests.hpp @@ -9,14 +9,12 @@ #include #include -#include #include #include #include #include -#include #include namespace armnn @@ -137,15 +135,6 @@ class TestFunctorC : public TestFunctorA using TestFunctorA::TestFunctorA; }; -class MockProfilingConnectionFactory : public IProfilingConnectionFactory -{ -public: - IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override - { - return std::make_unique(); - } -}; - class SwapProfilingConnectionFactoryHelper : public ProfilingService { public: @@ -182,6 +171,11 @@ public: TransitionToState(ProfilingService::Instance(), newState); } + void WaitForProfilingPacketsSent() + { + return WaitForPacketSent(ProfilingService::Instance()); + } + private: MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory; IProfilingConnectionFactory* m_BackupProfilingConnectionFactory; diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp index 871ca74124..73fc39b437 100644 --- a/src/profiling/test/SendCounterPacketTests.hpp +++ b/src/profiling/test/SendCounterPacketTests.hpp @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -74,11 +75,13 @@ public: return std::move(m_Packet); } - const std::vector GetWrittenData() const + const std::vector GetWrittenData() { std::lock_guard lock(m_Mutex); - return m_WrittenData; + std::vector writtenData = m_WrittenData; + m_WrittenData.clear(); + return writtenData; } void Clear() @@ -95,6 +98,15 @@ private: mutable std::mutex m_Mutex; }; +class MockProfilingConnectionFactory : public IProfilingConnectionFactory +{ +public: + IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override + { + return std::make_unique(); + } +}; + class MockPacketBuffer : public IPacketBuffer { public: -- cgit v1.2.1