aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-10-02 12:50:57 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-10-08 15:53:43 +0100
commit54fb957c9640d61ab575d7acfc4c430a15123315 (patch)
tree51ce829032913af068071be0dcfff7c7bef409b7
parentc4728ad356b73915588c971f6de38f4493078397 (diff)
downloadarmnn-54fb957c9640d61ab575d7acfc4c430a15123315.tar.gz
IVGCVSW-3937 Add the necessary components to the ProfilingService class to
process a connection to an external profiling service (e.g. gatord) * Added the required components (CommandHandlerRegistry, CommandHandler, SendCounterPacket, ...) to the ProfilingService class * Reworked the ProfilingService::Run procedure and renamed it to Update * Handling all states but Active in the Run method (future work) * Updated the unit and tests accordingly * Added component tests to check that the Connection Acknowledged packet is handled correctly * Added test util classes, made the default constructor/destructor protected to superclass a ProfilingService object * Added IProfilingConnectionFactory interface Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com> Change-Id: I010d94b18980c9e6394253f4b2bbe4fe5bb3fe4f
-rw-r--r--CMakeLists.txt1
-rw-r--r--src/profiling/CommandHandler.cpp6
-rw-r--r--src/profiling/IProfilingConnection.hpp2
-rw-r--r--src/profiling/IProfilingConnectionFactory.hpp33
-rw-r--r--src/profiling/Packet.hpp10
-rw-r--r--src/profiling/ProfilingConnectionDumpToFileDecorator.cpp2
-rw-r--r--src/profiling/ProfilingConnectionDumpToFileDecorator.hpp2
-rw-r--r--src/profiling/ProfilingConnectionFactory.hpp7
-rw-r--r--src/profiling/ProfilingService.cpp92
-rw-r--r--src/profiling/ProfilingService.hpp67
-rw-r--r--src/profiling/ProfilingStateMachine.cpp26
-rw-r--r--src/profiling/SocketProfilingConnection.cpp2
-rw-r--r--src/profiling/SocketProfilingConnection.hpp2
-rw-r--r--src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp2
-rw-r--r--src/profiling/test/ProfilingTests.cpp257
-rw-r--r--src/profiling/test/SendCounterPacketTests.hpp18
16 files changed, 448 insertions, 81 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index fc68f3afb3..3b27d05be5 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -450,6 +450,7 @@ list(APPEND armnn_sources
src/profiling/IPacketBuffer.hpp
src/profiling/IPeriodicCounterCapture.hpp
src/profiling/IProfilingConnection.hpp
+ src/profiling/IProfilingConnectionFactory.hpp
src/profiling/Packet.cpp
src/profiling/Packet.hpp
src/profiling/PacketBuffer.cpp
diff --git a/src/profiling/CommandHandler.cpp b/src/profiling/CommandHandler.cpp
index 49784056bf..86fa2571df 100644
--- a/src/profiling/CommandHandler.cpp
+++ b/src/profiling/CommandHandler.cpp
@@ -54,8 +54,12 @@ void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection)
m_KeepRunning.store(false, std::memory_order_relaxed);
}
}
- catch (...)
+ catch (const Exception& e)
{
+ // Log the error
+ BOOST_LOG_TRIVIAL(warning) << "An error has occurred when handling a command: "
+ << e.what();
+
// Might want to differentiate the errors more
m_KeepRunning.store(false);
}
diff --git a/src/profiling/IProfilingConnection.hpp b/src/profiling/IProfilingConnection.hpp
index 97f7b55477..5d6a352f1d 100644
--- a/src/profiling/IProfilingConnection.hpp
+++ b/src/profiling/IProfilingConnection.hpp
@@ -20,7 +20,7 @@ class IProfilingConnection
public:
virtual ~IProfilingConnection() {}
- virtual bool IsOpen() = 0;
+ virtual bool IsOpen() const = 0;
virtual void Close() = 0;
diff --git a/src/profiling/IProfilingConnectionFactory.hpp b/src/profiling/IProfilingConnectionFactory.hpp
new file mode 100644
index 0000000000..173421092e
--- /dev/null
+++ b/src/profiling/IProfilingConnectionFactory.hpp
@@ -0,0 +1,33 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "IProfilingConnection.hpp"
+
+#include <Runtime.hpp>
+
+#include <memory>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class IProfilingConnectionFactory
+{
+public:
+ using ExternalProfilingOptions = Runtime::CreationOptions::ExternalProfilingOptions;
+ using IProfilingConnectionPtr = std::unique_ptr<IProfilingConnection>;
+
+ virtual ~IProfilingConnectionFactory() {}
+
+ virtual IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const = 0;
+};
+
+} // namespace profiling
+
+} // namespace armnn
diff --git a/src/profiling/Packet.hpp b/src/profiling/Packet.hpp
index 7d70a48366..2aae14b741 100644
--- a/src/profiling/Packet.hpp
+++ b/src/profiling/Packet.hpp
@@ -23,6 +23,15 @@ public:
, m_Data(nullptr)
{}
+ Packet(uint32_t header)
+ : m_Header(header)
+ , m_Length(0)
+ , m_Data(nullptr)
+ {
+ m_PacketId = ((header >> 16) & 1023);
+ m_PacketFamily = (header >> 26);
+ }
+
Packet(uint32_t header, uint32_t length, std::unique_ptr<char[]>& data)
: m_Header(header)
, m_Length(length)
@@ -47,6 +56,7 @@ public:
Packet(const Packet& other) = delete;
Packet& operator=(const Packet&) = delete;
+ Packet& operator=(Packet&&) = default;
uint32_t GetHeader() const;
uint32_t GetPacketFamily() const;
diff --git a/src/profiling/ProfilingConnectionDumpToFileDecorator.cpp b/src/profiling/ProfilingConnectionDumpToFileDecorator.cpp
index cf427626ef..3d4b6bf927 100644
--- a/src/profiling/ProfilingConnectionDumpToFileDecorator.cpp
+++ b/src/profiling/ProfilingConnectionDumpToFileDecorator.cpp
@@ -34,7 +34,7 @@ ProfilingConnectionDumpToFileDecorator::~ProfilingConnectionDumpToFileDecorator(
Close();
}
-bool ProfilingConnectionDumpToFileDecorator::IsOpen()
+bool ProfilingConnectionDumpToFileDecorator::IsOpen() const
{
return m_Connection->IsOpen();
}
diff --git a/src/profiling/ProfilingConnectionDumpToFileDecorator.hpp b/src/profiling/ProfilingConnectionDumpToFileDecorator.hpp
index 95dbe55641..c2ae538138 100644
--- a/src/profiling/ProfilingConnectionDumpToFileDecorator.hpp
+++ b/src/profiling/ProfilingConnectionDumpToFileDecorator.hpp
@@ -49,7 +49,7 @@ public:
~ProfilingConnectionDumpToFileDecorator();
- bool IsOpen() override;
+ bool IsOpen() const override;
void Close() override;
diff --git a/src/profiling/ProfilingConnectionFactory.hpp b/src/profiling/ProfilingConnectionFactory.hpp
index 102c82070e..c4b10c6445 100644
--- a/src/profiling/ProfilingConnectionFactory.hpp
+++ b/src/profiling/ProfilingConnectionFactory.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include "IProfilingConnection.hpp"
+#include "IProfilingConnectionFactory.hpp"
#include <Runtime.hpp>
@@ -17,14 +17,13 @@ namespace armnn
namespace profiling
{
-class ProfilingConnectionFactory final
+class ProfilingConnectionFactory final : public IProfilingConnectionFactory
{
public:
ProfilingConnectionFactory() = default;
~ProfilingConnectionFactory() = default;
- std::unique_ptr<IProfilingConnection> GetProfilingConnection(
- const Runtime::CreationOptions::ExternalProfilingOptions& options) const;
+ IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override;
};
} // namespace profiling
diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp
index 2da0f79da2..19cf9cb58e 100644
--- a/src/profiling/ProfilingService.cpp
+++ b/src/profiling/ProfilingService.cpp
@@ -20,37 +20,76 @@ void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOpti
// Update the profiling options
m_Options = options;
+ // Check if the profiling service needs to be reset
if (resetProfilingService)
{
// Reset the profiling service
- m_CounterDirectory.Clear();
- m_ProfilingConnection.reset();
- m_StateMachine.Reset();
- m_CounterIndex.clear();
- m_CounterValues.clear();
+ Reset();
}
-
- // Re-initialize the profiling service
- Initialize();
}
-void ProfilingService::Run()
+void ProfilingService::Update()
{
- if (m_StateMachine.GetCurrentState() == ProfilingState::Uninitialised)
+ if (!m_Options.m_EnableProfiling)
{
- Initialize();
+ // Don't run if profiling is disabled
+ return;
}
- else if (m_StateMachine.GetCurrentState() == ProfilingState::NotConnected)
+
+ ProfilingState currentState = m_StateMachine.GetCurrentState();
+ switch (currentState)
{
+ case ProfilingState::Uninitialised:
+ // Initialize the profiling service
+ Initialize();
+
+ // Move to the next state
+ m_StateMachine.TransitionToState(ProfilingState::NotConnected);
+ break;
+ case ProfilingState::NotConnected:
+ BOOST_ASSERT(m_ProfilingConnectionFactory);
+
+ // Reset any existing profiling connection
+ m_ProfilingConnection.reset();
+
try
{
- m_ProfilingConnectionFactory.GetProfilingConnection(m_Options);
- m_StateMachine.TransitionToState(ProfilingState::WaitingForAck);
+ // Setup the profiling connection
+ //m_ProfilingConnection = m_ProfilingConnectionFactory.GetProfilingConnection(m_Options);
+ m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options);
}
- catch (const armnn::Exception& e)
+ catch (const Exception& e)
{
- std::cerr << e.what() << std::endl;
+ BOOST_LOG_TRIVIAL(warning) << "An error has occurred when creating the profiling connection: "
+ << e.what();
}
+
+ // Move to the next state
+ m_StateMachine.TransitionToState(m_ProfilingConnection
+ ? ProfilingState::WaitingForAck // Profiling connection obtained, wait for ack
+ : ProfilingState::NotConnected); // Profiling connection failed, stay in the
+ // "NotConnected" state
+ break;
+ case ProfilingState::WaitingForAck:
+ BOOST_ASSERT(m_ProfilingConnection);
+
+ // Start the command thread
+ m_CommandHandler.Start(*m_ProfilingConnection);
+
+ // Start the send thread, while in "WaitingForAck" state it'll send out a "Stream MetaData" packet waiting for
+ // a valid "Connection Acknowledged" packet confirming the connection
+ m_SendCounterPacket.Start(*m_ProfilingConnection);
+
+ // The connection acknowledged command handler will automatically transition the state to "Active" once a
+ // valid "Connection Acknowledged" packet has been received
+
+ break;
+ case ProfilingState::Active:
+
+ break;
+ default:
+ throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
+ % static_cast<int>(currentState)));
}
}
@@ -119,12 +158,6 @@ uint32_t ProfilingService::DecrementCounterValue(uint16_t counterUid)
void ProfilingService::Initialize()
{
- if (!m_Options.m_EnableProfiling)
- {
- // Skip the initialization if profiling is disabled
- return;
- }
-
// Register a category for the basic runtime counters
if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime"))
{
@@ -175,9 +208,6 @@ void ProfilingService::Initialize()
BOOST_ASSERT(inferencesRunCounter);
InitializeCounterValue(inferencesRunCounter->m_Uid);
}
-
- // Initialization is done, update the profiling service state
- m_StateMachine.TransitionToState(ProfilingState::NotConnected);
}
void ProfilingService::InitializeCounterValue(uint16_t counterUid)
@@ -196,6 +226,18 @@ void ProfilingService::InitializeCounterValue(uint16_t counterUid)
m_CounterIndex.at(counterUid) = counterValuePtr;
}
+void ProfilingService::Reset()
+{
+ // Reset the profiling service
+ m_CounterDirectory.Clear();
+ m_ProfilingConnection.reset();
+ m_StateMachine.Reset();
+ m_CounterIndex.clear();
+ m_CounterValues.clear();
+ m_CommandHandler.Stop();
+ m_SendCounterPacket.Stop(false);
+}
+
} // namespace profiling
} // namespace armnn
diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp
index b4cdcac76e..50a938e33d 100644
--- a/src/profiling/ProfilingService.hpp
+++ b/src/profiling/ProfilingService.hpp
@@ -9,6 +9,10 @@
#include "ProfilingConnectionFactory.hpp"
#include "CounterDirectory.hpp"
#include "ICounterValues.hpp"
+#include "CommandHandler.hpp"
+#include "BufferManager.hpp"
+#include "SendCounterPacket.hpp"
+#include "ConnectionAcknowledgedCommandHandler.hpp"
namespace armnn
{
@@ -16,10 +20,11 @@ namespace armnn
namespace profiling
{
-class ProfilingService final : public IReadWriteCounterValues
+class ProfilingService : public IReadWriteCounterValues
{
public:
using ExternalProfilingOptions = Runtime::CreationOptions::ExternalProfilingOptions;
+ using IProfilingConnectionFactoryPtr = std::unique_ptr<IProfilingConnectionFactory>;
using IProfilingConnectionPtr = std::unique_ptr<IProfilingConnection>;
using CounterIndices = std::vector<std::atomic<uint32_t>*>;
using CounterValues = std::list<std::atomic<uint32_t>>;
@@ -34,8 +39,8 @@ public:
// Resets the profiling options, optionally clears the profiling service entirely
void ResetExternalProfilingOptions(const ExternalProfilingOptions& options, bool resetProfilingService = false);
- // Runs the profiling service
- void Run();
+ // Updates the profiling service, making it transition to a new state if necessary
+ void Update();
// Getters for the profiling service state
const ICounterDirectory& GetCounterDirectory() const;
@@ -51,26 +56,70 @@ public:
uint32_t DecrementCounterValue(uint16_t counterUid) override;
private:
- // Default/copy/move constructors/destructors and copy/move assignment operators are kept private
- ProfilingService() = default;
+ // Copy/move constructors/destructors and copy/move assignment operators are deleted
ProfilingService(const ProfilingService&) = delete;
ProfilingService(ProfilingService&&) = delete;
ProfilingService& operator=(const ProfilingService&) = delete;
ProfilingService& operator=(ProfilingService&&) = delete;
- ~ProfilingService() = default;
- // Initialization functions
+ // Initialization/reset functions
void Initialize();
void InitializeCounterValue(uint16_t counterUid);
+ void Reset();
- // Profiling service state variables
+ // Profiling service components
ExternalProfilingOptions m_Options;
CounterDirectory m_CounterDirectory;
- ProfilingConnectionFactory m_ProfilingConnectionFactory;
+ IProfilingConnectionFactoryPtr m_ProfilingConnectionFactory;
IProfilingConnectionPtr m_ProfilingConnection;
ProfilingStateMachine m_StateMachine;
CounterIndices m_CounterIndex;
CounterValues m_CounterValues;
+ CommandHandlerRegistry m_CommandHandlerRegistry;
+ PacketVersionResolver m_PacketVersionResolver;
+ CommandHandler m_CommandHandler;
+ BufferManager m_BufferManager;
+ SendCounterPacket m_SendCounterPacket;
+ ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler;
+
+protected:
+ // Default constructor/destructor kept protected for testing
+ ProfilingService()
+ : m_Options()
+ , m_CounterDirectory()
+ , m_ProfilingConnectionFactory(new ProfilingConnectionFactory())
+ , m_ProfilingConnection()
+ , m_StateMachine()
+ , m_CounterIndex()
+ , m_CounterValues()
+ , m_CommandHandlerRegistry()
+ , m_PacketVersionResolver()
+ , m_CommandHandler(1000,
+ false,
+ m_CommandHandlerRegistry,
+ m_PacketVersionResolver)
+ , m_BufferManager()
+ , m_SendCounterPacket(m_StateMachine, m_BufferManager)
+ , m_ConnectionAcknowledgedCommandHandler(1,
+ m_PacketVersionResolver.ResolvePacketVersion(1).GetEncodedValue(),
+ m_StateMachine)
+ {
+ // Register the "Connection Acknowledged" command handler
+ m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler);
+ }
+ ~ProfilingService() = default;
+
+ // Protected method for testing
+ void SwapProfilingConnectionFactory(ProfilingService& instance,
+ IProfilingConnectionFactory* other,
+ IProfilingConnectionFactory*& backup)
+ {
+ BOOST_ASSERT(instance.m_ProfilingConnectionFactory);
+ BOOST_ASSERT(other);
+
+ backup = instance.m_ProfilingConnectionFactory.release();
+ instance.m_ProfilingConnectionFactory.reset(other);
+ }
};
} // namespace profiling
diff --git a/src/profiling/ProfilingStateMachine.cpp b/src/profiling/ProfilingStateMachine.cpp
index 5af5bfbed0..9d3a81f64a 100644
--- a/src/profiling/ProfilingStateMachine.cpp
+++ b/src/profiling/ProfilingStateMachine.cpp
@@ -35,50 +35,50 @@ ProfilingState ProfilingStateMachine::GetCurrentState() const
void ProfilingStateMachine::TransitionToState(ProfilingState newState)
{
- ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
+ ProfilingState currentState = m_State.load(std::memory_order::memory_order_relaxed);
switch (newState)
{
case ProfilingState::Uninitialised:
do
{
- if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised))
+ if (!IsOneOfStates(currentState, ProfilingState::Uninitialised))
{
- ThrowStateTransitionException(expectedState, newState);
+ ThrowStateTransitionException(currentState, newState);
}
}
- while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed));
+ while (!m_State.compare_exchange_strong(currentState, newState, std::memory_order::memory_order_relaxed));
break;
case ProfilingState::NotConnected:
do
{
- if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised, ProfilingState::NotConnected,
+ if (!IsOneOfStates(currentState, ProfilingState::Uninitialised, ProfilingState::NotConnected,
ProfilingState::Active))
{
- ThrowStateTransitionException(expectedState, newState);
+ ThrowStateTransitionException(currentState, newState);
}
}
- while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed));
+ while (!m_State.compare_exchange_strong(currentState, newState, std::memory_order::memory_order_relaxed));
break;
case ProfilingState::WaitingForAck:
do
{
- if (!IsOneOfStates(expectedState, ProfilingState::NotConnected, ProfilingState::WaitingForAck))
+ if (!IsOneOfStates(currentState, ProfilingState::NotConnected, ProfilingState::WaitingForAck))
{
- ThrowStateTransitionException(expectedState, newState);
+ ThrowStateTransitionException(currentState, newState);
}
}
- while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed));
+ while (!m_State.compare_exchange_strong(currentState, newState, std::memory_order::memory_order_relaxed));
break;
case ProfilingState::Active:
do
{
- if (!IsOneOfStates(expectedState, ProfilingState::WaitingForAck, ProfilingState::Active))
+ if (!IsOneOfStates(currentState, ProfilingState::WaitingForAck, ProfilingState::Active))
{
- ThrowStateTransitionException(expectedState, newState);
+ ThrowStateTransitionException(currentState, newState);
}
}
- while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed));
+ while (!m_State.compare_exchange_strong(currentState, newState, std::memory_order::memory_order_relaxed));
break;
default:
break;
diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp
index 6955f70a48..0ae7b0e1fe 100644
--- a/src/profiling/SocketProfilingConnection.cpp
+++ b/src/profiling/SocketProfilingConnection.cpp
@@ -50,7 +50,7 @@ SocketProfilingConnection::SocketProfilingConnection()
}
}
-bool SocketProfilingConnection::IsOpen()
+bool SocketProfilingConnection::IsOpen() const
{
return m_Socket[0].fd > 0;
}
diff --git a/src/profiling/SocketProfilingConnection.hpp b/src/profiling/SocketProfilingConnection.hpp
index 1ae9f17f7e..7c77a8bfc9 100644
--- a/src/profiling/SocketProfilingConnection.hpp
+++ b/src/profiling/SocketProfilingConnection.hpp
@@ -19,7 +19,7 @@ class SocketProfilingConnection : public IProfilingConnection
{
public:
SocketProfilingConnection();
- bool IsOpen() final;
+ bool IsOpen() const final;
void Close() final;
bool WritePacket(const unsigned char* buffer, uint32_t length) final;
Packet ReadPacket(uint32_t timeout) final;
diff --git a/src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp b/src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp
index 3e06cb353b..fac93c5ddf 100644
--- a/src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp
+++ b/src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp
@@ -41,7 +41,7 @@ public:
~DummyProfilingConnection() = default;
- bool IsOpen() override
+ bool IsOpen() const override
{
return m_Open;
}
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index 24ab779412..de92fb9eb0 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -27,6 +27,9 @@
#include <armnn/Conversion.hpp>
+#include <Logging.hpp>
+#include <armnn/Utils.hpp>
+
#include <boost/algorithm/string.hpp>
#include <boost/numeric/conversion/cast.hpp>
#include <boost/test/unit_test.hpp>
@@ -97,18 +100,19 @@ public:
TestProfilingConnectionBase() = default;
~TestProfilingConnectionBase() = default;
- bool IsOpen() { return true; }
+ bool IsOpen() const override { return true; }
- void Close() {}
+ void Close() override {}
- bool WritePacket(const unsigned char* buffer, uint32_t length) { return false; }
+ bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; }
- Packet ReadPacket(uint32_t timeout)
+ Packet ReadPacket(uint32_t timeout) override
{
std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
std::unique_ptr<char[]> packetData;
- //Return connection acknowledged packet
- return {65536 ,0 , packetData};
+
+ // Return connection acknowledged packet
+ return { 65536, 0, packetData };
}
};
@@ -119,12 +123,13 @@ public:
if (readRequests < 3)
{
readRequests++;
- throw armnn::TimeoutException(": Simulate a timeout");
+ throw armnn::TimeoutException("Simulate a timeout");
}
std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
std::unique_ptr<char[]> packetData;
- //Return connection acknowledged packet after three timeouts
- return {65536 ,0 , packetData};
+
+ // Return connection acknowledged packet after three timeouts
+ return { 65536, 0, packetData };
}
private:
@@ -655,15 +660,31 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisabled)
ProfilingService& profilingService = ProfilingService::Instance();
profilingService.ResetExternalProfilingOptions(options, true);
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
- profilingService.Run();
+ profilingService.Update();
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
}
-struct cerr_redirect
+struct LogLevelSwapper
{
- cerr_redirect(std::streambuf* new_buffer)
- : old(std::cerr.rdbuf(new_buffer)) {}
- ~cerr_redirect() { std::cerr.rdbuf(old); }
+public:
+ LogLevelSwapper(armnn::LogSeverity severity)
+ {
+ // Set the new log level
+ armnn::ConfigureLogging(true, true, severity);
+ }
+ ~LogLevelSwapper()
+ {
+ // The default log level for unit tests is "Fatal"
+ armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
+ }
+};
+
+struct CoutRedirect
+{
+public:
+ CoutRedirect(std::streambuf* newStreamBuffer)
+ : old(std::cout.rdbuf(newStreamBuffer)) {}
+ ~CoutRedirect() { std::cout.rdbuf(old); }
private:
std::streambuf* old;
@@ -671,35 +692,45 @@ private:
BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled)
{
+ // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
+ LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
+
armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
options.m_EnableProfiling = true;
ProfilingService& profilingService = ProfilingService::Instance();
profilingService.ResetExternalProfilingOptions(options, true);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Update();
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
- // As there is no daemon running a connection cannot be made so expect a std::cerr to console
+ // Redirect the output to a local stream so that we can parse the warning message
std::stringstream ss;
- cerr_redirect guard(ss.rdbuf());
- profilingService.Run();
+ CoutRedirect coutRedirect(ss.rdbuf());
+ profilingService.Update();
BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
}
BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabledRuntime)
{
+ // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
+ LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
+
armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
ProfilingService& profilingService = ProfilingService::Instance();
profilingService.ResetExternalProfilingOptions(options, true);
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
- profilingService.Run();
+ profilingService.Update();
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
options.m_EnableProfiling = true;
profilingService.ResetExternalProfilingOptions(options);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Update();
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
- // As there is no daemon running a connection cannot be made so expect a std::cerr to console
+ // Redirect the output to a local stream so that we can parse the warning message
std::stringstream ss;
- cerr_redirect guard(ss.rdbuf());
- profilingService.Run();
+ CoutRedirect coutRedirect(ss.rdbuf());
+ profilingService.Update();
BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
}
@@ -711,11 +742,15 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterDirectory)
const ICounterDirectory& counterDirectory0 = profilingService.GetCounterDirectory();
BOOST_CHECK(counterDirectory0.GetCounterCount() == 0);
+ profilingService.Update();
+ BOOST_CHECK(counterDirectory0.GetCounterCount() == 0);
options.m_EnableProfiling = true;
profilingService.ResetExternalProfilingOptions(options);
const ICounterDirectory& counterDirectory1 = profilingService.GetCounterDirectory();
+ BOOST_CHECK(counterDirectory1.GetCounterCount() == 0);
+ profilingService.Update();
BOOST_CHECK(counterDirectory1.GetCounterCount() != 0);
}
@@ -726,6 +761,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterValues)
ProfilingService& profilingService = ProfilingService::Instance();
profilingService.ResetExternalProfilingOptions(options, true);
+ profilingService.Update();
const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
const Counters& counters = counterDirectory.GetCounters();
BOOST_CHECK(!counters.empty());
@@ -2297,4 +2333,183 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1)
BOOST_TEST(categoryRecordOffset == 44);
}
+class MockProfilingConnectionFactory : public IProfilingConnectionFactory
+{
+public:
+ MockProfilingConnectionFactory()
+ : m_MockProfilingConnection(new MockProfilingConnection())
+ {}
+
+ IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
+ {
+ return std::unique_ptr<MockProfilingConnection>(m_MockProfilingConnection);
+ }
+
+ MockProfilingConnection* GetMockProfilingConnection() { return m_MockProfilingConnection; }
+
+private:
+ MockProfilingConnection* m_MockProfilingConnection;
+};
+
+class SwapProfilingConnectionFactoryHelper : public ProfilingService
+{
+public:
+ SwapProfilingConnectionFactoryHelper()
+ : ProfilingService()
+ , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
+ , m_BackupProfilingConnectionFactory(nullptr)
+ {
+ SwapProfilingConnectionFactory(ProfilingService::Instance(),
+ m_MockProfilingConnectionFactory.get(),
+ m_BackupProfilingConnectionFactory);
+ }
+ ~SwapProfilingConnectionFactoryHelper()
+ {
+ IProfilingConnectionFactory* temp = nullptr;
+ SwapProfilingConnectionFactory(ProfilingService::Instance(),
+ m_BackupProfilingConnectionFactory,
+ temp);
+ }
+
+ IProfilingConnectionFactory* GetMockProfilingConnectionFactory() { return m_MockProfilingConnectionFactory.get(); }
+
+private:
+ IProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
+ IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
+};
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
+{
+ // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
+ LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
+
+ SwapProfilingConnectionFactoryHelper helper;
+ MockProfilingConnectionFactory* mockProfilingConnectionFactory =
+ boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory());
+ BOOST_CHECK(mockProfilingConnectionFactory);
+ MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
+ // Calculate the size of a Stream Metadata packet
+ std::string processName = GetProcessName().substr(0, 60);
+ unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast<unsigned int>(processName.size()) + 1;
+ unsigned int streamMetadataPacketsize = 118 + processNameSize;
+
+ armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+ options.m_EnableProfiling = true;
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+
+ // Bring the profiling service to the "WaitingForAck" state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Update();
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+ profilingService.Update();
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+ profilingService.Update();
+
+ // Redirect the output to a local stream so that we can parse the warning message
+ std::stringstream ss;
+ CoutRedirect coutRedirect(ss.rdbuf());
+
+ // Wait for a bit to make sure that we get the packet
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ // Check that the mock profiling connection contains one Stream Metadata packet
+ const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData();
+ BOOST_TEST(writtenData.size() == 1);
+ BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
+
+ // Write a valid "Connection Acknowledged" packet into the mock profiling connection, to simulate a valid
+ // reply from an external profiling service
+
+ // Connection Acknowledged Packet header (word 0, word 1 is always zero):
+ // 26:31 [6] packet_family: Control Packet Family, value 0b000000
+ // 16:25 [10] packet_id: Packet identifier, value 0b0000000001
+ // 8:15 [8] reserved: Reserved, value 0b00000000
+ // 0:7 [8] reserved: Reserved, value 0b00000000
+ uint32_t packetFamily = 0;
+ uint32_t packetId = 37; // Wrong packet id!!!
+ uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+ ((packetId & 0x000003FF) << 16);
+
+ // Connection Acknowledged Packet
+ Packet connectionAcknowledgedPacket(header);
+
+ // Write the packet to the mock profiling connection
+ mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket));
+
+ // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
+ // the Connection Acknowledged packet gets processed by the profiling service
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+
+ // Check that the expected error has occurred and logged to the standard output
+ BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=37 and Version=4194304 does not exist"));
+
+ // The Connection Acknowledged Command Handler should not have updated the profiling state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
+{
+ SwapProfilingConnectionFactoryHelper helper;
+ MockProfilingConnectionFactory* mockProfilingConnectionFactory =
+ boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory());
+ BOOST_CHECK(mockProfilingConnectionFactory);
+ MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
+ // Calculate the size of a Stream Metadata packet
+ std::string processName = GetProcessName().substr(0, 60);
+ unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast<unsigned int>(processName.size()) + 1;
+ unsigned int streamMetadataPacketsize = 118 + processNameSize;
+
+ armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+ options.m_EnableProfiling = true;
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+
+ // Bring the profiling service to the "WaitingForAck" state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Update();
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+ profilingService.Update();
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+ profilingService.Update();
+
+ // Wait for a bit to make sure that we get the packet
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ // Check that the mock profiling connection contains one Stream Metadata packet
+ const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData();
+ BOOST_TEST(writtenData.size() == 1);
+ BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
+
+ // Write a valid "Connection Acknowledged" packet into the mock profiling connection, to simulate a valid
+ // reply from an external profiling service
+
+ // Connection Acknowledged Packet header (word 0, word 1 is always zero):
+ // 26:31 [6] packet_family: Control Packet Family, value 0b000000
+ // 16:25 [10] packet_id: Packet identifier, value 0b0000000001
+ // 8:15 [8] reserved: Reserved, value 0b00000000
+ // 0:7 [8] reserved: Reserved, value 0b00000000
+ uint32_t packetFamily = 0;
+ uint32_t packetId = 1;
+ uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+ ((packetId & 0x000003FF) << 16);
+
+ // Connection Acknowledged Packet
+ Packet connectionAcknowledgedPacket(header);
+
+ // Write the packet to the mock profiling connection
+ mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket));
+
+ // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
+ // the Connection Acknowledged packet gets processed by the profiling service
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+
+ // The Connection Acknowledged Command Handler should have updated the profiling state accordingly
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp
index cae02b064d..48bab025dd 100644
--- a/src/profiling/test/SendCounterPacketTests.hpp
+++ b/src/profiling/test/SendCounterPacketTests.hpp
@@ -24,9 +24,11 @@ class MockProfilingConnection : public IProfilingConnection
public:
MockProfilingConnection()
: m_IsOpen(true)
+ , m_WrittenData()
+ , m_Packet()
{}
- bool IsOpen() override { return m_IsOpen; }
+ bool IsOpen() const override { return m_IsOpen; }
void Close() override { m_IsOpen = false; }
@@ -40,8 +42,19 @@ public:
m_WrittenData.push_back(length);
return true;
}
+ bool WritePacket(Packet&& packet)
+ {
+ m_Packet = std::move(packet);
+ return true;
+ }
- Packet ReadPacket(uint32_t timeout) override { return Packet(); }
+ Packet ReadPacket(uint32_t timeout) override
+ {
+ // Simulate a delay in the reading process
+ std::this_thread::sleep_for(std::chrono::milliseconds(500));
+
+ return std::move(m_Packet);
+ }
const std::vector<uint32_t>& GetWrittenData() const { return m_WrittenData; }
@@ -50,6 +63,7 @@ public:
private:
bool m_IsOpen;
std::vector<uint32_t> m_WrittenData;
+ Packet m_Packet;
};
class MockPacketBuffer : public IPacketBuffer