From d0613b56cea7eba0604e0548bddffd773a4eb554 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Wed, 9 Oct 2019 16:47:04 +0100 Subject: IVGCVSW-3937 Improve the Connection Acknowledged Handler * The Connection Acknowledged Handler should report an error is it's called while in a wrong state * Stopping the threads in the ProfilingService before having to start them again * Updated the unit tests to check the changes * Removed unnecessary Packet.cpp file * Fixed memory leak Signed-off-by: Matteo Martincigh Change-Id: I8c4d33b4d97994df86fe6c9f8c659f880ec64c16 --- src/profiling/CommandHandler.cpp | 23 +- .../ConnectionAcknowledgedCommandHandler.cpp | 35 ++- src/profiling/Packet.cpp | 51 ----- src/profiling/Packet.hpp | 39 ++-- src/profiling/ProfilingService.cpp | 28 ++- src/profiling/ProfilingService.hpp | 6 +- .../RequestCounterDirectoryCommandHandler.cpp | 4 +- src/profiling/SendCounterPacket.cpp | 2 +- src/profiling/test/ProfilingTests.cpp | 239 ++++----------------- src/profiling/test/ProfilingTests.hpp | 200 +++++++++++++++++ src/profiling/test/SendCounterPacketTests.cpp | 6 +- src/profiling/test/SendCounterPacketTests.hpp | 49 ++++- 12 files changed, 379 insertions(+), 303 deletions(-) delete mode 100644 src/profiling/Packet.cpp create mode 100644 src/profiling/test/ProfilingTests.hpp (limited to 'src/profiling') diff --git a/src/profiling/CommandHandler.cpp b/src/profiling/CommandHandler.cpp index 86fa2571df..cc68dcf74d 100644 --- a/src/profiling/CommandHandler.cpp +++ b/src/profiling/CommandHandler.cpp @@ -5,6 +5,8 @@ #include "CommandHandler.hpp" +#include + namespace armnn { @@ -39,7 +41,14 @@ void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection) { try { - Packet packet = profilingConnection.ReadPacket(m_Timeout); + Packet packet = profilingConnection.ReadPacket(m_Timeout.load()); + + if (packet.IsEmpty()) + { + // Nothing to do, continue + continue; + } + Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId()); CommandHandlerFunctor* commandHandlerFunctor = @@ -49,19 +58,15 @@ void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection) } catch (const armnn::TimeoutException&) { - if (m_StopAfterTimeout) + if (m_StopAfterTimeout.load()) { - m_KeepRunning.store(false, std::memory_order_relaxed); + m_KeepRunning.store(false); } } catch (const Exception& e) { - // Log the error - BOOST_LOG_TRIVIAL(warning) << "An error has occurred when handling a command: " - << e.what(); - - // Might want to differentiate the errors more - m_KeepRunning.store(false); + // Log the error and continue + BOOST_LOG_TRIVIAL(warning) << "An error has occurred when handling a command: " << e.what() << std::endl; } } while (m_KeepRunning.load()); diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp index f90b601b7e..9d2d1a2bd2 100644 --- a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp +++ b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp @@ -7,6 +7,8 @@ #include +#include + namespace armnn { @@ -15,15 +17,34 @@ namespace profiling void ConnectionAcknowledgedCommandHandler::operator()(const Packet& packet) { - if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 1u)) + ProfilingState currentState = m_StateMachine.GetCurrentState(); + switch (currentState) { - throw armnn::InvalidArgumentException(std::string("Expected Packet family = 0, id = 1 but received family = ") - + std::to_string(packet.GetPacketFamily()) - + " id = " + std::to_string(packet.GetPacketId())); + case ProfilingState::Uninitialised: + case ProfilingState::NotConnected: + throw RuntimeException(boost::str(boost::format("Connection Acknowledged Handler invoked while in an " + "wrong state: %1%") + % GetProfilingStateName(currentState))); + case ProfilingState::WaitingForAck: + // Process the packet + if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 1u)) + { + throw armnn::InvalidArgumentException(boost::str(boost::format("Expected Packet family = 0, id = 1 but " + "received family = %1%, id = %2%") + % packet.GetPacketFamily() + % packet.GetPacketId())); + } + + // Once a Connection Acknowledged packet has been received, move to the Active state immediately + m_StateMachine.TransitionToState(ProfilingState::Active); + + break; + case ProfilingState::Active: + return; // NOP + default: + throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%") + % static_cast(currentState))); } - - // Once a Connection Acknowledged packet has been received, move to the Active state immediately - m_StateMachine.TransitionToState(ProfilingState::Active); } } // namespace profiling diff --git a/src/profiling/Packet.cpp b/src/profiling/Packet.cpp deleted file mode 100644 index 4cfa42bbc9..0000000000 --- a/src/profiling/Packet.cpp +++ /dev/null @@ -1,51 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "Packet.hpp" - -namespace armnn -{ - -namespace profiling -{ - -std::uint32_t Packet::GetHeader() const -{ - return m_Header; -} - -std::uint32_t Packet::GetPacketFamily() const -{ - return m_PacketFamily; -} - -std::uint32_t Packet::GetPacketId() const -{ - return m_PacketId; -} - -std::uint32_t Packet::GetLength() const -{ - return m_Length; -} - -const char* const Packet::GetData() const -{ - return m_Data.get(); -} - -std::uint32_t Packet::GetPacketClass() const -{ - return (m_PacketId >> 3); -} - -std::uint32_t Packet::GetPacketType() const -{ - return (m_PacketId & 7); -} - -} // namespace profiling - -} // namespace armnn diff --git a/src/profiling/Packet.hpp b/src/profiling/Packet.hpp index 2aae14b741..fae368b64e 100644 --- a/src/profiling/Packet.hpp +++ b/src/profiling/Packet.hpp @@ -2,11 +2,12 @@ // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // + #pragma once #include -#include +#include namespace armnn { @@ -46,26 +47,32 @@ public: } } - Packet(Packet&& other) : - m_Header(other.m_Header), - m_PacketFamily(other.m_PacketFamily), - m_PacketId(other.m_PacketId), - m_Length(other.m_Length), - m_Data(std::move(other.m_Data)) - {} + Packet(Packet&& other) + : m_Header(other.m_Header) + , m_PacketFamily(other.m_PacketFamily) + , m_PacketId(other.m_PacketId) + , m_Length(other.m_Length) + , m_Data(std::move(other.m_Data)) + { + other.m_Header = 0; + other.m_PacketFamily = 0; + other.m_PacketId = 0; + other.m_Length = 0; + } + + ~Packet() = default; Packet(const Packet& other) = delete; Packet& operator=(const Packet&) = delete; Packet& operator=(Packet&&) = default; - uint32_t GetHeader() const; - uint32_t GetPacketFamily() const; - uint32_t GetPacketId() const; - uint32_t GetLength() const; - const char* const GetData() const; - - uint32_t GetPacketClass() const; - uint32_t GetPacketType() const; + uint32_t GetHeader() const { return m_Header; } + uint32_t GetPacketFamily() const { return m_PacketFamily; } + uint32_t GetPacketId() const { return m_PacketId; } + uint32_t GetPacketClass() const { return m_PacketId >> 3; } + uint32_t GetPacketType() const { return m_PacketId & 7; } + uint32_t GetLength() const { return m_Length; } + const char* const GetData() const { return m_Data.get(); } bool IsEmpty() { return m_Header == 0 && m_Length == 0; } diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp index 19cf9cb58e..693f8337db 100644 --- a/src/profiling/ProfilingService.cpp +++ b/src/profiling/ProfilingService.cpp @@ -47,7 +47,11 @@ void ProfilingService::Update() m_StateMachine.TransitionToState(ProfilingState::NotConnected); break; case ProfilingState::NotConnected: - BOOST_ASSERT(m_ProfilingConnectionFactory); + // Stop the command thread (if running) + m_CommandHandler.Stop(); + + // Stop the send thread (if running) + m_SendCounterPacket.Stop(false); // Reset any existing profiling connection m_ProfilingConnection.reset(); @@ -55,13 +59,13 @@ void ProfilingService::Update() try { // Setup the profiling connection - //m_ProfilingConnection = m_ProfilingConnectionFactory.GetProfilingConnection(m_Options); + BOOST_ASSERT(m_ProfilingConnectionFactory); m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options); } catch (const Exception& e) { BOOST_LOG_TRIVIAL(warning) << "An error has occurred when creating the profiling connection: " - << e.what(); + << e.what() << std::endl; } // Move to the next state @@ -229,13 +233,23 @@ void ProfilingService::InitializeCounterValue(uint16_t counterUid) void ProfilingService::Reset() { // Reset the profiling service - m_CounterDirectory.Clear(); + + // The order in which we reset/stop the components is not trivial! + + // First stop the threads (Command Handler first)... + m_CommandHandler.Stop(); + m_SendCounterPacket.Stop(false); + + // ...then destroy the profiling connection... m_ProfilingConnection.reset(); - m_StateMachine.Reset(); + + // ...then delete all the counter data and configuration... m_CounterIndex.clear(); m_CounterValues.clear(); - m_CommandHandler.Stop(); - m_SendCounterPacket.Stop(false); + m_CounterDirectory.Clear(); + + // ...finally reset the profiling state machine + m_StateMachine.Reset(); } } // namespace profiling diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp index 50a938e33d..edeb6bde90 100644 --- a/src/profiling/ProfilingService.hpp +++ b/src/profiling/ProfilingService.hpp @@ -109,7 +109,7 @@ protected: } ~ProfilingService() = default; - // Protected method for testing + // Protected methods for testing void SwapProfilingConnectionFactory(ProfilingService& instance, IProfilingConnectionFactory* other, IProfilingConnectionFactory*& backup) @@ -120,6 +120,10 @@ protected: backup = instance.m_ProfilingConnectionFactory.release(); instance.m_ProfilingConnectionFactory.reset(other); } + IProfilingConnection* GetProfilingConnection(ProfilingService& instance) + { + return instance.m_ProfilingConnection.get(); + } }; } // namespace profiling diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.cpp b/src/profiling/RequestCounterDirectoryCommandHandler.cpp index f186add357..0fdcf10de4 100644 --- a/src/profiling/RequestCounterDirectoryCommandHandler.cpp +++ b/src/profiling/RequestCounterDirectoryCommandHandler.cpp @@ -5,6 +5,8 @@ #include "RequestCounterDirectoryCommandHandler.hpp" +#include + namespace armnn { @@ -21,4 +23,4 @@ void RequestCounterDirectoryCommandHandler::operator()(const Packet& packet) } // namespace profiling -} // namespace armnn \ No newline at end of file +} // namespace armnn diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp index b9f2b187b7..e48da3ed7c 100644 --- a/src/profiling/SendCounterPacket.cpp +++ b/src/profiling/SendCounterPacket.cpp @@ -945,7 +945,7 @@ void SendCounterPacket::Stop(bool rethrowSendThreadExceptions) // Exception handling lock scope - Begin { // Lock the mutex to handle any exception coming from the send thread - std::unique_lock lock(m_WaitMutex); + std::lock_guard lock(m_WaitMutex); // Check if there's an exception to rethrow if (m_SendThreadException) diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index de92fb9eb0..80d99dd7ab 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -3,11 +3,10 @@ // SPDX-License-Identifier: MIT // -#include "SendCounterPacketTests.hpp" +#include "ProfilingTests.hpp" #include #include -#include #include #include #include @@ -19,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -27,21 +25,16 @@ #include -#include #include #include #include -#include #include #include -#include #include #include #include -#include -#include using namespace armnn::profiling; @@ -94,59 +87,6 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons) BOOST_CHECK(vect == expectedVect); } -class TestProfilingConnectionBase :public IProfilingConnection -{ -public: - TestProfilingConnectionBase() = default; - ~TestProfilingConnectionBase() = default; - - bool IsOpen() const override { return true; } - - void Close() override {} - - bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; } - - Packet ReadPacket(uint32_t timeout) override - { - std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); - std::unique_ptr packetData; - - // Return connection acknowledged packet - return { 65536, 0, packetData }; - } -}; - -class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase -{ -public: - Packet ReadPacket(uint32_t timeout) { - if (readRequests < 3) - { - readRequests++; - throw armnn::TimeoutException("Simulate a timeout"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); - std::unique_ptr packetData; - - // Return connection acknowledged packet after three timeouts - return { 65536, 0, packetData }; - } - -private: - int readRequests = 0; -}; - -class TestProfilingConnectionArmnnError :public TestProfilingConnectionBase -{ -public: - - Packet ReadPacket(uint32_t timeout) - { - std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); - throw armnn::Exception(": Simulate a non timeout error"); - } -}; - BOOST_AUTO_TEST_CASE(CheckCommandHandler) { PacketVersionResolver packetVersionResolver; @@ -180,7 +120,7 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandler) profilingStateMachine.TransitionToState(ProfilingState::NotConnected); profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); // commandHandler1 should give up after one timeout - CommandHandler commandHandler1(1, + CommandHandler commandHandler1(10, true, commandHandlerRegistry, packetVersionResolver); @@ -204,32 +144,24 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandler) break; } - std::this_thread::sleep_for(std::chrono::milliseconds(5)); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } commandHandler1.Stop(); BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active); - CommandHandler commandHandler2(1, + CommandHandler commandHandler2(100, false, commandHandlerRegistry, packetVersionResolver); commandHandler2.Start(testProfilingConnectionArmnnError); - for (int i = 0; i < 100; i++) - { - if (!commandHandler2.IsRunning()) - { - // commandHandler2 should stop once it encounters a non timing error - return; - } - - std::this_thread::sleep_for(std::chrono::milliseconds(5)); - } + // commandHandler2 should not stop once it encounters a non timing error + std::this_thread::sleep_for(std::chrono::milliseconds(500)); - BOOST_ERROR("commandHandler2 has failed to stop"); + BOOST_CHECK(commandHandler2.IsRunning()); commandHandler2.Stop(); } @@ -300,33 +232,6 @@ BOOST_AUTO_TEST_CASE(CheckPacketClass) BOOST_CHECK(packetTest4.GetPacketClass() == 5); } -// Create Derived Classes -class TestFunctorA : public CommandHandlerFunctor -{ -public: - using CommandHandlerFunctor::CommandHandlerFunctor; - - int GetCount() { return m_Count; } - - void operator()(const Packet& packet) override - { - m_Count++; - } - -private: - int m_Count = 0; -}; - -class TestFunctorB : public TestFunctorA -{ - using TestFunctorA::TestFunctorA; -}; - -class TestFunctorC : public TestFunctorA -{ - using TestFunctorA::TestFunctorA; -}; - BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor) { // Hard code the version as it will be the same during a single profiling session @@ -455,6 +360,7 @@ BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver) BOOST_TEST(resolvedVersion == expectedVersion); } } + void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states) { ProfilingState newState = ProfilingState::NotConnected; @@ -664,32 +570,6 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisabled) BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); } -struct LogLevelSwapper -{ -public: - LogLevelSwapper(armnn::LogSeverity severity) - { - // Set the new log level - armnn::ConfigureLogging(true, true, severity); - } - ~LogLevelSwapper() - { - // The default log level for unit tests is "Fatal" - armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal); - } -}; - -struct CoutRedirect -{ -public: - CoutRedirect(std::streambuf* newStreamBuffer) - : old(std::cout.rdbuf(newStreamBuffer)) {} - ~CoutRedirect() { std::cout.rdbuf(old); } - -private: - std::streambuf* old; -}; - BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled) { // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output @@ -705,7 +585,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled) // Redirect the output to a local stream so that we can parse the warning message std::stringstream ss; - CoutRedirect coutRedirect(ss.rdbuf()); + StreamRedirector streamRedirector(std::cout, ss.rdbuf()); profilingService.Update(); BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused")); } @@ -729,7 +609,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabledRuntime) // Redirect the output to a local stream so that we can parse the warning message std::stringstream ss; - CoutRedirect coutRedirect(ss.rdbuf()); + StreamRedirector streamRedirector(std::cout, ss.rdbuf()); profilingService.Update(); BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused")); } @@ -1949,16 +1829,18 @@ BOOST_AUTO_TEST_CASE(CheckConnectionAcknowledged) profilingState.TransitionToState(ProfilingState::WaitingForAck); BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::WaitingForAck); // command handler received packet on ProfilingState::WaitingForAck - commandHandler(packetA); + BOOST_CHECK_NO_THROW(commandHandler(packetA)); BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active); // command handler received packet on ProfilingState::Active - commandHandler(packetA); + BOOST_CHECK_NO_THROW(commandHandler(packetA)); BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active); // command handler received different packet const uint32_t differentPacketId = 0x40000; Packet packetB(differentPacketId, dataLength1, uniqueData1); + profilingState.TransitionToState(ProfilingState::NotConnected); + profilingState.TransitionToState(ProfilingState::WaitingForAck); ConnectionAcknowledgedCommandHandler differentCommandHandler(differentPacketId, version, profilingState); BOOST_CHECK_THROW(differentCommandHandler(packetB), armnn::Exception); } @@ -2333,62 +2215,17 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1) BOOST_TEST(categoryRecordOffset == 44); } -class MockProfilingConnectionFactory : public IProfilingConnectionFactory -{ -public: - MockProfilingConnectionFactory() - : m_MockProfilingConnection(new MockProfilingConnection()) - {} - - IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override - { - return std::unique_ptr(m_MockProfilingConnection); - } - - MockProfilingConnection* GetMockProfilingConnection() { return m_MockProfilingConnection; } - -private: - MockProfilingConnection* m_MockProfilingConnection; -}; - -class SwapProfilingConnectionFactoryHelper : public ProfilingService -{ -public: - SwapProfilingConnectionFactoryHelper() - : ProfilingService() - , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory()) - , m_BackupProfilingConnectionFactory(nullptr) - { - SwapProfilingConnectionFactory(ProfilingService::Instance(), - m_MockProfilingConnectionFactory.get(), - m_BackupProfilingConnectionFactory); - } - ~SwapProfilingConnectionFactoryHelper() - { - IProfilingConnectionFactory* temp = nullptr; - SwapProfilingConnectionFactory(ProfilingService::Instance(), - m_BackupProfilingConnectionFactory, - temp); - } - - IProfilingConnectionFactory* GetMockProfilingConnectionFactory() { return m_MockProfilingConnectionFactory.get(); } - -private: - IProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory; - IProfilingConnectionFactory* m_BackupProfilingConnectionFactory; -}; - BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) { // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + // Swap the profiling connection factory in the profiling service instance with our mock one SwapProfilingConnectionFactoryHelper helper; - MockProfilingConnectionFactory* mockProfilingConnectionFactory = - boost::polymorphic_downcast(helper.GetMockProfilingConnectionFactory()); - BOOST_CHECK(mockProfilingConnectionFactory); - MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection(); - BOOST_CHECK(mockProfilingConnection); + + // Redirect the standard output to a local stream so that we can parse the warning message + std::stringstream ss; + StreamRedirector streamRedirector(std::cout, ss.rdbuf()); // Calculate the size of a Stream Metadata packet std::string processName = GetProcessName().substr(0, 60); @@ -2408,15 +2245,15 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); profilingService.Update(); - // Redirect the output to a local stream so that we can parse the warning message - std::stringstream ss; - CoutRedirect coutRedirect(ss.rdbuf()); - // Wait for a bit to make sure that we get the packet std::this_thread::sleep_for(std::chrono::milliseconds(100)); + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + // Check that the mock profiling connection contains one Stream Metadata packet - const std::vector& writtenData = mockProfilingConnection->GetWrittenData(); + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); BOOST_TEST(writtenData.size() == 1); BOOST_TEST(writtenData[0] == streamMetadataPacketsize); @@ -2433,7 +2270,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16); - // Connection Acknowledged Packet + // Create the Connection Acknowledged Packet Packet connectionAcknowledgedPacket(header); // Write the packet to the mock profiling connection @@ -2441,23 +2278,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that // the Connection Acknowledged packet gets processed by the profiling service - std::this_thread::sleep_for(std::chrono::seconds(1)); + std::this_thread::sleep_for(std::chrono::seconds(2)); // Check that the expected error has occurred and logged to the standard output BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=37 and Version=4194304 does not exist")); // The Connection Acknowledged Command Handler should not have updated the profiling state BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); } BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) { + // Swap the profiling connection factory in the profiling service instance with our mock one SwapProfilingConnectionFactoryHelper helper; - MockProfilingConnectionFactory* mockProfilingConnectionFactory = - boost::polymorphic_downcast(helper.GetMockProfilingConnectionFactory()); - BOOST_CHECK(mockProfilingConnectionFactory); - MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection(); - BOOST_CHECK(mockProfilingConnection); // Calculate the size of a Stream Metadata packet std::string processName = GetProcessName().substr(0, 60); @@ -2480,8 +2317,12 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) // Wait for a bit to make sure that we get the packet std::this_thread::sleep_for(std::chrono::milliseconds(100)); + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + // Check that the mock profiling connection contains one Stream Metadata packet - const std::vector& writtenData = mockProfilingConnection->GetWrittenData(); + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); BOOST_TEST(writtenData.size() == 1); BOOST_TEST(writtenData[0] == streamMetadataPacketsize); @@ -2498,7 +2339,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16); - // Connection Acknowledged Packet + // Create the Connection Acknowledged Packet Packet connectionAcknowledgedPacket(header); // Write the packet to the mock profiling connection @@ -2506,10 +2347,14 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that // the Connection Acknowledged packet gets processed by the profiling service - std::this_thread::sleep_for(std::chrono::seconds(1)); + std::this_thread::sleep_for(std::chrono::seconds(2)); // The Connection Acknowledged Command Handler should have updated the profiling state accordingly BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); } BOOST_AUTO_TEST_SUITE_END() diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp new file mode 100644 index 0000000000..3e6cf63efe --- /dev/null +++ b/src/profiling/test/ProfilingTests.hpp @@ -0,0 +1,200 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "SendCounterPacketTests.hpp" + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace armnn +{ + +namespace profiling +{ + +struct LogLevelSwapper +{ +public: + LogLevelSwapper(armnn::LogSeverity severity) + { + // Set the new log level + armnn::ConfigureLogging(true, true, severity); + } + ~LogLevelSwapper() + { + // The default log level for unit tests is "Fatal" + armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal); + } +}; + +struct CoutRedirect +{ +public: + CoutRedirect(std::streambuf* newStreamBuffer) + : m_Old(std::cout.rdbuf(newStreamBuffer)) {} + ~CoutRedirect() { std::cout.rdbuf(m_Old); } + +private: + std::streambuf* m_Old; +}; + +struct StreamRedirector +{ +public: + StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer) + : m_Stream(stream) + , m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer)) + {} + ~StreamRedirector() { m_Stream.rdbuf(m_BackupBuffer); } + +private: + std::ostream& m_Stream; + std::streambuf* m_BackupBuffer; +}; + +class TestProfilingConnectionBase : public IProfilingConnection +{ +public: + TestProfilingConnectionBase() = default; + ~TestProfilingConnectionBase() = default; + + bool IsOpen() const override { return true; } + + void Close() override {} + + bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; } + + Packet ReadPacket(uint32_t timeout) override + { + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + + // Return connection acknowledged packet + std::unique_ptr packetData; + return Packet(65536, 0, packetData); + } +}; + +class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase +{ +public: + TestProfilingConnectionTimeoutError() + : m_ReadRequests(0) + {} + + Packet ReadPacket(uint32_t timeout) override + { + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + + if (m_ReadRequests < 3) + { + m_ReadRequests++; + throw armnn::TimeoutException("Simulate a timeout error\n"); + } + + // Return connection acknowledged packet after three timeouts + std::unique_ptr packetData; + return Packet(65536, 0, packetData); + } + +private: + int m_ReadRequests; +}; + +class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase +{ +public: + Packet ReadPacket(uint32_t timeout) override + { + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + + throw armnn::Exception("Simulate a non-timeout error"); + } +}; + +class TestFunctorA : public CommandHandlerFunctor +{ +public: + using CommandHandlerFunctor::CommandHandlerFunctor; + + int GetCount() { return m_Count; } + + void operator()(const Packet& packet) override + { + m_Count++; + } + +private: + int m_Count = 0; +}; + +class TestFunctorB : public TestFunctorA +{ + using TestFunctorA::TestFunctorA; +}; + +class TestFunctorC : public TestFunctorA +{ + using TestFunctorA::TestFunctorA; +}; + +class MockProfilingConnectionFactory : public IProfilingConnectionFactory +{ +public: + IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override + { + return std::make_unique(); + } +}; + +class SwapProfilingConnectionFactoryHelper : public ProfilingService +{ +public: + using MockProfilingConnectionFactoryPtr = std::unique_ptr; + + SwapProfilingConnectionFactoryHelper() + : ProfilingService() + , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory()) + , m_BackupProfilingConnectionFactory(nullptr) + { + BOOST_CHECK(m_MockProfilingConnectionFactory); + SwapProfilingConnectionFactory(ProfilingService::Instance(), + m_MockProfilingConnectionFactory.get(), + m_BackupProfilingConnectionFactory); + BOOST_CHECK(m_BackupProfilingConnectionFactory); + } + ~SwapProfilingConnectionFactoryHelper() + { + BOOST_CHECK(m_BackupProfilingConnectionFactory); + IProfilingConnectionFactory* temp = nullptr; + SwapProfilingConnectionFactory(ProfilingService::Instance(), + m_BackupProfilingConnectionFactory, + temp); + } + + MockProfilingConnection* GetMockProfilingConnection() + { + IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance()); + return boost::polymorphic_downcast(profilingConnection); + } + +private: + MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory; + IProfilingConnectionFactory* m_BackupProfilingConnectionFactory; +}; + +} // namespace profiling + +} // namespace armnn diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp index 1216420383..00dad38078 100644 --- a/src/profiling/test/SendCounterPacketTests.cpp +++ b/src/profiling/test/SendCounterPacketTests.cpp @@ -2322,7 +2322,7 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest1) BOOST_TEST(reservedBuffer.get()); // Check that data was actually written to the profiling connection in any order - const std::vector& writtenData = mockProfilingConnection.GetWrittenData(); + const std::vector writtenData = mockProfilingConnection.GetWrittenData(); BOOST_TEST(writtenData.size() == 3); bool foundStreamMetaDataPacket = std::find(writtenData.begin(), writtenData.end(), streamMetadataPacketsize) != writtenData.end(); @@ -2391,7 +2391,7 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket3) BOOST_CHECK_NO_THROW(sendCounterPacket.Stop()); // Check that the buffer contains one Stream Metadata packet - const std::vector& writtenData = mockProfilingConnection.GetWrittenData(); + const std::vector writtenData = mockProfilingConnection.GetWrittenData(); BOOST_TEST(writtenData.size() == 1); BOOST_TEST(writtenData[0] == streamMetadataPacketsize); } @@ -2420,7 +2420,7 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket4) BOOST_TEST((profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck)); // Check that the buffer contains one Stream Metadata packet - const std::vector& writtenData = mockProfilingConnection.GetWrittenData(); + const std::vector writtenData = mockProfilingConnection.GetWrittenData(); BOOST_TEST(writtenData.size() == 1); BOOST_TEST(writtenData[0] == streamMetadataPacketsize); diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp index 48bab025dd..871ca74124 100644 --- a/src/profiling/test/SendCounterPacketTests.hpp +++ b/src/profiling/test/SendCounterPacketTests.hpp @@ -12,6 +12,7 @@ #include #include +#include #include namespace armnn @@ -19,6 +20,7 @@ namespace armnn namespace profiling { + class MockProfilingConnection : public IProfilingConnection { public: @@ -28,9 +30,19 @@ public: , m_Packet() {} - bool IsOpen() const override { return m_IsOpen; } + bool IsOpen() const override + { + std::lock_guard lock(m_Mutex); + + return m_IsOpen; + } + + void Close() override + { + std::lock_guard lock(m_Mutex); - void Close() override { m_IsOpen = false; } + m_IsOpen = false; + } bool WritePacket(const unsigned char* buffer, uint32_t length) override { @@ -39,11 +51,15 @@ public: return false; } + std::lock_guard lock(m_Mutex); + m_WrittenData.push_back(length); return true; } bool WritePacket(Packet&& packet) { + std::lock_guard lock(m_Mutex); + m_Packet = std::move(packet); return true; } @@ -51,19 +67,32 @@ public: Packet ReadPacket(uint32_t timeout) override { // Simulate a delay in the reading process - std::this_thread::sleep_for(std::chrono::milliseconds(500)); + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + + std::lock_guard lock(m_Mutex); return std::move(m_Packet); } - const std::vector& GetWrittenData() const { return m_WrittenData; } + const std::vector GetWrittenData() const + { + std::lock_guard lock(m_Mutex); + + return m_WrittenData; + } + + void Clear() + { + std::lock_guard lock(m_Mutex); - void Clear() { m_WrittenData.clear(); } + m_WrittenData.clear(); + } private: bool m_IsOpen; std::vector m_WrittenData; Packet m_Packet; + mutable std::mutex m_Mutex; }; class MockPacketBuffer : public IPacketBuffer @@ -162,7 +191,7 @@ public: IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override { - std::unique_lock lock(m_Mutex); + std::lock_guard lock(m_Mutex); reservedSize = 0; if (requestedSize > m_MaxBufferSize) @@ -176,7 +205,7 @@ public: void Commit(IPacketBufferPtr& packetBuffer, unsigned int size) override { - std::unique_lock lock(m_Mutex); + std::lock_guard lock(m_Mutex); packetBuffer->Commit(size); m_BufferList.push_back(std::move(packetBuffer)); @@ -185,14 +214,14 @@ public: void Release(IPacketBufferPtr& packetBuffer) override { - std::unique_lock lock(m_Mutex); + std::lock_guard lock(m_Mutex); packetBuffer->Release(); } IPacketBufferPtr GetReadableBuffer() override { - std::unique_lock lock(m_Mutex); + std::lock_guard lock(m_Mutex); if (m_BufferList.empty()) { @@ -206,7 +235,7 @@ public: void MarkRead(IPacketBufferPtr& packetBuffer) override { - std::unique_lock lock(m_Mutex); + std::lock_guard lock(m_Mutex); m_ReadSize += packetBuffer->GetSize(); packetBuffer->MarkRead(); -- cgit v1.2.1