diff options
Diffstat (limited to 'src/profiling/test/ProfilingTests.cpp')
-rw-r--r-- | src/profiling/test/ProfilingTests.cpp | 257 |
1 files changed, 236 insertions, 21 deletions
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 24ab779412..de92fb9eb0 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -27,6 +27,9 @@ #include <armnn/Conversion.hpp> +#include <Logging.hpp> +#include <armnn/Utils.hpp> + #include <boost/algorithm/string.hpp> #include <boost/numeric/conversion/cast.hpp> #include <boost/test/unit_test.hpp> @@ -97,18 +100,19 @@ public: TestProfilingConnectionBase() = default; ~TestProfilingConnectionBase() = default; - bool IsOpen() { return true; } + bool IsOpen() const override { return true; } - void Close() {} + void Close() override {} - bool WritePacket(const unsigned char* buffer, uint32_t length) { return false; } + bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; } - Packet ReadPacket(uint32_t timeout) + Packet ReadPacket(uint32_t timeout) override { std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); std::unique_ptr<char[]> packetData; - //Return connection acknowledged packet - return {65536 ,0 , packetData}; + + // Return connection acknowledged packet + return { 65536, 0, packetData }; } }; @@ -119,12 +123,13 @@ public: if (readRequests < 3) { readRequests++; - throw armnn::TimeoutException(": Simulate a timeout"); + throw armnn::TimeoutException("Simulate a timeout"); } std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); std::unique_ptr<char[]> packetData; - //Return connection acknowledged packet after three timeouts - return {65536 ,0 , packetData}; + + // Return connection acknowledged packet after three timeouts + return { 65536, 0, packetData }; } private: @@ -655,15 +660,31 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisabled) ProfilingService& profilingService = ProfilingService::Instance(); profilingService.ResetExternalProfilingOptions(options, true); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); - profilingService.Run(); + profilingService.Update(); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); } -struct cerr_redirect +struct LogLevelSwapper { - cerr_redirect(std::streambuf* new_buffer) - : old(std::cerr.rdbuf(new_buffer)) {} - ~cerr_redirect() { std::cerr.rdbuf(old); } +public: + LogLevelSwapper(armnn::LogSeverity severity) + { + // Set the new log level + armnn::ConfigureLogging(true, true, severity); + } + ~LogLevelSwapper() + { + // The default log level for unit tests is "Fatal" + armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal); + } +}; + +struct CoutRedirect +{ +public: + CoutRedirect(std::streambuf* newStreamBuffer) + : old(std::cout.rdbuf(newStreamBuffer)) {} + ~CoutRedirect() { std::cout.rdbuf(old); } private: std::streambuf* old; @@ -671,35 +692,45 @@ private: BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled) { + // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output + LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; options.m_EnableProfiling = true; ProfilingService& profilingService = ProfilingService::Instance(); profilingService.ResetExternalProfilingOptions(options, true); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); - // As there is no daemon running a connection cannot be made so expect a std::cerr to console + // Redirect the output to a local stream so that we can parse the warning message std::stringstream ss; - cerr_redirect guard(ss.rdbuf()); - profilingService.Run(); + CoutRedirect coutRedirect(ss.rdbuf()); + profilingService.Update(); BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused")); } BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabledRuntime) { + // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output + LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; ProfilingService& profilingService = ProfilingService::Instance(); profilingService.ResetExternalProfilingOptions(options, true); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); - profilingService.Run(); + profilingService.Update(); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); options.m_EnableProfiling = true; profilingService.ResetExternalProfilingOptions(options); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); - // As there is no daemon running a connection cannot be made so expect a std::cerr to console + // Redirect the output to a local stream so that we can parse the warning message std::stringstream ss; - cerr_redirect guard(ss.rdbuf()); - profilingService.Run(); + CoutRedirect coutRedirect(ss.rdbuf()); + profilingService.Update(); BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused")); } @@ -711,11 +742,15 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterDirectory) const ICounterDirectory& counterDirectory0 = profilingService.GetCounterDirectory(); BOOST_CHECK(counterDirectory0.GetCounterCount() == 0); + profilingService.Update(); + BOOST_CHECK(counterDirectory0.GetCounterCount() == 0); options.m_EnableProfiling = true; profilingService.ResetExternalProfilingOptions(options); const ICounterDirectory& counterDirectory1 = profilingService.GetCounterDirectory(); + BOOST_CHECK(counterDirectory1.GetCounterCount() == 0); + profilingService.Update(); BOOST_CHECK(counterDirectory1.GetCounterCount() != 0); } @@ -726,6 +761,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterValues) ProfilingService& profilingService = ProfilingService::Instance(); profilingService.ResetExternalProfilingOptions(options, true); + profilingService.Update(); const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory(); const Counters& counters = counterDirectory.GetCounters(); BOOST_CHECK(!counters.empty()); @@ -2297,4 +2333,183 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1) BOOST_TEST(categoryRecordOffset == 44); } +class MockProfilingConnectionFactory : public IProfilingConnectionFactory +{ +public: + MockProfilingConnectionFactory() + : m_MockProfilingConnection(new MockProfilingConnection()) + {} + + IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override + { + return std::unique_ptr<MockProfilingConnection>(m_MockProfilingConnection); + } + + MockProfilingConnection* GetMockProfilingConnection() { return m_MockProfilingConnection; } + +private: + MockProfilingConnection* m_MockProfilingConnection; +}; + +class SwapProfilingConnectionFactoryHelper : public ProfilingService +{ +public: + SwapProfilingConnectionFactoryHelper() + : ProfilingService() + , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory()) + , m_BackupProfilingConnectionFactory(nullptr) + { + SwapProfilingConnectionFactory(ProfilingService::Instance(), + m_MockProfilingConnectionFactory.get(), + m_BackupProfilingConnectionFactory); + } + ~SwapProfilingConnectionFactoryHelper() + { + IProfilingConnectionFactory* temp = nullptr; + SwapProfilingConnectionFactory(ProfilingService::Instance(), + m_BackupProfilingConnectionFactory, + temp); + } + + IProfilingConnectionFactory* GetMockProfilingConnectionFactory() { return m_MockProfilingConnectionFactory.get(); } + +private: + IProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory; + IProfilingConnectionFactory* m_BackupProfilingConnectionFactory; +}; + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) +{ + // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output + LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + + SwapProfilingConnectionFactoryHelper helper; + MockProfilingConnectionFactory* mockProfilingConnectionFactory = + boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory()); + BOOST_CHECK(mockProfilingConnectionFactory); + MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Calculate the size of a Stream Metadata packet + std::string processName = GetProcessName().substr(0, 60); + unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast<unsigned int>(processName.size()) + 1; + unsigned int streamMetadataPacketsize = 118 + processNameSize; + + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "WaitingForAck" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); + + // Redirect the output to a local stream so that we can parse the warning message + std::stringstream ss; + CoutRedirect coutRedirect(ss.rdbuf()); + + // Wait for a bit to make sure that we get the packet + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Check that the mock profiling connection contains one Stream Metadata packet + const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData(); + BOOST_TEST(writtenData.size() == 1); + BOOST_TEST(writtenData[0] == streamMetadataPacketsize); + + // Write a valid "Connection Acknowledged" packet into the mock profiling connection, to simulate a valid + // reply from an external profiling service + + // Connection Acknowledged Packet header (word 0, word 1 is always zero): + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000001 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 37; // Wrong packet id!!! + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + // Connection Acknowledged Packet + Packet connectionAcknowledgedPacket(header); + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket)); + + // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that + // the Connection Acknowledged packet gets processed by the profiling service + std::this_thread::sleep_for(std::chrono::seconds(1)); + + // Check that the expected error has occurred and logged to the standard output + BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=37 and Version=4194304 does not exist")); + + // The Connection Acknowledged Command Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) +{ + SwapProfilingConnectionFactoryHelper helper; + MockProfilingConnectionFactory* mockProfilingConnectionFactory = + boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory()); + BOOST_CHECK(mockProfilingConnectionFactory); + MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Calculate the size of a Stream Metadata packet + std::string processName = GetProcessName().substr(0, 60); + unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast<unsigned int>(processName.size()) + 1; + unsigned int streamMetadataPacketsize = 118 + processNameSize; + + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "WaitingForAck" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); + + // Wait for a bit to make sure that we get the packet + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Check that the mock profiling connection contains one Stream Metadata packet + const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData(); + BOOST_TEST(writtenData.size() == 1); + BOOST_TEST(writtenData[0] == streamMetadataPacketsize); + + // Write a valid "Connection Acknowledged" packet into the mock profiling connection, to simulate a valid + // reply from an external profiling service + + // Connection Acknowledged Packet header (word 0, word 1 is always zero): + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000001 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 1; + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + // Connection Acknowledged Packet + Packet connectionAcknowledgedPacket(header); + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket)); + + // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that + // the Connection Acknowledged packet gets processed by the profiling service + std::this_thread::sleep_for(std::chrono::seconds(1)); + + // The Connection Acknowledged Command Handler should have updated the profiling state accordingly + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); +} + BOOST_AUTO_TEST_SUITE_END() |