aboutsummaryrefslogtreecommitdiff
path: root/src/profiling
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-10-10 14:30:29 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-10-10 14:30:56 +0100
commit8efc500a7465c03877db8bbe443134f2b1bbc1af (patch)
treefcf4f4a56fa5f20ace8a1e5dfc175cecfd2373a7 /src/profiling
parenta3600ba71978225e4d21399fb781d4850f2bb25f (diff)
downloadarmnn-8efc500a7465c03877db8bbe443134f2b1bbc1af.tar.gz
IVGCVSW-3963 Implement the Request Counter Directory Handler
* Integrated the RequestCounterDirectoryCommandHandler in the ProfilingService class * Code refactoring * Added/Updated unit tests Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com> Change-Id: I60d9f8acf166e29b3dabc921dbdb8149461bd85f
Diffstat (limited to 'src/profiling')
-rw-r--r--src/profiling/ProfilingService.hpp14
-rw-r--r--src/profiling/RequestCounterDirectoryCommandHandler.cpp36
-rw-r--r--src/profiling/RequestCounterDirectoryCommandHandler.hpp19
-rw-r--r--src/profiling/test/ProfilingTests.cpp205
-rw-r--r--src/profiling/test/ProfilingTests.hpp16
5 files changed, 238 insertions, 52 deletions
diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp
index edeb6bde90..0e66924267 100644
--- a/src/profiling/ProfilingService.hpp
+++ b/src/profiling/ProfilingService.hpp
@@ -13,6 +13,7 @@
#include "BufferManager.hpp"
#include "SendCounterPacket.hpp"
#include "ConnectionAcknowledgedCommandHandler.hpp"
+#include "RequestCounterDirectoryCommandHandler.hpp"
namespace armnn
{
@@ -81,6 +82,7 @@ private:
BufferManager m_BufferManager;
SendCounterPacket m_SendCounterPacket;
ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler;
+ RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler;
protected:
// Default constructor/destructor kept protected for testing
@@ -103,9 +105,17 @@ protected:
, m_ConnectionAcknowledgedCommandHandler(1,
m_PacketVersionResolver.ResolvePacketVersion(1).GetEncodedValue(),
m_StateMachine)
+ , m_RequestCounterDirectoryCommandHandler(3,
+ m_PacketVersionResolver.ResolvePacketVersion(3).GetEncodedValue(),
+ m_CounterDirectory,
+ m_SendCounterPacket,
+ m_StateMachine)
{
// Register the "Connection Acknowledged" command handler
m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler);
+
+ // Register the "Request Counter Directory" command handler
+ m_CommandHandlerRegistry.RegisterFunctor(&m_RequestCounterDirectoryCommandHandler);
}
~ProfilingService() = default;
@@ -124,6 +134,10 @@ protected:
{
return instance.m_ProfilingConnection.get();
}
+ void TransitionToState(ProfilingService& instance, ProfilingState newState)
+ {
+ instance.m_StateMachine.TransitionToState(newState);
+ }
};
} // namespace profiling
diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.cpp b/src/profiling/RequestCounterDirectoryCommandHandler.cpp
index 0fdcf10de4..e85acb4215 100644
--- a/src/profiling/RequestCounterDirectoryCommandHandler.cpp
+++ b/src/profiling/RequestCounterDirectoryCommandHandler.cpp
@@ -5,7 +5,7 @@
#include "RequestCounterDirectoryCommandHandler.hpp"
-#include <boost/assert.hpp>
+#include <boost/format.hpp>
namespace armnn
{
@@ -15,10 +15,36 @@ namespace profiling
void RequestCounterDirectoryCommandHandler::operator()(const Packet& packet)
{
- BOOST_ASSERT(packet.GetLength() == 0);
-
- // Write packet to Counter Stream Buffer
- m_SendCounterPacket.SendCounterDirectoryPacket(m_CounterDirectory);
+ ProfilingState currentState = m_StateMachine.GetCurrentState();
+ switch (currentState)
+ {
+ case ProfilingState::Uninitialised:
+ case ProfilingState::NotConnected:
+ case ProfilingState::WaitingForAck:
+ throw RuntimeException(boost::str(boost::format("Request Counter Directory Handler invoked while in an "
+ "wrong state: %1%")
+ % GetProfilingStateName(currentState)));
+ case ProfilingState::Active:
+ // Process the packet
+ if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 3u))
+ {
+ throw armnn::InvalidArgumentException(boost::str(boost::format("Expected Packet family = 0, id = 3 but "
+ "received family = %1%, id = %2%")
+ % packet.GetPacketFamily()
+ % packet.GetPacketId()));
+ }
+
+ // Write a Counter Directory packet to the Counter Stream Buffer
+ m_SendCounterPacket.SendCounterDirectoryPacket(m_CounterDirectory);
+
+ // Notify the Send Thread that new data is available in the Counter Stream Buffer
+ m_SendCounterPacket.SetReadyToRead();
+
+ break;
+ default:
+ throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%")
+ % static_cast<int>(currentState)));
+ }
}
} // namespace profiling
diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.hpp b/src/profiling/RequestCounterDirectoryCommandHandler.hpp
index a03300af48..02bf64d17a 100644
--- a/src/profiling/RequestCounterDirectoryCommandHandler.hpp
+++ b/src/profiling/RequestCounterDirectoryCommandHandler.hpp
@@ -8,6 +8,7 @@
#include "CommandHandlerFunctor.hpp"
#include "ISendCounterPacket.hpp"
#include "Packet.hpp"
+#include "ProfilingStateMachine.hpp"
namespace armnn
{
@@ -19,23 +20,25 @@ class RequestCounterDirectoryCommandHandler : public CommandHandlerFunctor
{
public:
- RequestCounterDirectoryCommandHandler(uint32_t packetId, uint32_t version,
+ RequestCounterDirectoryCommandHandler(uint32_t packetId,
+ uint32_t version,
ICounterDirectory& counterDirectory,
- ISendCounterPacket& sendCounterPacket)
- : CommandHandlerFunctor(packetId, version),
- m_CounterDirectory(counterDirectory),
- m_SendCounterPacket(sendCounterPacket)
+ ISendCounterPacket& sendCounterPacket,
+ ProfilingStateMachine& profilingStateMachine)
+ : CommandHandlerFunctor(packetId, version)
+ , m_CounterDirectory(counterDirectory)
+ , m_SendCounterPacket(sendCounterPacket)
+ , m_StateMachine(profilingStateMachine)
{}
void operator()(const Packet& packet) override;
-
private:
- ICounterDirectory& m_CounterDirectory;
+ const ICounterDirectory& m_CounterDirectory;
ISendCounterPacket& m_SendCounterPacket;
+ const ProfilingStateMachine& m_StateMachine;
};
} // namespace profiling
} // namespace armnn
-
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index 80d99dd7ab..57a11d0503 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -2119,75 +2119,97 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread)
BOOST_TEST((valueB * numSteps) == readValue);
}
-BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest0)
+BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1)
{
using boost::numeric_cast;
+ const uint32_t packetId = 3;
+ const uint32_t version = 1;
ProfilingStateMachine profilingStateMachine;
+ CounterDirectory counterDirectory;
+ MockBufferManager mockBuffer(1024);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
+ RequestCounterDirectoryCommandHandler commandHandler(packetId,
+ version,
+ counterDirectory,
+ sendCounterPacket,
+ profilingStateMachine);
- const uint32_t packetId = 0x30000;
- const uint32_t version = 1;
+ const uint32_t wrongPacketId = 47;
+ const uint32_t wrongHeader = (wrongPacketId & 0x000003FF) << 16;
- std::unique_ptr<char[]> packetData;
+ Packet wrongPacket(wrongHeader);
- Packet packetA(packetId, 0, packetData);
+ profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
+ BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state
+ profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+ BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state
+ profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+ BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state
+ profilingStateMachine.TransitionToState(ProfilingState::Active);
+ BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::InvalidArgumentException); // Wrong packet
- MockBufferManager mockBuffer(1024);
- SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
+ const uint32_t rightHeader = (packetId & 0x000003FF) << 16;
- CounterDirectory counterDirectory;
+ Packet rightPacket(rightHeader);
- RequestCounterDirectoryCommandHandler commandHandler(packetId, version, counterDirectory, sendCounterPacket);
- commandHandler(packetA);
+ BOOST_CHECK_NO_THROW(commandHandler(rightPacket)); // Right packet
auto readBuffer = mockBuffer.GetReadableBuffer();
uint32_t headerWord0 = ReadUint32(readBuffer, 0);
uint32_t headerWord1 = ReadUint32(readBuffer, 4);
- BOOST_TEST(((headerWord0 >> 26) & 0x3F) == 0); // packet family
- BOOST_TEST(((headerWord0 >> 16) & 0x3FF) == 2); // packet id
- BOOST_TEST(headerWord1 == 24); // data length
+ BOOST_TEST(((headerWord0 >> 26) & 0x0000003F) == 0); // packet family
+ BOOST_TEST(((headerWord0 >> 16) & 0x000003FF) == 2); // packet id
+ BOOST_TEST(headerWord1 == 24); // data length
uint32_t bodyHeaderWord0 = ReadUint32(readBuffer, 8);
uint16_t deviceRecordCount = numeric_cast<uint16_t>(bodyHeaderWord0 >> 16);
BOOST_TEST(deviceRecordCount == 0); // device_records_count
}
-BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1)
+BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest2)
{
using boost::numeric_cast;
- ProfilingStateMachine profilingStateMachine;
-
- const uint32_t packetId = 0x30000;
+ const uint32_t packetId = 3;
const uint32_t version = 1;
-
- std::unique_ptr<char[]> packetData;
-
- Packet packetA(packetId, 0, packetData);
-
+ ProfilingStateMachine profilingStateMachine;
+ CounterDirectory counterDirectory;
MockBufferManager mockBuffer(1024);
SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
+ RequestCounterDirectoryCommandHandler commandHandler(packetId,
+ version,
+ counterDirectory,
+ sendCounterPacket,
+ profilingStateMachine);
+ const uint32_t header = (packetId & 0x000003FF) << 16;
+ Packet packet(header);
- CounterDirectory counterDirectory;
const Device* device = counterDirectory.RegisterDevice("deviceA", 1);
const CounterSet* counterSet = counterDirectory.RegisterCounterSet("countersetA");
counterDirectory.RegisterCategory("categoryA", device->m_Uid, counterSet->m_Uid);
counterDirectory.RegisterCounter("categoryA", 0, 1, 2.0f, "counterA", "descA");
counterDirectory.RegisterCounter("categoryA", 1, 1, 3.0f, "counterB", "descB");
- RequestCounterDirectoryCommandHandler commandHandler(packetId, version, counterDirectory, sendCounterPacket);
- commandHandler(packetA);
+ profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
+ BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state
+ profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+ BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state
+ profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+ BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state
+ profilingStateMachine.TransitionToState(ProfilingState::Active);
+ BOOST_CHECK_NO_THROW(commandHandler(packet));
auto readBuffer = mockBuffer.GetReadableBuffer();
uint32_t headerWord0 = ReadUint32(readBuffer, 0);
uint32_t headerWord1 = ReadUint32(readBuffer, 4);
- BOOST_TEST(((headerWord0 >> 26) & 0x3F) == 0); // packet family
- BOOST_TEST(((headerWord0 >> 16) & 0x3FF) == 2); // packet id
- BOOST_TEST(headerWord1 == 240); // data length
+ BOOST_TEST(((headerWord0 >> 26) & 0x0000003F) == 0); // packet family
+ BOOST_TEST(((headerWord0 >> 16) & 0x000003FF) == 2); // packet id
+ BOOST_TEST(headerWord1 == 240); // data length
uint32_t bodyHeaderWord0 = ReadUint32(readBuffer, 8);
uint32_t bodyHeaderWord1 = ReadUint32(readBuffer, 12);
@@ -2357,4 +2379,131 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
profilingService.ResetExternalProfilingOptions(options, true);
}
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket)
+{
+ // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
+ LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
+
+ // Swap the profiling connection factory in the profiling service instance with our mock one
+ SwapProfilingConnectionFactoryHelper helper;
+
+ // Redirect the standard output to a local stream so that we can parse the warning message
+ std::stringstream ss;
+ StreamRedirector streamRedirector(std::cout, ss.rdbuf());
+
+ // Reset the profiling service to the uninitialized state
+ armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+ options.m_EnableProfiling = true;
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+
+ // Bring the profiling service to the "Active" state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ helper.ForceTransitionToState(ProfilingState::NotConnected);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+ profilingService.Update(); // Create the profiling connection
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+ profilingService.Update(); // Start the threads
+ helper.ForceTransitionToState(ProfilingState::Active);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Get the mock profiling connection
+ MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
+ // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
+ // reply from an external profiling service
+
+ // Request Counter Directory packet header (word 0, word 1 is always zero):
+ // 26:31 [6] packet_family: Control Packet Family, value 0b000000
+ // 16:25 [10] packet_id: Packet identifier, value 0b0000000011
+ // 8:15 [8] reserved: Reserved, value 0b00000000
+ // 0:7 [8] reserved: Reserved, value 0b00000000
+ uint32_t packetFamily = 0;
+ uint32_t packetId = 123; // Wrong packet id!!!
+ uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+ ((packetId & 0x000003FF) << 16);
+
+ // Create the Request Counter Directory packet
+ Packet requestCounterDirectoryPacket(header);
+
+ // Write the packet to the mock profiling connection
+ mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
+
+ // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
+ // the Create the Request Counter packet gets processed by the profiling service
+ std::this_thread::sleep_for(std::chrono::seconds(2));
+
+ // Check that the expected error has occurred and logged to the standard output
+ BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=123 and Version=4194304 does not exist"));
+
+ // The Connection Acknowledged Command Handler should not have updated the profiling state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Reset the profiling service to stop any running thread
+ options.m_EnableProfiling = false;
+ profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket)
+{
+ // Swap the profiling connection factory in the profiling service instance with our mock one
+ SwapProfilingConnectionFactoryHelper helper;
+
+ // Reset the profiling service to the uninitialized state
+ armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+ options.m_EnableProfiling = true;
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+
+ // Bring the profiling service to the "Active" state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Update(); // Initialize the counter directory
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+ profilingService.Update(); // Create the profiling connection
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+ profilingService.Update(); // Start the threads
+ helper.ForceTransitionToState(ProfilingState::Active);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Get the mock profiling connection
+ MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
+ // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
+ // reply from an external profiling service
+
+ // Request Counter Directory packet header (word 0, word 1 is always zero):
+ // 26:31 [6] packet_family: Control Packet Family, value 0b000000
+ // 16:25 [10] packet_id: Packet identifier, value 0b0000000011
+ // 8:15 [8] reserved: Reserved, value 0b00000000
+ // 0:7 [8] reserved: Reserved, value 0b00000000
+ uint32_t packetFamily = 0;
+ uint32_t packetId = 3;
+ uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+ ((packetId & 0x000003FF) << 16);
+
+ // Create the Request Counter Directory packet
+ Packet requestCounterDirectoryPacket(header);
+
+ // Write the packet to the mock profiling connection
+ mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
+
+ // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
+ // the Create the Request Counter packet gets processed by the profiling service
+ std::this_thread::sleep_for(std::chrono::seconds(2));
+
+ // The Connection Acknowledged Command Handler should not have updated the profiling state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Check that the mock profiling connection contains one Counter Directory packet
+ const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
+ BOOST_TEST(writtenData.size() == 1);
+ BOOST_TEST(writtenData[0] == 416); // The size of a valid Counter Directory packet
+
+ // Reset the profiling service to stop any running thread
+ options.m_EnableProfiling = false;
+ profilingService.ResetExternalProfilingOptions(options, true);
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp
index 3e6cf63efe..e1686162db 100644
--- a/src/profiling/test/ProfilingTests.hpp
+++ b/src/profiling/test/ProfilingTests.hpp
@@ -40,17 +40,6 @@ public:
}
};
-struct CoutRedirect
-{
-public:
- CoutRedirect(std::streambuf* newStreamBuffer)
- : m_Old(std::cout.rdbuf(newStreamBuffer)) {}
- ~CoutRedirect() { std::cout.rdbuf(m_Old); }
-
-private:
- std::streambuf* m_Old;
-};
-
struct StreamRedirector
{
public:
@@ -190,6 +179,11 @@ public:
return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection);
}
+ void ForceTransitionToState(ProfilingState newState)
+ {
+ TransitionToState(ProfilingService::Instance(), newState);
+ }
+
private:
MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;