diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-10-02 12:50:57 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-10-08 15:53:43 +0100 |
commit | 54fb957c9640d61ab575d7acfc4c430a15123315 (patch) | |
tree | 51ce829032913af068071be0dcfff7c7bef409b7 /src/profiling | |
parent | c4728ad356b73915588c971f6de38f4493078397 (diff) | |
download | armnn-54fb957c9640d61ab575d7acfc4c430a15123315.tar.gz |
IVGCVSW-3937 Add the necessary components to the ProfilingService class to
process a connection to an external profiling service (e.g. gatord)
* Added the required components (CommandHandlerRegistry, CommandHandler,
SendCounterPacket, ...) to the ProfilingService class
* Reworked the ProfilingService::Run procedure and renamed it to Update
* Handling all states but Active in the Run method (future work)
* Updated the unit and tests accordingly
* Added component tests to check that the Connection Acknowledged packet
is handled correctly
* Added test util classes, made the default constructor/destructor protected
to superclass a ProfilingService object
* Added IProfilingConnectionFactory interface
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Change-Id: I010d94b18980c9e6394253f4b2bbe4fe5bb3fe4f
Diffstat (limited to 'src/profiling')
-rw-r--r-- | src/profiling/CommandHandler.cpp | 6 | ||||
-rw-r--r-- | src/profiling/IProfilingConnection.hpp | 2 | ||||
-rw-r--r-- | src/profiling/IProfilingConnectionFactory.hpp | 33 | ||||
-rw-r--r-- | src/profiling/Packet.hpp | 10 | ||||
-rw-r--r-- | src/profiling/ProfilingConnectionDumpToFileDecorator.cpp | 2 | ||||
-rw-r--r-- | src/profiling/ProfilingConnectionDumpToFileDecorator.hpp | 2 | ||||
-rw-r--r-- | src/profiling/ProfilingConnectionFactory.hpp | 7 | ||||
-rw-r--r-- | src/profiling/ProfilingService.cpp | 92 | ||||
-rw-r--r-- | src/profiling/ProfilingService.hpp | 67 | ||||
-rw-r--r-- | src/profiling/ProfilingStateMachine.cpp | 26 | ||||
-rw-r--r-- | src/profiling/SocketProfilingConnection.cpp | 2 | ||||
-rw-r--r-- | src/profiling/SocketProfilingConnection.hpp | 2 | ||||
-rw-r--r-- | src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp | 2 | ||||
-rw-r--r-- | src/profiling/test/ProfilingTests.cpp | 257 | ||||
-rw-r--r-- | src/profiling/test/SendCounterPacketTests.hpp | 18 |
15 files changed, 447 insertions, 81 deletions
diff --git a/src/profiling/CommandHandler.cpp b/src/profiling/CommandHandler.cpp index 49784056bf..86fa2571df 100644 --- a/src/profiling/CommandHandler.cpp +++ b/src/profiling/CommandHandler.cpp @@ -54,8 +54,12 @@ void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection) m_KeepRunning.store(false, std::memory_order_relaxed); } } - catch (...) + catch (const Exception& e) { + // Log the error + BOOST_LOG_TRIVIAL(warning) << "An error has occurred when handling a command: " + << e.what(); + // Might want to differentiate the errors more m_KeepRunning.store(false); } diff --git a/src/profiling/IProfilingConnection.hpp b/src/profiling/IProfilingConnection.hpp index 97f7b55477..5d6a352f1d 100644 --- a/src/profiling/IProfilingConnection.hpp +++ b/src/profiling/IProfilingConnection.hpp @@ -20,7 +20,7 @@ class IProfilingConnection public: virtual ~IProfilingConnection() {} - virtual bool IsOpen() = 0; + virtual bool IsOpen() const = 0; virtual void Close() = 0; diff --git a/src/profiling/IProfilingConnectionFactory.hpp b/src/profiling/IProfilingConnectionFactory.hpp new file mode 100644 index 0000000000..173421092e --- /dev/null +++ b/src/profiling/IProfilingConnectionFactory.hpp @@ -0,0 +1,33 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "IProfilingConnection.hpp" + +#include <Runtime.hpp> + +#include <memory> + +namespace armnn +{ + +namespace profiling +{ + +class IProfilingConnectionFactory +{ +public: + using ExternalProfilingOptions = Runtime::CreationOptions::ExternalProfilingOptions; + using IProfilingConnectionPtr = std::unique_ptr<IProfilingConnection>; + + virtual ~IProfilingConnectionFactory() {} + + virtual IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const = 0; +}; + +} // namespace profiling + +} // namespace armnn diff --git a/src/profiling/Packet.hpp b/src/profiling/Packet.hpp index 7d70a48366..2aae14b741 100644 --- a/src/profiling/Packet.hpp +++ b/src/profiling/Packet.hpp @@ -23,6 +23,15 @@ public: , m_Data(nullptr) {} + Packet(uint32_t header) + : m_Header(header) + , m_Length(0) + , m_Data(nullptr) + { + m_PacketId = ((header >> 16) & 1023); + m_PacketFamily = (header >> 26); + } + Packet(uint32_t header, uint32_t length, std::unique_ptr<char[]>& data) : m_Header(header) , m_Length(length) @@ -47,6 +56,7 @@ public: Packet(const Packet& other) = delete; Packet& operator=(const Packet&) = delete; + Packet& operator=(Packet&&) = default; uint32_t GetHeader() const; uint32_t GetPacketFamily() const; diff --git a/src/profiling/ProfilingConnectionDumpToFileDecorator.cpp b/src/profiling/ProfilingConnectionDumpToFileDecorator.cpp index cf427626ef..3d4b6bf927 100644 --- a/src/profiling/ProfilingConnectionDumpToFileDecorator.cpp +++ b/src/profiling/ProfilingConnectionDumpToFileDecorator.cpp @@ -34,7 +34,7 @@ ProfilingConnectionDumpToFileDecorator::~ProfilingConnectionDumpToFileDecorator( Close(); } -bool ProfilingConnectionDumpToFileDecorator::IsOpen() +bool ProfilingConnectionDumpToFileDecorator::IsOpen() const { return m_Connection->IsOpen(); } diff --git a/src/profiling/ProfilingConnectionDumpToFileDecorator.hpp b/src/profiling/ProfilingConnectionDumpToFileDecorator.hpp index 95dbe55641..c2ae538138 100644 --- a/src/profiling/ProfilingConnectionDumpToFileDecorator.hpp +++ b/src/profiling/ProfilingConnectionDumpToFileDecorator.hpp @@ -49,7 +49,7 @@ public: ~ProfilingConnectionDumpToFileDecorator(); - bool IsOpen() override; + bool IsOpen() const override; void Close() override; diff --git a/src/profiling/ProfilingConnectionFactory.hpp b/src/profiling/ProfilingConnectionFactory.hpp index 102c82070e..c4b10c6445 100644 --- a/src/profiling/ProfilingConnectionFactory.hpp +++ b/src/profiling/ProfilingConnectionFactory.hpp @@ -5,7 +5,7 @@ #pragma once -#include "IProfilingConnection.hpp" +#include "IProfilingConnectionFactory.hpp" #include <Runtime.hpp> @@ -17,14 +17,13 @@ namespace armnn namespace profiling { -class ProfilingConnectionFactory final +class ProfilingConnectionFactory final : public IProfilingConnectionFactory { public: ProfilingConnectionFactory() = default; ~ProfilingConnectionFactory() = default; - std::unique_ptr<IProfilingConnection> GetProfilingConnection( - const Runtime::CreationOptions::ExternalProfilingOptions& options) const; + IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override; }; } // namespace profiling diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp index 2da0f79da2..19cf9cb58e 100644 --- a/src/profiling/ProfilingService.cpp +++ b/src/profiling/ProfilingService.cpp @@ -20,37 +20,76 @@ void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOpti // Update the profiling options m_Options = options; + // Check if the profiling service needs to be reset if (resetProfilingService) { // Reset the profiling service - m_CounterDirectory.Clear(); - m_ProfilingConnection.reset(); - m_StateMachine.Reset(); - m_CounterIndex.clear(); - m_CounterValues.clear(); + Reset(); } - - // Re-initialize the profiling service - Initialize(); } -void ProfilingService::Run() +void ProfilingService::Update() { - if (m_StateMachine.GetCurrentState() == ProfilingState::Uninitialised) + if (!m_Options.m_EnableProfiling) { - Initialize(); + // Don't run if profiling is disabled + return; } - else if (m_StateMachine.GetCurrentState() == ProfilingState::NotConnected) + + ProfilingState currentState = m_StateMachine.GetCurrentState(); + switch (currentState) { + case ProfilingState::Uninitialised: + // Initialize the profiling service + Initialize(); + + // Move to the next state + m_StateMachine.TransitionToState(ProfilingState::NotConnected); + break; + case ProfilingState::NotConnected: + BOOST_ASSERT(m_ProfilingConnectionFactory); + + // Reset any existing profiling connection + m_ProfilingConnection.reset(); + try { - m_ProfilingConnectionFactory.GetProfilingConnection(m_Options); - m_StateMachine.TransitionToState(ProfilingState::WaitingForAck); + // Setup the profiling connection + //m_ProfilingConnection = m_ProfilingConnectionFactory.GetProfilingConnection(m_Options); + m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options); } - catch (const armnn::Exception& e) + catch (const Exception& e) { - std::cerr << e.what() << std::endl; + BOOST_LOG_TRIVIAL(warning) << "An error has occurred when creating the profiling connection: " + << e.what(); } + + // Move to the next state + m_StateMachine.TransitionToState(m_ProfilingConnection + ? ProfilingState::WaitingForAck // Profiling connection obtained, wait for ack + : ProfilingState::NotConnected); // Profiling connection failed, stay in the + // "NotConnected" state + break; + case ProfilingState::WaitingForAck: + BOOST_ASSERT(m_ProfilingConnection); + + // Start the command thread + m_CommandHandler.Start(*m_ProfilingConnection); + + // Start the send thread, while in "WaitingForAck" state it'll send out a "Stream MetaData" packet waiting for + // a valid "Connection Acknowledged" packet confirming the connection + m_SendCounterPacket.Start(*m_ProfilingConnection); + + // The connection acknowledged command handler will automatically transition the state to "Active" once a + // valid "Connection Acknowledged" packet has been received + + break; + case ProfilingState::Active: + + break; + default: + throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1") + % static_cast<int>(currentState))); } } @@ -119,12 +158,6 @@ uint32_t ProfilingService::DecrementCounterValue(uint16_t counterUid) void ProfilingService::Initialize() { - if (!m_Options.m_EnableProfiling) - { - // Skip the initialization if profiling is disabled - return; - } - // Register a category for the basic runtime counters if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime")) { @@ -175,9 +208,6 @@ void ProfilingService::Initialize() BOOST_ASSERT(inferencesRunCounter); InitializeCounterValue(inferencesRunCounter->m_Uid); } - - // Initialization is done, update the profiling service state - m_StateMachine.TransitionToState(ProfilingState::NotConnected); } void ProfilingService::InitializeCounterValue(uint16_t counterUid) @@ -196,6 +226,18 @@ void ProfilingService::InitializeCounterValue(uint16_t counterUid) m_CounterIndex.at(counterUid) = counterValuePtr; } +void ProfilingService::Reset() +{ + // Reset the profiling service + m_CounterDirectory.Clear(); + m_ProfilingConnection.reset(); + m_StateMachine.Reset(); + m_CounterIndex.clear(); + m_CounterValues.clear(); + m_CommandHandler.Stop(); + m_SendCounterPacket.Stop(false); +} + } // namespace profiling } // namespace armnn diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp index b4cdcac76e..50a938e33d 100644 --- a/src/profiling/ProfilingService.hpp +++ b/src/profiling/ProfilingService.hpp @@ -9,6 +9,10 @@ #include "ProfilingConnectionFactory.hpp" #include "CounterDirectory.hpp" #include "ICounterValues.hpp" +#include "CommandHandler.hpp" +#include "BufferManager.hpp" +#include "SendCounterPacket.hpp" +#include "ConnectionAcknowledgedCommandHandler.hpp" namespace armnn { @@ -16,10 +20,11 @@ namespace armnn namespace profiling { -class ProfilingService final : public IReadWriteCounterValues +class ProfilingService : public IReadWriteCounterValues { public: using ExternalProfilingOptions = Runtime::CreationOptions::ExternalProfilingOptions; + using IProfilingConnectionFactoryPtr = std::unique_ptr<IProfilingConnectionFactory>; using IProfilingConnectionPtr = std::unique_ptr<IProfilingConnection>; using CounterIndices = std::vector<std::atomic<uint32_t>*>; using CounterValues = std::list<std::atomic<uint32_t>>; @@ -34,8 +39,8 @@ public: // Resets the profiling options, optionally clears the profiling service entirely void ResetExternalProfilingOptions(const ExternalProfilingOptions& options, bool resetProfilingService = false); - // Runs the profiling service - void Run(); + // Updates the profiling service, making it transition to a new state if necessary + void Update(); // Getters for the profiling service state const ICounterDirectory& GetCounterDirectory() const; @@ -51,26 +56,70 @@ public: uint32_t DecrementCounterValue(uint16_t counterUid) override; private: - // Default/copy/move constructors/destructors and copy/move assignment operators are kept private - ProfilingService() = default; + // Copy/move constructors/destructors and copy/move assignment operators are deleted ProfilingService(const ProfilingService&) = delete; ProfilingService(ProfilingService&&) = delete; ProfilingService& operator=(const ProfilingService&) = delete; ProfilingService& operator=(ProfilingService&&) = delete; - ~ProfilingService() = default; - // Initialization functions + // Initialization/reset functions void Initialize(); void InitializeCounterValue(uint16_t counterUid); + void Reset(); - // Profiling service state variables + // Profiling service components ExternalProfilingOptions m_Options; CounterDirectory m_CounterDirectory; - ProfilingConnectionFactory m_ProfilingConnectionFactory; + IProfilingConnectionFactoryPtr m_ProfilingConnectionFactory; IProfilingConnectionPtr m_ProfilingConnection; ProfilingStateMachine m_StateMachine; CounterIndices m_CounterIndex; CounterValues m_CounterValues; + CommandHandlerRegistry m_CommandHandlerRegistry; + PacketVersionResolver m_PacketVersionResolver; + CommandHandler m_CommandHandler; + BufferManager m_BufferManager; + SendCounterPacket m_SendCounterPacket; + ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler; + +protected: + // Default constructor/destructor kept protected for testing + ProfilingService() + : m_Options() + , m_CounterDirectory() + , m_ProfilingConnectionFactory(new ProfilingConnectionFactory()) + , m_ProfilingConnection() + , m_StateMachine() + , m_CounterIndex() + , m_CounterValues() + , m_CommandHandlerRegistry() + , m_PacketVersionResolver() + , m_CommandHandler(1000, + false, + m_CommandHandlerRegistry, + m_PacketVersionResolver) + , m_BufferManager() + , m_SendCounterPacket(m_StateMachine, m_BufferManager) + , m_ConnectionAcknowledgedCommandHandler(1, + m_PacketVersionResolver.ResolvePacketVersion(1).GetEncodedValue(), + m_StateMachine) + { + // Register the "Connection Acknowledged" command handler + m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler); + } + ~ProfilingService() = default; + + // Protected method for testing + void SwapProfilingConnectionFactory(ProfilingService& instance, + IProfilingConnectionFactory* other, + IProfilingConnectionFactory*& backup) + { + BOOST_ASSERT(instance.m_ProfilingConnectionFactory); + BOOST_ASSERT(other); + + backup = instance.m_ProfilingConnectionFactory.release(); + instance.m_ProfilingConnectionFactory.reset(other); + } }; } // namespace profiling diff --git a/src/profiling/ProfilingStateMachine.cpp b/src/profiling/ProfilingStateMachine.cpp index 5af5bfbed0..9d3a81f64a 100644 --- a/src/profiling/ProfilingStateMachine.cpp +++ b/src/profiling/ProfilingStateMachine.cpp @@ -35,50 +35,50 @@ ProfilingState ProfilingStateMachine::GetCurrentState() const void ProfilingStateMachine::TransitionToState(ProfilingState newState) { - ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); + ProfilingState currentState = m_State.load(std::memory_order::memory_order_relaxed); switch (newState) { case ProfilingState::Uninitialised: do { - if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised)) + if (!IsOneOfStates(currentState, ProfilingState::Uninitialised)) { - ThrowStateTransitionException(expectedState, newState); + ThrowStateTransitionException(currentState, newState); } } - while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed)); + while (!m_State.compare_exchange_strong(currentState, newState, std::memory_order::memory_order_relaxed)); break; case ProfilingState::NotConnected: do { - if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised, ProfilingState::NotConnected, + if (!IsOneOfStates(currentState, ProfilingState::Uninitialised, ProfilingState::NotConnected, ProfilingState::Active)) { - ThrowStateTransitionException(expectedState, newState); + ThrowStateTransitionException(currentState, newState); } } - while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed)); + while (!m_State.compare_exchange_strong(currentState, newState, std::memory_order::memory_order_relaxed)); break; case ProfilingState::WaitingForAck: do { - if (!IsOneOfStates(expectedState, ProfilingState::NotConnected, ProfilingState::WaitingForAck)) + if (!IsOneOfStates(currentState, ProfilingState::NotConnected, ProfilingState::WaitingForAck)) { - ThrowStateTransitionException(expectedState, newState); + ThrowStateTransitionException(currentState, newState); } } - while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed)); + while (!m_State.compare_exchange_strong(currentState, newState, std::memory_order::memory_order_relaxed)); break; case ProfilingState::Active: do { - if (!IsOneOfStates(expectedState, ProfilingState::WaitingForAck, ProfilingState::Active)) + if (!IsOneOfStates(currentState, ProfilingState::WaitingForAck, ProfilingState::Active)) { - ThrowStateTransitionException(expectedState, newState); + ThrowStateTransitionException(currentState, newState); } } - while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed)); + while (!m_State.compare_exchange_strong(currentState, newState, std::memory_order::memory_order_relaxed)); break; default: break; diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp index 6955f70a48..0ae7b0e1fe 100644 --- a/src/profiling/SocketProfilingConnection.cpp +++ b/src/profiling/SocketProfilingConnection.cpp @@ -50,7 +50,7 @@ SocketProfilingConnection::SocketProfilingConnection() } } -bool SocketProfilingConnection::IsOpen() +bool SocketProfilingConnection::IsOpen() const { return m_Socket[0].fd > 0; } diff --git a/src/profiling/SocketProfilingConnection.hpp b/src/profiling/SocketProfilingConnection.hpp index 1ae9f17f7e..7c77a8bfc9 100644 --- a/src/profiling/SocketProfilingConnection.hpp +++ b/src/profiling/SocketProfilingConnection.hpp @@ -19,7 +19,7 @@ class SocketProfilingConnection : public IProfilingConnection { public: SocketProfilingConnection(); - bool IsOpen() final; + bool IsOpen() const final; void Close() final; bool WritePacket(const unsigned char* buffer, uint32_t length) final; Packet ReadPacket(uint32_t timeout) final; diff --git a/src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp b/src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp index 3e06cb353b..fac93c5ddf 100644 --- a/src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp +++ b/src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp @@ -41,7 +41,7 @@ public: ~DummyProfilingConnection() = default; - bool IsOpen() override + bool IsOpen() const override { return m_Open; } diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 24ab779412..de92fb9eb0 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -27,6 +27,9 @@ #include <armnn/Conversion.hpp> +#include <Logging.hpp> +#include <armnn/Utils.hpp> + #include <boost/algorithm/string.hpp> #include <boost/numeric/conversion/cast.hpp> #include <boost/test/unit_test.hpp> @@ -97,18 +100,19 @@ public: TestProfilingConnectionBase() = default; ~TestProfilingConnectionBase() = default; - bool IsOpen() { return true; } + bool IsOpen() const override { return true; } - void Close() {} + void Close() override {} - bool WritePacket(const unsigned char* buffer, uint32_t length) { return false; } + bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; } - Packet ReadPacket(uint32_t timeout) + Packet ReadPacket(uint32_t timeout) override { std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); std::unique_ptr<char[]> packetData; - //Return connection acknowledged packet - return {65536 ,0 , packetData}; + + // Return connection acknowledged packet + return { 65536, 0, packetData }; } }; @@ -119,12 +123,13 @@ public: if (readRequests < 3) { readRequests++; - throw armnn::TimeoutException(": Simulate a timeout"); + throw armnn::TimeoutException("Simulate a timeout"); } std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); std::unique_ptr<char[]> packetData; - //Return connection acknowledged packet after three timeouts - return {65536 ,0 , packetData}; + + // Return connection acknowledged packet after three timeouts + return { 65536, 0, packetData }; } private: @@ -655,15 +660,31 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisabled) ProfilingService& profilingService = ProfilingService::Instance(); profilingService.ResetExternalProfilingOptions(options, true); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); - profilingService.Run(); + profilingService.Update(); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); } -struct cerr_redirect +struct LogLevelSwapper { - cerr_redirect(std::streambuf* new_buffer) - : old(std::cerr.rdbuf(new_buffer)) {} - ~cerr_redirect() { std::cerr.rdbuf(old); } +public: + LogLevelSwapper(armnn::LogSeverity severity) + { + // Set the new log level + armnn::ConfigureLogging(true, true, severity); + } + ~LogLevelSwapper() + { + // The default log level for unit tests is "Fatal" + armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal); + } +}; + +struct CoutRedirect +{ +public: + CoutRedirect(std::streambuf* newStreamBuffer) + : old(std::cout.rdbuf(newStreamBuffer)) {} + ~CoutRedirect() { std::cout.rdbuf(old); } private: std::streambuf* old; @@ -671,35 +692,45 @@ private: BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled) { + // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output + LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; options.m_EnableProfiling = true; ProfilingService& profilingService = ProfilingService::Instance(); profilingService.ResetExternalProfilingOptions(options, true); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); - // As there is no daemon running a connection cannot be made so expect a std::cerr to console + // Redirect the output to a local stream so that we can parse the warning message std::stringstream ss; - cerr_redirect guard(ss.rdbuf()); - profilingService.Run(); + CoutRedirect coutRedirect(ss.rdbuf()); + profilingService.Update(); BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused")); } BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabledRuntime) { + // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output + LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; ProfilingService& profilingService = ProfilingService::Instance(); profilingService.ResetExternalProfilingOptions(options, true); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); - profilingService.Run(); + profilingService.Update(); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); options.m_EnableProfiling = true; profilingService.ResetExternalProfilingOptions(options); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); - // As there is no daemon running a connection cannot be made so expect a std::cerr to console + // Redirect the output to a local stream so that we can parse the warning message std::stringstream ss; - cerr_redirect guard(ss.rdbuf()); - profilingService.Run(); + CoutRedirect coutRedirect(ss.rdbuf()); + profilingService.Update(); BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused")); } @@ -711,11 +742,15 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterDirectory) const ICounterDirectory& counterDirectory0 = profilingService.GetCounterDirectory(); BOOST_CHECK(counterDirectory0.GetCounterCount() == 0); + profilingService.Update(); + BOOST_CHECK(counterDirectory0.GetCounterCount() == 0); options.m_EnableProfiling = true; profilingService.ResetExternalProfilingOptions(options); const ICounterDirectory& counterDirectory1 = profilingService.GetCounterDirectory(); + BOOST_CHECK(counterDirectory1.GetCounterCount() == 0); + profilingService.Update(); BOOST_CHECK(counterDirectory1.GetCounterCount() != 0); } @@ -726,6 +761,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterValues) ProfilingService& profilingService = ProfilingService::Instance(); profilingService.ResetExternalProfilingOptions(options, true); + profilingService.Update(); const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory(); const Counters& counters = counterDirectory.GetCounters(); BOOST_CHECK(!counters.empty()); @@ -2297,4 +2333,183 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1) BOOST_TEST(categoryRecordOffset == 44); } +class MockProfilingConnectionFactory : public IProfilingConnectionFactory +{ +public: + MockProfilingConnectionFactory() + : m_MockProfilingConnection(new MockProfilingConnection()) + {} + + IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override + { + return std::unique_ptr<MockProfilingConnection>(m_MockProfilingConnection); + } + + MockProfilingConnection* GetMockProfilingConnection() { return m_MockProfilingConnection; } + +private: + MockProfilingConnection* m_MockProfilingConnection; +}; + +class SwapProfilingConnectionFactoryHelper : public ProfilingService +{ +public: + SwapProfilingConnectionFactoryHelper() + : ProfilingService() + , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory()) + , m_BackupProfilingConnectionFactory(nullptr) + { + SwapProfilingConnectionFactory(ProfilingService::Instance(), + m_MockProfilingConnectionFactory.get(), + m_BackupProfilingConnectionFactory); + } + ~SwapProfilingConnectionFactoryHelper() + { + IProfilingConnectionFactory* temp = nullptr; + SwapProfilingConnectionFactory(ProfilingService::Instance(), + m_BackupProfilingConnectionFactory, + temp); + } + + IProfilingConnectionFactory* GetMockProfilingConnectionFactory() { return m_MockProfilingConnectionFactory.get(); } + +private: + IProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory; + IProfilingConnectionFactory* m_BackupProfilingConnectionFactory; +}; + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) +{ + // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output + LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + + SwapProfilingConnectionFactoryHelper helper; + MockProfilingConnectionFactory* mockProfilingConnectionFactory = + boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory()); + BOOST_CHECK(mockProfilingConnectionFactory); + MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Calculate the size of a Stream Metadata packet + std::string processName = GetProcessName().substr(0, 60); + unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast<unsigned int>(processName.size()) + 1; + unsigned int streamMetadataPacketsize = 118 + processNameSize; + + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "WaitingForAck" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); + + // Redirect the output to a local stream so that we can parse the warning message + std::stringstream ss; + CoutRedirect coutRedirect(ss.rdbuf()); + + // Wait for a bit to make sure that we get the packet + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Check that the mock profiling connection contains one Stream Metadata packet + const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData(); + BOOST_TEST(writtenData.size() == 1); + BOOST_TEST(writtenData[0] == streamMetadataPacketsize); + + // Write a valid "Connection Acknowledged" packet into the mock profiling connection, to simulate a valid + // reply from an external profiling service + + // Connection Acknowledged Packet header (word 0, word 1 is always zero): + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000001 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 37; // Wrong packet id!!! + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + // Connection Acknowledged Packet + Packet connectionAcknowledgedPacket(header); + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket)); + + // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that + // the Connection Acknowledged packet gets processed by the profiling service + std::this_thread::sleep_for(std::chrono::seconds(1)); + + // Check that the expected error has occurred and logged to the standard output + BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=37 and Version=4194304 does not exist")); + + // The Connection Acknowledged Command Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) +{ + SwapProfilingConnectionFactoryHelper helper; + MockProfilingConnectionFactory* mockProfilingConnectionFactory = + boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory()); + BOOST_CHECK(mockProfilingConnectionFactory); + MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Calculate the size of a Stream Metadata packet + std::string processName = GetProcessName().substr(0, 60); + unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast<unsigned int>(processName.size()) + 1; + unsigned int streamMetadataPacketsize = 118 + processNameSize; + + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "WaitingForAck" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); + + // Wait for a bit to make sure that we get the packet + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Check that the mock profiling connection contains one Stream Metadata packet + const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData(); + BOOST_TEST(writtenData.size() == 1); + BOOST_TEST(writtenData[0] == streamMetadataPacketsize); + + // Write a valid "Connection Acknowledged" packet into the mock profiling connection, to simulate a valid + // reply from an external profiling service + + // Connection Acknowledged Packet header (word 0, word 1 is always zero): + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000001 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 1; + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + // Connection Acknowledged Packet + Packet connectionAcknowledgedPacket(header); + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket)); + + // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that + // the Connection Acknowledged packet gets processed by the profiling service + std::this_thread::sleep_for(std::chrono::seconds(1)); + + // The Connection Acknowledged Command Handler should have updated the profiling state accordingly + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp index cae02b064d..48bab025dd 100644 --- a/src/profiling/test/SendCounterPacketTests.hpp +++ b/src/profiling/test/SendCounterPacketTests.hpp @@ -24,9 +24,11 @@ class MockProfilingConnection : public IProfilingConnection public: MockProfilingConnection() : m_IsOpen(true) + , m_WrittenData() + , m_Packet() {} - bool IsOpen() override { return m_IsOpen; } + bool IsOpen() const override { return m_IsOpen; } void Close() override { m_IsOpen = false; } @@ -40,8 +42,19 @@ public: m_WrittenData.push_back(length); return true; } + bool WritePacket(Packet&& packet) + { + m_Packet = std::move(packet); + return true; + } - Packet ReadPacket(uint32_t timeout) override { return Packet(); } + Packet ReadPacket(uint32_t timeout) override + { + // Simulate a delay in the reading process + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + return std::move(m_Packet); + } const std::vector<uint32_t>& GetWrittenData() const { return m_WrittenData; } @@ -50,6 +63,7 @@ public: private: bool m_IsOpen; std::vector<uint32_t> m_WrittenData; + Packet m_Packet; }; class MockPacketBuffer : public IPacketBuffer |