diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-10-10 14:30:29 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-10-10 14:30:56 +0100 |
commit | 8efc500a7465c03877db8bbe443134f2b1bbc1af (patch) | |
tree | fcf4f4a56fa5f20ace8a1e5dfc175cecfd2373a7 /src/profiling | |
parent | a3600ba71978225e4d21399fb781d4850f2bb25f (diff) | |
download | armnn-8efc500a7465c03877db8bbe443134f2b1bbc1af.tar.gz |
IVGCVSW-3963 Implement the Request Counter Directory Handler
* Integrated the RequestCounterDirectoryCommandHandler in the
ProfilingService class
* Code refactoring
* Added/Updated unit tests
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Change-Id: I60d9f8acf166e29b3dabc921dbdb8149461bd85f
Diffstat (limited to 'src/profiling')
-rw-r--r-- | src/profiling/ProfilingService.hpp | 14 | ||||
-rw-r--r-- | src/profiling/RequestCounterDirectoryCommandHandler.cpp | 36 | ||||
-rw-r--r-- | src/profiling/RequestCounterDirectoryCommandHandler.hpp | 19 | ||||
-rw-r--r-- | src/profiling/test/ProfilingTests.cpp | 205 | ||||
-rw-r--r-- | src/profiling/test/ProfilingTests.hpp | 16 |
5 files changed, 238 insertions, 52 deletions
diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp index edeb6bde90..0e66924267 100644 --- a/src/profiling/ProfilingService.hpp +++ b/src/profiling/ProfilingService.hpp @@ -13,6 +13,7 @@ #include "BufferManager.hpp" #include "SendCounterPacket.hpp" #include "ConnectionAcknowledgedCommandHandler.hpp" +#include "RequestCounterDirectoryCommandHandler.hpp" namespace armnn { @@ -81,6 +82,7 @@ private: BufferManager m_BufferManager; SendCounterPacket m_SendCounterPacket; ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler; + RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler; protected: // Default constructor/destructor kept protected for testing @@ -103,9 +105,17 @@ protected: , m_ConnectionAcknowledgedCommandHandler(1, m_PacketVersionResolver.ResolvePacketVersion(1).GetEncodedValue(), m_StateMachine) + , m_RequestCounterDirectoryCommandHandler(3, + m_PacketVersionResolver.ResolvePacketVersion(3).GetEncodedValue(), + m_CounterDirectory, + m_SendCounterPacket, + m_StateMachine) { // Register the "Connection Acknowledged" command handler m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler); + + // Register the "Request Counter Directory" command handler + m_CommandHandlerRegistry.RegisterFunctor(&m_RequestCounterDirectoryCommandHandler); } ~ProfilingService() = default; @@ -124,6 +134,10 @@ protected: { return instance.m_ProfilingConnection.get(); } + void TransitionToState(ProfilingService& instance, ProfilingState newState) + { + instance.m_StateMachine.TransitionToState(newState); + } }; } // namespace profiling diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.cpp b/src/profiling/RequestCounterDirectoryCommandHandler.cpp index 0fdcf10de4..e85acb4215 100644 --- a/src/profiling/RequestCounterDirectoryCommandHandler.cpp +++ b/src/profiling/RequestCounterDirectoryCommandHandler.cpp @@ -5,7 +5,7 @@ #include "RequestCounterDirectoryCommandHandler.hpp" -#include <boost/assert.hpp> +#include <boost/format.hpp> namespace armnn { @@ -15,10 +15,36 @@ namespace profiling void RequestCounterDirectoryCommandHandler::operator()(const Packet& packet) { - BOOST_ASSERT(packet.GetLength() == 0); - - // Write packet to Counter Stream Buffer - m_SendCounterPacket.SendCounterDirectoryPacket(m_CounterDirectory); + ProfilingState currentState = m_StateMachine.GetCurrentState(); + switch (currentState) + { + case ProfilingState::Uninitialised: + case ProfilingState::NotConnected: + case ProfilingState::WaitingForAck: + throw RuntimeException(boost::str(boost::format("Request Counter Directory Handler invoked while in an " + "wrong state: %1%") + % GetProfilingStateName(currentState))); + case ProfilingState::Active: + // Process the packet + if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 3u)) + { + throw armnn::InvalidArgumentException(boost::str(boost::format("Expected Packet family = 0, id = 3 but " + "received family = %1%, id = %2%") + % packet.GetPacketFamily() + % packet.GetPacketId())); + } + + // Write a Counter Directory packet to the Counter Stream Buffer + m_SendCounterPacket.SendCounterDirectoryPacket(m_CounterDirectory); + + // Notify the Send Thread that new data is available in the Counter Stream Buffer + m_SendCounterPacket.SetReadyToRead(); + + break; + default: + throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%") + % static_cast<int>(currentState))); + } } } // namespace profiling diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.hpp b/src/profiling/RequestCounterDirectoryCommandHandler.hpp index a03300af48..02bf64d17a 100644 --- a/src/profiling/RequestCounterDirectoryCommandHandler.hpp +++ b/src/profiling/RequestCounterDirectoryCommandHandler.hpp @@ -8,6 +8,7 @@ #include "CommandHandlerFunctor.hpp" #include "ISendCounterPacket.hpp" #include "Packet.hpp" +#include "ProfilingStateMachine.hpp" namespace armnn { @@ -19,23 +20,25 @@ class RequestCounterDirectoryCommandHandler : public CommandHandlerFunctor { public: - RequestCounterDirectoryCommandHandler(uint32_t packetId, uint32_t version, + RequestCounterDirectoryCommandHandler(uint32_t packetId, + uint32_t version, ICounterDirectory& counterDirectory, - ISendCounterPacket& sendCounterPacket) - : CommandHandlerFunctor(packetId, version), - m_CounterDirectory(counterDirectory), - m_SendCounterPacket(sendCounterPacket) + ISendCounterPacket& sendCounterPacket, + ProfilingStateMachine& profilingStateMachine) + : CommandHandlerFunctor(packetId, version) + , m_CounterDirectory(counterDirectory) + , m_SendCounterPacket(sendCounterPacket) + , m_StateMachine(profilingStateMachine) {} void operator()(const Packet& packet) override; - private: - ICounterDirectory& m_CounterDirectory; + const ICounterDirectory& m_CounterDirectory; ISendCounterPacket& m_SendCounterPacket; + const ProfilingStateMachine& m_StateMachine; }; } // namespace profiling } // namespace armnn - diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 80d99dd7ab..57a11d0503 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -2119,75 +2119,97 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread) BOOST_TEST((valueB * numSteps) == readValue); } -BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest0) +BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1) { using boost::numeric_cast; + const uint32_t packetId = 3; + const uint32_t version = 1; ProfilingStateMachine profilingStateMachine; + CounterDirectory counterDirectory; + MockBufferManager mockBuffer(1024); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + RequestCounterDirectoryCommandHandler commandHandler(packetId, + version, + counterDirectory, + sendCounterPacket, + profilingStateMachine); - const uint32_t packetId = 0x30000; - const uint32_t version = 1; + const uint32_t wrongPacketId = 47; + const uint32_t wrongHeader = (wrongPacketId & 0x000003FF) << 16; - std::unique_ptr<char[]> packetData; + Packet wrongPacket(wrongHeader); - Packet packetA(packetId, 0, packetData); + profilingStateMachine.TransitionToState(ProfilingState::Uninitialised); + BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state + profilingStateMachine.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state + profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state + profilingStateMachine.TransitionToState(ProfilingState::Active); + BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::InvalidArgumentException); // Wrong packet - MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + const uint32_t rightHeader = (packetId & 0x000003FF) << 16; - CounterDirectory counterDirectory; + Packet rightPacket(rightHeader); - RequestCounterDirectoryCommandHandler commandHandler(packetId, version, counterDirectory, sendCounterPacket); - commandHandler(packetA); + BOOST_CHECK_NO_THROW(commandHandler(rightPacket)); // Right packet auto readBuffer = mockBuffer.GetReadableBuffer(); uint32_t headerWord0 = ReadUint32(readBuffer, 0); uint32_t headerWord1 = ReadUint32(readBuffer, 4); - BOOST_TEST(((headerWord0 >> 26) & 0x3F) == 0); // packet family - BOOST_TEST(((headerWord0 >> 16) & 0x3FF) == 2); // packet id - BOOST_TEST(headerWord1 == 24); // data length + BOOST_TEST(((headerWord0 >> 26) & 0x0000003F) == 0); // packet family + BOOST_TEST(((headerWord0 >> 16) & 0x000003FF) == 2); // packet id + BOOST_TEST(headerWord1 == 24); // data length uint32_t bodyHeaderWord0 = ReadUint32(readBuffer, 8); uint16_t deviceRecordCount = numeric_cast<uint16_t>(bodyHeaderWord0 >> 16); BOOST_TEST(deviceRecordCount == 0); // device_records_count } -BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1) +BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest2) { using boost::numeric_cast; - ProfilingStateMachine profilingStateMachine; - - const uint32_t packetId = 0x30000; + const uint32_t packetId = 3; const uint32_t version = 1; - - std::unique_ptr<char[]> packetData; - - Packet packetA(packetId, 0, packetData); - + ProfilingStateMachine profilingStateMachine; + CounterDirectory counterDirectory; MockBufferManager mockBuffer(1024); SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + RequestCounterDirectoryCommandHandler commandHandler(packetId, + version, + counterDirectory, + sendCounterPacket, + profilingStateMachine); + const uint32_t header = (packetId & 0x000003FF) << 16; + Packet packet(header); - CounterDirectory counterDirectory; const Device* device = counterDirectory.RegisterDevice("deviceA", 1); const CounterSet* counterSet = counterDirectory.RegisterCounterSet("countersetA"); counterDirectory.RegisterCategory("categoryA", device->m_Uid, counterSet->m_Uid); counterDirectory.RegisterCounter("categoryA", 0, 1, 2.0f, "counterA", "descA"); counterDirectory.RegisterCounter("categoryA", 1, 1, 3.0f, "counterB", "descB"); - RequestCounterDirectoryCommandHandler commandHandler(packetId, version, counterDirectory, sendCounterPacket); - commandHandler(packetA); + profilingStateMachine.TransitionToState(ProfilingState::Uninitialised); + BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state + profilingStateMachine.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state + profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state + profilingStateMachine.TransitionToState(ProfilingState::Active); + BOOST_CHECK_NO_THROW(commandHandler(packet)); auto readBuffer = mockBuffer.GetReadableBuffer(); uint32_t headerWord0 = ReadUint32(readBuffer, 0); uint32_t headerWord1 = ReadUint32(readBuffer, 4); - BOOST_TEST(((headerWord0 >> 26) & 0x3F) == 0); // packet family - BOOST_TEST(((headerWord0 >> 16) & 0x3FF) == 2); // packet id - BOOST_TEST(headerWord1 == 240); // data length + BOOST_TEST(((headerWord0 >> 26) & 0x0000003F) == 0); // packet family + BOOST_TEST(((headerWord0 >> 16) & 0x000003FF) == 2); // packet id + BOOST_TEST(headerWord1 == 240); // data length uint32_t bodyHeaderWord0 = ReadUint32(readBuffer, 8); uint32_t bodyHeaderWord1 = ReadUint32(readBuffer, 12); @@ -2357,4 +2379,131 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) profilingService.ResetExternalProfilingOptions(options, true); } +BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket) +{ + // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output + LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Redirect the standard output to a local stream so that we can parse the warning message + std::stringstream ss; + StreamRedirector streamRedirector(std::cout, ss.rdbuf()); + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + helper.ForceTransitionToState(ProfilingState::NotConnected); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the threads + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid + // reply from an external profiling service + + // Request Counter Directory packet header (word 0, word 1 is always zero): + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000011 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 123; // Wrong packet id!!! + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + // Create the Request Counter Directory packet + Packet requestCounterDirectoryPacket(header); + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket)); + + // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that + // the Create the Request Counter packet gets processed by the profiling service + std::this_thread::sleep_for(std::chrono::seconds(2)); + + // Check that the expected error has occurred and logged to the standard output + BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=123 and Version=4194304 does not exist")); + + // The Connection Acknowledged Command Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket) +{ + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); // Initialize the counter directory + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the threads + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid + // reply from an external profiling service + + // Request Counter Directory packet header (word 0, word 1 is always zero): + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000011 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 3; + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + // Create the Request Counter Directory packet + Packet requestCounterDirectoryPacket(header); + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket)); + + // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that + // the Create the Request Counter packet gets processed by the profiling service + std::this_thread::sleep_for(std::chrono::seconds(2)); + + // The Connection Acknowledged Command Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Check that the mock profiling connection contains one Counter Directory packet + const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData(); + BOOST_TEST(writtenData.size() == 1); + BOOST_TEST(writtenData[0] == 416); // The size of a valid Counter Directory packet + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp index 3e6cf63efe..e1686162db 100644 --- a/src/profiling/test/ProfilingTests.hpp +++ b/src/profiling/test/ProfilingTests.hpp @@ -40,17 +40,6 @@ public: } }; -struct CoutRedirect -{ -public: - CoutRedirect(std::streambuf* newStreamBuffer) - : m_Old(std::cout.rdbuf(newStreamBuffer)) {} - ~CoutRedirect() { std::cout.rdbuf(m_Old); } - -private: - std::streambuf* m_Old; -}; - struct StreamRedirector { public: @@ -190,6 +179,11 @@ public: return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection); } + void ForceTransitionToState(ProfilingState newState) + { + TransitionToState(ProfilingService::Instance(), newState); + } + private: MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory; IProfilingConnectionFactory* m_BackupProfilingConnectionFactory; |