aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-10-10 14:08:21 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-10-11 16:33:29 +0100
commite848538efbdf01aa0b067da942c3c214f8e62826 (patch)
treed700239f1316a098849fcfc39ec70e926f86fd62
parentf982deaefbe5fe5814487b27f7099829839b8666 (diff)
downloadarmnn-e848538efbdf01aa0b067da942c3c214f8e62826.tar.gz
IVGCVSW-3964 Implement the Periodic Counter Selection command handler
* Improved the PeriodicCounterPacket class to handle errors properly * Improved the PeriodicCounterSelectionCommandHandler to handle invalid counter UIDs in the selection packet * Added the Periodic Counter Selection command handler to the ProfilingService class * Code refactoring and added comments * Added WaitForPacketSent method to the SendCounterPacket class to allow waiting for the packets to be sent (useful in the unit tests) * Added unit tests and updated the old ones accordingly * Fixed threading issues with a number of unit tests Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com> Change-Id: I271b7b0bfa801d88fe1725b934d24e30cd839ed7
-rw-r--r--src/profiling/ConnectionAcknowledgedCommandHandler.cpp2
-rw-r--r--src/profiling/Holder.cpp12
-rw-r--r--src/profiling/Holder.hpp4
-rw-r--r--src/profiling/ICounterValues.hpp1
-rw-r--r--src/profiling/PeriodicCounterCapture.cpp33
-rw-r--r--src/profiling/PeriodicCounterSelectionCommandHandler.cpp115
-rw-r--r--src/profiling/PeriodicCounterSelectionCommandHandler.hpp31
-rw-r--r--src/profiling/ProfilingService.cpp32
-rw-r--r--src/profiling/ProfilingService.hpp24
-rw-r--r--src/profiling/RequestCounterDirectoryCommandHandler.cpp2
-rw-r--r--src/profiling/SendCounterPacket.cpp18
-rw-r--r--src/profiling/SendCounterPacket.hpp12
-rw-r--r--src/profiling/test/ProfilingTests.cpp584
-rw-r--r--src/profiling/test/ProfilingTests.hpp16
-rw-r--r--src/profiling/test/SendCounterPacketTests.hpp16
15 files changed, 784 insertions, 118 deletions
diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
index 9d2d1a2bd2..deffd1414b 100644
--- a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
+++ b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
@@ -22,7 +22,7 @@ void ConnectionAcknowledgedCommandHandler::operator()(const Packet& packet)
{
case ProfilingState::Uninitialised:
case ProfilingState::NotConnected:
- throw RuntimeException(boost::str(boost::format("Connection Acknowledged Handler invoked while in an "
+ throw RuntimeException(boost::str(boost::format("Connection Acknowledged Command Handler invoked while in an "
"wrong state: %1%")
% GetProfilingStateName(currentState)));
case ProfilingState::WaitingForAck:
diff --git a/src/profiling/Holder.cpp b/src/profiling/Holder.cpp
index 5916017eb6..750be7ec74 100644
--- a/src/profiling/Holder.cpp
+++ b/src/profiling/Holder.cpp
@@ -11,10 +11,10 @@ namespace armnn
namespace profiling
{
-CaptureData& CaptureData::operator= (const CaptureData& captureData)
+CaptureData& CaptureData::operator=(const CaptureData& other)
{
- m_CapturePeriod = captureData.m_CapturePeriod;
- m_CounterIds = captureData.m_CounterIds;
+ m_CapturePeriod = other.m_CapturePeriod;
+ m_CounterIds = other.m_CounterIds;
return *this;
}
@@ -29,12 +29,12 @@ void CaptureData::SetCounterIds(const std::vector<uint16_t>& counterIds)
m_CounterIds = counterIds;
}
-std::uint32_t CaptureData::GetCapturePeriod() const
+uint32_t CaptureData::GetCapturePeriod() const
{
return m_CapturePeriod;
}
-std::vector<uint16_t> CaptureData::GetCounterIds() const
+const std::vector<uint16_t>& CaptureData::GetCounterIds() const
{
return m_CounterIds;
}
@@ -42,12 +42,14 @@ std::vector<uint16_t> CaptureData::GetCounterIds() const
CaptureData Holder::GetCaptureData() const
{
std::lock_guard<std::mutex> lockGuard(m_CaptureThreadMutex);
+
return m_CaptureData;
}
void Holder::SetCaptureData(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
{
std::lock_guard<std::mutex> lockGuard(m_CaptureThreadMutex);
+
m_CaptureData.SetCapturePeriod(capturePeriod);
m_CaptureData.SetCounterIds(counterIds);
}
diff --git a/src/profiling/Holder.hpp b/src/profiling/Holder.hpp
index 72ca0914a9..3143105ab4 100644
--- a/src/profiling/Holder.hpp
+++ b/src/profiling/Holder.hpp
@@ -27,12 +27,12 @@ public:
: m_CapturePeriod(captureData.m_CapturePeriod)
, m_CounterIds(captureData.m_CounterIds) {}
- CaptureData& operator= (const CaptureData& captureData);
+ CaptureData& operator=(const CaptureData& other);
void SetCapturePeriod(uint32_t capturePeriod);
void SetCounterIds(const std::vector<uint16_t>& counterIds);
uint32_t GetCapturePeriod() const;
- std::vector<uint16_t> GetCounterIds() const;
+ const std::vector<uint16_t>& GetCounterIds() const;
private:
uint32_t m_CapturePeriod;
diff --git a/src/profiling/ICounterValues.hpp b/src/profiling/ICounterValues.hpp
index 5e32ca2b37..18e34b6747 100644
--- a/src/profiling/ICounterValues.hpp
+++ b/src/profiling/ICounterValues.hpp
@@ -18,6 +18,7 @@ class IReadCounterValues
public:
virtual ~IReadCounterValues() {}
+ virtual bool IsCounterRegistered(uint16_t counterUid) const = 0;
virtual uint16_t GetCounterCount() const = 0;
virtual uint32_t GetCounterValue(uint16_t counterUid) const = 0;
};
diff --git a/src/profiling/PeriodicCounterCapture.cpp b/src/profiling/PeriodicCounterCapture.cpp
index 9002bfc065..0ccb516ae2 100644
--- a/src/profiling/PeriodicCounterCapture.cpp
+++ b/src/profiling/PeriodicCounterCapture.cpp
@@ -5,6 +5,8 @@
#include "PeriodicCounterCapture.hpp"
+#include <boost/log/trivial.hpp>
+
namespace armnn
{
@@ -34,10 +36,13 @@ void PeriodicCounterCapture::Start()
void PeriodicCounterCapture::Stop()
{
+ // Signal the capture thread to stop
m_KeepRunning.store(false);
+ // Check that the capture thread is running
if (m_PeriodCaptureThread.joinable())
{
+ // Wait for the capture thread to complete operations
m_PeriodCaptureThread.join();
}
}
@@ -51,10 +56,12 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues
{
while (m_KeepRunning.load())
{
+ // Check if the current capture data indicates that there's data capture
auto currentCaptureData = ReadCaptureData();
- std::vector<uint16_t> counterIds = currentCaptureData.GetCounterIds();
+ const std::vector<uint16_t>& counterIds = currentCaptureData.GetCounterIds();
if (currentCaptureData.GetCapturePeriod() == 0 || counterIds.empty())
{
+ // No data capture, terminate the thread
m_KeepRunning.store(false);
break;
}
@@ -63,12 +70,22 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues
auto numCounters = counterIds.size();
values.reserve(numCounters);
- // Create vector of pairs of CounterIndexes and Values
- uint32_t counterValue = 0;
+ // Create a vector of pairs of CounterIndexes and Values
for (uint16_t index = 0; index < numCounters; ++index)
{
auto requestedId = counterIds[index];
- counterValue = readCounterValues.GetCounterValue(requestedId);
+ uint32_t counterValue = 0;
+ try
+ {
+ counterValue = readCounterValues.GetCounterValue(requestedId);
+ }
+ catch (const Exception& e)
+ {
+ // Report the error and continue
+ BOOST_LOG_TRIVIAL(warning) << "An error has occurred when getting a counter value: "
+ << e.what() << std::endl;
+ continue;
+ }
values.emplace_back(std::make_pair(requestedId, counterValue));
}
@@ -81,9 +98,15 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues
// Take a timestamp
auto timestamp = clock::now();
+ // Write a Periodic Counter Capture packet to the Counter Stream Buffer
m_SendCounterPacket.SendPeriodicCounterCapturePacket(
static_cast<uint64_t>(timestamp.time_since_epoch().count()), values);
- std::this_thread::sleep_for(std::chrono::milliseconds(currentCaptureData.GetCapturePeriod()));
+
+ // Notify the Send Thread that new data is available in the Counter Stream Buffer
+ m_SendCounterPacket.SetReadyToRead();
+
+ // Wait the indicated capture period (microseconds)
+ std::this_thread::sleep_for(std::chrono::microseconds(currentCaptureData.GetCapturePeriod()));
}
m_IsRunning.store(false);
diff --git a/src/profiling/PeriodicCounterSelectionCommandHandler.cpp b/src/profiling/PeriodicCounterSelectionCommandHandler.cpp
index 9be37fcfd2..db09856dae 100644
--- a/src/profiling/PeriodicCounterSelectionCommandHandler.cpp
+++ b/src/profiling/PeriodicCounterSelectionCommandHandler.cpp
@@ -7,6 +7,9 @@
#include "ProfilingUtils.hpp"
#include <boost/numeric/conversion/cast.hpp>
+#include <boost/format.hpp>
+
+#include <vector>
namespace armnn
{
@@ -14,57 +17,109 @@ namespace armnn
namespace profiling
{
-using namespace std;
-using boost::numeric_cast;
-
void PeriodicCounterSelectionCommandHandler::ParseData(const Packet& packet, CaptureData& captureData)
{
std::vector<uint16_t> counterIds;
- uint32_t sizeOfUint32 = numeric_cast<uint32_t>(sizeof(uint32_t));
- uint32_t sizeOfUint16 = numeric_cast<uint32_t>(sizeof(uint16_t));
+ uint32_t sizeOfUint32 = boost::numeric_cast<uint32_t>(sizeof(uint32_t));
+ uint32_t sizeOfUint16 = boost::numeric_cast<uint32_t>(sizeof(uint16_t));
uint32_t offset = 0;
- if (packet.GetLength() > 0)
+ if (packet.GetLength() < 4)
{
- if (packet.GetLength() >= 4)
- {
- captureData.SetCapturePeriod(ReadUint32(reinterpret_cast<const unsigned char*>(packet.GetData()), offset));
+ // Insufficient packet size
+ return;
+ }
- unsigned int counters = (packet.GetLength() - 4) / 2;
+ // Parse the capture period
+ uint32_t capturePeriod = ReadUint32(packet.GetData(), offset);
- if (counters > 0)
- {
- counterIds.reserve(counters);
- offset += sizeOfUint32;
- for(unsigned int pos = 0; pos < counters; ++pos)
- {
- counterIds.emplace_back(ReadUint16(reinterpret_cast<const unsigned char*>(packet.GetData()),
- offset));
- offset += sizeOfUint16;
- }
- }
+ // Set the capture period
+ captureData.SetCapturePeriod(capturePeriod);
- captureData.SetCounterIds(counterIds);
+ // Parse the counter ids
+ unsigned int counters = (packet.GetLength() - 4) / 2;
+ if (counters > 0)
+ {
+ counterIds.reserve(counters);
+ offset += sizeOfUint32;
+ for (unsigned int i = 0; i < counters; ++i)
+ {
+ // Parse the counter id
+ uint16_t counterId = ReadUint16(packet.GetData(), offset);
+ counterIds.emplace_back(counterId);
+ offset += sizeOfUint16;
}
}
+
+ // Set the counter ids
+ captureData.SetCounterIds(counterIds);
}
void PeriodicCounterSelectionCommandHandler::operator()(const Packet& packet)
{
- CaptureData captureData;
+ ProfilingState currentState = m_StateMachine.GetCurrentState();
+ switch (currentState)
+ {
+ case ProfilingState::Uninitialised:
+ case ProfilingState::NotConnected:
+ case ProfilingState::WaitingForAck:
+ throw RuntimeException(boost::str(boost::format("Periodic Counter Selection Command Handler invoked while in "
+ "an wrong state: %1%")
+ % GetProfilingStateName(currentState)));
+ case ProfilingState::Active:
+ {
+ // Process the packet
+ if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 4u))
+ {
+ throw armnn::InvalidArgumentException(boost::str(boost::format("Expected Packet family = 0, id = 4 but "
+ "received family = %1%, id = %2%")
+ % packet.GetPacketFamily()
+ % packet.GetPacketId()));
+ }
+
+ // Parse the packet to get the capture period and counter UIDs
+ CaptureData captureData;
+ ParseData(packet, captureData);
- ParseData(packet, captureData);
+ // Get the capture data
+ const uint32_t capturePeriod = captureData.GetCapturePeriod();
+ const std::vector<uint16_t>& counterIds = captureData.GetCounterIds();
- vector<uint16_t> counterIds = captureData.GetCounterIds();
+ // Check whether the selected counter UIDs are valid
+ std::vector<uint16_t> validCounterIds;
+ for (uint16_t counterId : counterIds)
+ {
+ // Check whether the counter is registered
+ if (!m_ReadCounterValues.IsCounterRegistered(counterId))
+ {
+ // Invalid counter UID, ignore it and continue
+ continue;
+ }
- m_CaptureDataHolder.SetCaptureData(captureData.GetCapturePeriod(), counterIds);
+ // The counter is valid
+ validCounterIds.push_back(counterId);
+ }
- m_CaptureThread.Start();
+ // Set the capture data with only the valid counter UIDs
+ m_CaptureDataHolder.SetCaptureData(capturePeriod, validCounterIds);
- // Write packet to Counter Stream Buffer
- m_SendCounterPacket.SendPeriodicCounterSelectionPacket(captureData.GetCapturePeriod(), captureData.GetCounterIds());
+ // Echo back the Periodic Counter Selection packet to the Counter Stream Buffer
+ m_SendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, validCounterIds);
+
+ // Notify the Send Thread that new data is available in the Counter Stream Buffer
+ m_SendCounterPacket.SetReadyToRead();
+
+ // Start the Period Counter Capture thread (if not running already)
+ m_PeriodicCounterCapture.Start();
+
+ break;
+ }
+ default:
+ throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%")
+ % static_cast<int>(currentState)));
+ }
}
} // namespace profiling
-} // namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/profiling/PeriodicCounterSelectionCommandHandler.hpp b/src/profiling/PeriodicCounterSelectionCommandHandler.hpp
index e247e7773f..1da08e3c7a 100644
--- a/src/profiling/PeriodicCounterSelectionCommandHandler.hpp
+++ b/src/profiling/PeriodicCounterSelectionCommandHandler.hpp
@@ -10,10 +10,7 @@
#include "Holder.hpp"
#include "SendCounterPacket.hpp"
#include "IPeriodicCounterCapture.hpp"
-
-#include <vector>
-#include <thread>
-#include <atomic>
+#include "ICounterValues.hpp"
namespace armnn
{
@@ -25,22 +22,30 @@ class PeriodicCounterSelectionCommandHandler : public CommandHandlerFunctor
{
public:
- PeriodicCounterSelectionCommandHandler(uint32_t packetId, uint32_t version, Holder& captureDataHolder,
- IPeriodicCounterCapture& captureThread,
- ISendCounterPacket& sendCounterPacket)
- : CommandHandlerFunctor(packetId, version),
- m_CaptureDataHolder(captureDataHolder),
- m_CaptureThread(captureThread),
- m_SendCounterPacket(sendCounterPacket)
+ PeriodicCounterSelectionCommandHandler(uint32_t packetId,
+ uint32_t version,
+ Holder& captureDataHolder,
+ IPeriodicCounterCapture& periodicCounterCapture,
+ const IReadCounterValues& readCounterValue,
+ ISendCounterPacket& sendCounterPacket,
+ const ProfilingStateMachine& profilingStateMachine)
+ : CommandHandlerFunctor(packetId, version)
+ , m_CaptureDataHolder(captureDataHolder)
+ , m_PeriodicCounterCapture(periodicCounterCapture)
+ , m_ReadCounterValues(readCounterValue)
+ , m_SendCounterPacket(sendCounterPacket)
+ , m_StateMachine(profilingStateMachine)
{}
void operator()(const Packet& packet) override;
-
private:
Holder& m_CaptureDataHolder;
- IPeriodicCounterCapture& m_CaptureThread;
+ IPeriodicCounterCapture& m_PeriodicCounterCapture;
+ const IReadCounterValues& m_ReadCounterValues;
ISendCounterPacket& m_SendCounterPacket;
+ const ProfilingStateMachine& m_StateMachine;
+
void ParseData(const Packet& packet, CaptureData& captureData);
};
diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp
index 693f8337db..79184416cd 100644
--- a/src/profiling/ProfilingService.cpp
+++ b/src/profiling/ProfilingService.cpp
@@ -53,6 +53,9 @@ void ProfilingService::Update()
// Stop the send thread (if running)
m_SendCounterPacket.Stop(false);
+ // Stop the periodic counter capture thread (if running)
+ m_PeriodicCounterCapture.Stop();
+
// Reset any existing profiling connection
m_ProfilingConnection.reset();
@@ -90,6 +93,9 @@ void ProfilingService::Update()
break;
case ProfilingState::Active:
+ // The period counter capture thread is started by the Periodic Counter Selection command handler upon
+ // request by an external profiling service
+
break;
default:
throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
@@ -112,9 +118,14 @@ uint16_t ProfilingService::GetCounterCount() const
return m_CounterDirectory.GetCounterCount();
}
+bool ProfilingService::IsCounterRegistered(uint16_t counterUid) const
+{
+ return counterUid < m_CounterIndex.size();
+}
+
uint32_t ProfilingService::GetCounterValue(uint16_t counterUid) const
{
- BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ CheckCounterUid(counterUid);
std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
BOOST_ASSERT(counterValuePtr);
return counterValuePtr->load(std::memory_order::memory_order_relaxed);
@@ -122,7 +133,7 @@ uint32_t ProfilingService::GetCounterValue(uint16_t counterUid) const
void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
{
- BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ CheckCounterUid(counterUid);
std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
BOOST_ASSERT(counterValuePtr);
counterValuePtr->store(value, std::memory_order::memory_order_relaxed);
@@ -130,7 +141,7 @@ void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
{
- BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ CheckCounterUid(counterUid);
std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
BOOST_ASSERT(counterValuePtr);
return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed);
@@ -138,7 +149,7 @@ uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value)
{
- BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ CheckCounterUid(counterUid);
std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
BOOST_ASSERT(counterValuePtr);
return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed);
@@ -146,7 +157,7 @@ uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t va
uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
{
- BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ CheckCounterUid(counterUid);
std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
BOOST_ASSERT(counterValuePtr);
return counterValuePtr->operator++(std::memory_order::memory_order_relaxed);
@@ -154,7 +165,7 @@ uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
uint32_t ProfilingService::DecrementCounterValue(uint16_t counterUid)
{
- BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ CheckCounterUid(counterUid);
std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
BOOST_ASSERT(counterValuePtr);
return counterValuePtr->operator--(std::memory_order::memory_order_relaxed);
@@ -239,6 +250,7 @@ void ProfilingService::Reset()
// First stop the threads (Command Handler first)...
m_CommandHandler.Stop();
m_SendCounterPacket.Stop(false);
+ m_PeriodicCounterCapture.Stop();
// ...then destroy the profiling connection...
m_ProfilingConnection.reset();
@@ -252,6 +264,14 @@ void ProfilingService::Reset()
m_StateMachine.Reset();
}
+inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const
+{
+ if (!IsCounterRegistered(counterUid))
+ {
+ throw InvalidArgumentException(boost::str(boost::format("Counter UID %1% is not registered") % counterUid));
+ }
+}
+
} // namespace profiling
} // namespace armnn
diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp
index 0e66924267..dd70af4b39 100644
--- a/src/profiling/ProfilingService.hpp
+++ b/src/profiling/ProfilingService.hpp
@@ -12,8 +12,10 @@
#include "CommandHandler.hpp"
#include "BufferManager.hpp"
#include "SendCounterPacket.hpp"
+#include "PeriodicCounterCapture.hpp"
#include "ConnectionAcknowledgedCommandHandler.hpp"
#include "RequestCounterDirectoryCommandHandler.hpp"
+#include "PeriodicCounterSelectionCommandHandler.hpp"
namespace armnn
{
@@ -46,6 +48,7 @@ public:
// Getters for the profiling service state
const ICounterDirectory& GetCounterDirectory() const;
ProfilingState GetCurrentState() const;
+ bool IsCounterRegistered(uint16_t counterUid) const override;
uint16_t GetCounterCount() const override;
uint32_t GetCounterValue(uint16_t counterUid) const override;
@@ -68,6 +71,9 @@ private:
void InitializeCounterValue(uint16_t counterUid);
void Reset();
+ // Helper function
+ void CheckCounterUid(uint16_t counterUid) const;
+
// Profiling service components
ExternalProfilingOptions m_Options;
CounterDirectory m_CounterDirectory;
@@ -81,8 +87,11 @@ private:
CommandHandler m_CommandHandler;
BufferManager m_BufferManager;
SendCounterPacket m_SendCounterPacket;
+ Holder m_Holder;
+ PeriodicCounterCapture m_PeriodicCounterCapture;
ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler;
RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler;
+ PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler;
protected:
// Default constructor/destructor kept protected for testing
@@ -102,6 +111,7 @@ protected:
m_PacketVersionResolver)
, m_BufferManager()
, m_SendCounterPacket(m_StateMachine, m_BufferManager)
+ , m_PeriodicCounterCapture(m_Holder, m_SendCounterPacket, *this)
, m_ConnectionAcknowledgedCommandHandler(1,
m_PacketVersionResolver.ResolvePacketVersion(1).GetEncodedValue(),
m_StateMachine)
@@ -110,12 +120,22 @@ protected:
m_CounterDirectory,
m_SendCounterPacket,
m_StateMachine)
+ , m_PeriodicCounterSelectionCommandHandler(4,
+ m_PacketVersionResolver.ResolvePacketVersion(4).GetEncodedValue(),
+ m_Holder,
+ m_PeriodicCounterCapture,
+ *this,
+ m_SendCounterPacket,
+ m_StateMachine)
{
// Register the "Connection Acknowledged" command handler
m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler);
// Register the "Request Counter Directory" command handler
m_CommandHandlerRegistry.RegisterFunctor(&m_RequestCounterDirectoryCommandHandler);
+
+ // Register the "Periodic Counter Selection" command handler
+ m_CommandHandlerRegistry.RegisterFunctor(&m_PeriodicCounterSelectionCommandHandler);
}
~ProfilingService() = default;
@@ -138,6 +158,10 @@ protected:
{
instance.m_StateMachine.TransitionToState(newState);
}
+ void WaitForPacketSent(ProfilingService& instance)
+ {
+ return instance.m_SendCounterPacket.WaitForPacketSent();
+ }
};
} // namespace profiling
diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.cpp b/src/profiling/RequestCounterDirectoryCommandHandler.cpp
index e85acb4215..b8ac9d9426 100644
--- a/src/profiling/RequestCounterDirectoryCommandHandler.cpp
+++ b/src/profiling/RequestCounterDirectoryCommandHandler.cpp
@@ -21,7 +21,7 @@ void RequestCounterDirectoryCommandHandler::operator()(const Packet& packet)
case ProfilingState::Uninitialised:
case ProfilingState::NotConnected:
case ProfilingState::WaitingForAck:
- throw RuntimeException(boost::str(boost::format("Request Counter Directory Handler invoked while in an "
+ throw RuntimeException(boost::str(boost::format("Request Counter Directory Comand Handler invoked while in an "
"wrong state: %1%")
% GetProfilingStateName(currentState)));
case ProfilingState::Active:
diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp
index e48da3ed7c..41adf37244 100644
--- a/src/profiling/SendCounterPacket.cpp
+++ b/src/profiling/SendCounterPacket.cpp
@@ -1035,17 +1035,21 @@ void SendCounterPacket::Send(IProfilingConnection& profilingConnection)
}
// Ensure that all readable data got written to the profiling connection before the thread is stopped
- FlushBuffer(profilingConnection);
+ // (do not notify any watcher in this case, as this is just to wrap up things before shutting down the send thread)
+ FlushBuffer(profilingConnection, false);
// Mark the send thread as not running
m_IsRunning.store(false);
}
-void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection)
+void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection, bool notifyWatchers)
{
// Get the first available readable buffer
std::unique_ptr<IPacketBuffer> packetBuffer = m_BufferManager.GetReadableBuffer();
+ // Initialize the flag that indicates whether at least a packet has been sent
+ bool packetsSent = false;
+
while (packetBuffer != nullptr)
{
// Get the data to send from the buffer
@@ -1066,6 +1070,9 @@ void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection)
{
// Write a packet to the profiling connection. Silently ignore any write error and continue
profilingConnection.WritePacket(readBuffer, boost::numeric_cast<uint32_t>(readBufferSize));
+
+ // Set the flag that indicates whether at least a packet has been sent
+ packetsSent = true;
}
// Mark the packet buffer as read
@@ -1074,6 +1081,13 @@ void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection)
// Get the next available readable buffer
packetBuffer = m_BufferManager.GetReadableBuffer();
}
+
+ // Check whether at least a packet has been sent
+ if (packetsSent && notifyWatchers)
+ {
+ // Notify to any watcher that something has been sent
+ m_PacketSentWaitCondition.notify_one();
+ }
}
} // namespace profiling
diff --git a/src/profiling/SendCounterPacket.hpp b/src/profiling/SendCounterPacket.hpp
index 9361efbc74..e1a42aa496 100644
--- a/src/profiling/SendCounterPacket.hpp
+++ b/src/profiling/SendCounterPacket.hpp
@@ -65,6 +65,14 @@ public:
void Stop(bool rethrowSendThreadExceptions = true);
bool IsRunning() { return m_IsRunning.load(); }
+ void WaitForPacketSent()
+ {
+ std::unique_lock<std::mutex> lock(m_PacketSentWaitMutex);
+
+ // Blocks until notified that at least a packet has been sent
+ m_PacketSentWaitCondition.wait(lock);
+ }
+
private:
void Send(IProfilingConnection& profilingConnection);
@@ -93,7 +101,7 @@ private:
throw ExceptionType(errorMessage);
}
- void FlushBuffer(IProfilingConnection& profilingConnection);
+ void FlushBuffer(IProfilingConnection& profilingConnection, bool notifyWatchers = true);
ProfilingStateMachine& m_StateMachine;
IBufferManager& m_BufferManager;
@@ -104,6 +112,8 @@ private:
std::atomic<bool> m_IsRunning;
std::atomic<bool> m_KeepRunning;
std::exception_ptr m_SendThreadException;
+ std::mutex m_PacketSentWaitMutex;
+ std::condition_variable m_PacketSentWaitCondition;
protected:
// Helper methods, protected for testing
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index 27bacf7145..554b7e1936 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -35,6 +35,7 @@
#include <limits>
#include <map>
#include <random>
+#include <iostream>
using namespace armnn::profiling;
@@ -1691,11 +1692,19 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
void Stop() override {}
};
+ class TestReadCounterValues : public IReadCounterValues
+ {
+ bool IsCounterRegistered(uint16_t counterUid) const override { return true; }
+ uint16_t GetCounterCount() const override { return 0; }
+ uint32_t GetCounterValue(uint16_t counterUid) const override { return 0; }
+ };
+
const uint32_t packetId = 0x40000;
uint32_t version = 1;
Holder holder;
TestCaptureThread captureThread;
+ TestReadCounterValues readCounterValues;
MockBufferManager mockBuffer(512);
SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
@@ -1718,16 +1727,29 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
Packet packetA(packetId, dataLength1, uniqueData1);
- PeriodicCounterSelectionCommandHandler commandHandler(packetId, version, holder, captureThread,
- sendCounterPacket);
- commandHandler(packetA);
+ PeriodicCounterSelectionCommandHandler commandHandler(packetId,
+ version,
+ holder,
+ captureThread,
+ readCounterValues,
+ sendCounterPacket,
+ profilingStateMachine);
- std::vector<uint16_t> counterIds = holder.GetCaptureData().GetCounterIds();
+ profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
+ BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException);
+ profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+ BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException);
+ profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+ BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException);
+ profilingStateMachine.TransitionToState(ProfilingState::Active);
+ BOOST_CHECK_NO_THROW(commandHandler(packetA));
+
+ const std::vector<uint16_t> counterIdsA = holder.GetCaptureData().GetCounterIds();
BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period1);
- BOOST_TEST(counterIds.size() == 2);
- BOOST_TEST(counterIds[0] == 4000);
- BOOST_TEST(counterIds[1] == 5000);
+ BOOST_TEST(counterIdsA.size() == 2);
+ BOOST_TEST(counterIdsA[0] == 4000);
+ BOOST_TEST(counterIdsA[1] == 5000);
auto readBuffer = mockBuffer.GetReadableBuffer();
@@ -1766,10 +1788,10 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
commandHandler(packetB);
- counterIds = holder.GetCaptureData().GetCounterIds();
+ const std::vector<uint16_t> counterIdsB = holder.GetCaptureData().GetCounterIds();
BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period2);
- BOOST_TEST(counterIds.size() == 0);
+ BOOST_TEST(counterIdsB.size() == 0);
readBuffer = mockBuffer.GetReadableBuffer();
@@ -2024,35 +2046,40 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread)
public:
CaptureReader() {}
+ bool IsCounterRegistered(uint16_t counterUid) const override
+ {
+ return m_Data.find(counterUid) != m_Data.end();
+ }
+
uint16_t GetCounterCount() const override
{
return boost::numeric_cast<uint16_t>(m_Data.size());
}
- uint32_t GetCounterValue(uint16_t index) const override
+ uint32_t GetCounterValue(uint16_t counterUid) const override
{
- if (m_Data.find(index) == m_Data.end())
+ if (m_Data.find(counterUid) == m_Data.end())
{
return 0;
}
- return m_Data.at(index);
+ return m_Data.at(counterUid).load();
}
- void SetCounterValue(uint16_t index, uint32_t value)
+ void SetCounterValue(uint16_t counterUid, uint32_t value)
{
- if (m_Data.find(index) == m_Data.end())
+ if (m_Data.find(counterUid) == m_Data.end())
{
- m_Data.insert(std::pair<uint16_t, uint32_t>(index, value));
+ m_Data.insert(std::make_pair(counterUid, value));
}
else
{
- m_Data.at(index) = value;
+ m_Data.at(counterUid).store(value);
}
}
private:
- std::unordered_map<uint16_t, uint32_t> m_Data;
+ std::unordered_map<uint16_t, std::atomic<uint32_t>> m_Data;
};
ProfilingStateMachine profilingStateMachine;
@@ -2261,19 +2288,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
// Bring the profiling service to the "WaitingForAck" state
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
- profilingService.Update();
+ profilingService.Update(); // Initialize the counter directory
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
- profilingService.Update();
- BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
- profilingService.Update();
-
- // Wait for a bit to make sure that we get the packet
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ profilingService.Update();// Create the profiling connection
// Get the mock profiling connection
MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
BOOST_CHECK(mockProfilingConnection);
+ // Remove the packets received so far
+ mockProfilingConnection->Clear();
+
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+ profilingService.Update();
+
+ // Wait for the Stream Metadata packet to be sent
+ helper.WaitForProfilingPacketsSent();
+
// Check that the mock profiling connection contains one Stream Metadata packet
const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
BOOST_TEST(writtenData.size() == 1);
@@ -2330,19 +2361,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
// Bring the profiling service to the "WaitingForAck" state
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
- profilingService.Update();
+ profilingService.Update(); // Initialize the counter directory
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
- profilingService.Update();
- BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
- profilingService.Update();
-
- // Wait for a bit to make sure that we get the packet
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ profilingService.Update(); // Create the profiling connection
// Get the mock profiling connection
MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
BOOST_CHECK(mockProfilingConnection);
+ // Remove the packets received so far
+ mockProfilingConnection->Clear();
+
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+ profilingService.Update(); // Start the command handler and the send thread
+
+ // Wait for the Stream Metadata packet to be sent
+ helper.WaitForProfilingPacketsSent();
+
// Check that the mock profiling connection contains one Stream Metadata packet
const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
BOOST_TEST(writtenData.size() == 1);
@@ -2403,7 +2438,13 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket)
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
profilingService.Update(); // Create the profiling connection
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
- profilingService.Update(); // Start the threads
+ profilingService.Update(); // Start the command handler and the send thread
+
+ // Wait for the Stream Metadata packet the be sent
+ // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+ helper.WaitForProfilingPacketsSent();
+
+ // Force the profiling service to the "Active" state
helper.ForceTransitionToState(ProfilingState::Active);
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
@@ -2411,6 +2452,9 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket)
MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
BOOST_CHECK(mockProfilingConnection);
+ // Remove the packets received so far
+ mockProfilingConnection->Clear();
+
// Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
// reply from an external profiling service
@@ -2437,7 +2481,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket)
// Check that the expected error has occurred and logged to the standard output
BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=123 and Version=4194304 does not exist"));
- // The Connection Acknowledged Command Handler should not have updated the profiling state
+ // The Request Counter Directory Command Handler should not have updated the profiling state
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
// Reset the profiling service to stop any running thread
@@ -2462,7 +2506,13 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket)
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
profilingService.Update(); // Create the profiling connection
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
- profilingService.Update(); // Start the threads
+ profilingService.Update(); // Start the command handler and the send thread
+
+ // Wait for the Stream Metadata packet the be sent
+ // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+ helper.WaitForProfilingPacketsSent();
+
+ // Force the profiling service to the "Active" state
helper.ForceTransitionToState(ProfilingState::Active);
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
@@ -2470,6 +2520,9 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket)
MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
BOOST_CHECK(mockProfilingConnection);
+ // Remove the packets received so far
+ mockProfilingConnection->Clear();
+
// Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
// reply from an external profiling service
@@ -2489,17 +2542,470 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket)
// Write the packet to the mock profiling connection
mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
+ // Wait for the Counter Directory packet to be sent
+ helper.WaitForProfilingPacketsSent();
+
+ // Check that the mock profiling connection contains one Counter Directory packet
+ const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
+ BOOST_TEST(writtenData.size() == 1);
+ BOOST_TEST(writtenData[0] == 416); // The size of the expected Counter Directory packet
+
+ // The Request Counter Directory Command Handler should not have updated the profiling state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Reset the profiling service to stop any running thread
+ options.m_EnableProfiling = false;
+ profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadPeriodicCounterSelectionPacket)
+{
+ // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
+ LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
+
+ // Swap the profiling connection factory in the profiling service instance with our mock one
+ SwapProfilingConnectionFactoryHelper helper;
+
+ // Redirect the standard output to a local stream so that we can parse the warning message
+ std::stringstream ss;
+ StreamRedirector streamRedirector(std::cout, ss.rdbuf());
+
+ // Reset the profiling service to the uninitialized state
+ armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+ options.m_EnableProfiling = true;
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+
+ // Bring the profiling service to the "Active" state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Update(); // Initialize the counter directory
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+ profilingService.Update(); // Create the profiling connection
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+ profilingService.Update(); // Start the command handler and the send thread
+
+ // Wait for the Stream Metadata packet the be sent
+ // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+ helper.WaitForProfilingPacketsSent();
+
+ // Force the profiling service to the "Active" state
+ helper.ForceTransitionToState(ProfilingState::Active);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Get the mock profiling connection
+ MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
+ // Remove the packets received so far
+ mockProfilingConnection->Clear();
+
+ // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
+ // external profiling service
+
+ // Periodic Counter Selection packet header:
+ // 26:31 [6] packet_family: Control Packet Family, value 0b000000
+ // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
+ // 8:15 [8] reserved: Reserved, value 0b00000000
+ // 0:7 [8] reserved: Reserved, value 0b00000000
+ uint32_t packetFamily = 0;
+ uint32_t packetId = 999; // Wrong packet id!!!
+ uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+ ((packetId & 0x000003FF) << 16);
+
+ // Create the Periodic Counter Selection packet
+ Packet periodicCounterSelectionPacket(header); // Length == 0, this will disable the collection of counters
+
+ // Write the packet to the mock profiling connection
+ mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
+
// Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
- // the Create the Request Counter packet gets processed by the profiling service
+ // the Periodic Counter Selection packet gets processed by the profiling service
std::this_thread::sleep_for(std::chrono::seconds(2));
- // The Connection Acknowledged Command Handler should not have updated the profiling state
+ // Check that the expected error has occurred and logged to the standard output
+ BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=999 and Version=4194304 does not exist"));
+
+ // The Periodic Counter Selection Handler should not have updated the profiling state
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
- // Check that the mock profiling connection contains one Counter Directory packet
+ // Reset the profiling service to stop any running thread
+ options.m_EnableProfiling = false;
+ profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadPeriodicCounterSelectionPacketInvalidCounterUid)
+{
+ // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
+ LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
+
+ // Swap the profiling connection factory in the profiling service instance with our mock one
+ SwapProfilingConnectionFactoryHelper helper;
+
+ // Reset the profiling service to the uninitialized state
+ armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+ options.m_EnableProfiling = true;
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+
+ // Bring the profiling service to the "Active" state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Update(); // Initialize the counter directory
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+ profilingService.Update(); // Create the profiling connection
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+ profilingService.Update(); // Start the command handler and the send thread
+
+ // Wait for the Stream Metadata packet the be sent
+ // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+ helper.WaitForProfilingPacketsSent();
+
+ // Force the profiling service to the "Active" state
+ helper.ForceTransitionToState(ProfilingState::Active);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Get the mock profiling connection
+ MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
+ // Remove the packets received so far
+ mockProfilingConnection->Clear();
+
+ // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
+ // external profiling service
+
+ // Periodic Counter Selection packet header:
+ // 26:31 [6] packet_family: Control Packet Family, value 0b000000
+ // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
+ // 8:15 [8] reserved: Reserved, value 0b00000000
+ // 0:7 [8] reserved: Reserved, value 0b00000000
+ uint32_t packetFamily = 0;
+ uint32_t packetId = 4;
+ uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+ ((packetId & 0x000003FF) << 16);
+
+ uint32_t capturePeriod = 123456; // Some capture period (microseconds)
+
+ // Get the first valid counter UID
+ const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
+ const Counters& counters = counterDirectory.GetCounters();
+ BOOST_CHECK(counters.size() > 1);
+ uint16_t counterUidA = counters.begin()->first; // First valid counter UID
+ uint16_t counterUidB = 9999; // Second invalid counter UID
+
+ uint32_t length = 8;
+
+ auto data = std::make_unique<unsigned char[]>(length);
+ WriteUint32(data.get(), 0, capturePeriod);
+ WriteUint16(data.get(), 4, counterUidA);
+ WriteUint16(data.get(), 6, counterUidB);
+
+ // Create the Periodic Counter Selection packet
+ Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter
+ // Capture thread
+
+ // Write the packet to the mock profiling connection
+ mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
+
+ // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet
+ int expectedPackets = 2;
+ std::vector<uint32_t> receivedPackets;
+
+ // Keep waiting until all the expected packets have been received
+ do
+ {
+ helper.WaitForProfilingPacketsSent();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
+ if (writtenData.empty())
+ {
+ BOOST_ERROR("Packets should be available for reading at this point");
+ return;
+ }
+ receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end());
+ expectedPackets -= boost::numeric_cast<int>(writtenData.size());
+ }
+ while (expectedPackets > 0);
+ BOOST_TEST(!receivedPackets.empty());
+
+ // The size of the expected Periodic Counter Selection packet
+ BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 14) != receivedPackets.end()));
+ // The size of the expected Periodic Counter Capture packet
+ BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 22) != receivedPackets.end()));
+
+ // The Periodic Counter Selection Handler should not have updated the profiling state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Reset the profiling service to stop any running thread
+ options.m_EnableProfiling = false;
+ profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketNoCounters)
+{
+ // Swap the profiling connection factory in the profiling service instance with our mock one
+ SwapProfilingConnectionFactoryHelper helper;
+
+ // Reset the profiling service to the uninitialized state
+ armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+ options.m_EnableProfiling = true;
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+
+ // Bring the profiling service to the "Active" state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Update(); // Initialize the counter directory
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+ profilingService.Update(); // Create the profiling connection
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+ profilingService.Update(); // Start the command handler and the send thread
+
+ // Wait for the Stream Metadata packet the be sent
+ // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+ helper.WaitForProfilingPacketsSent();
+
+ // Force the profiling service to the "Active" state
+ helper.ForceTransitionToState(ProfilingState::Active);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Get the mock profiling connection
+ MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
+ // Remove the packets received so far
+ mockProfilingConnection->Clear();
+
+ // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
+ // external profiling service
+
+ // Periodic Counter Selection packet header:
+ // 26:31 [6] packet_family: Control Packet Family, value 0b000000
+ // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
+ // 8:15 [8] reserved: Reserved, value 0b00000000
+ // 0:7 [8] reserved: Reserved, value 0b00000000
+ uint32_t packetFamily = 0;
+ uint32_t packetId = 4;
+ uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+ ((packetId & 0x000003FF) << 16);
+
+ // Create the Periodic Counter Selection packet
+ Packet periodicCounterSelectionPacket(header); // Length == 0, this will disable the collection of counters
+
+ // Write the packet to the mock profiling connection
+ mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
+
+ // Wait for the Periodic Counter Selection packet to be sent
+ helper.WaitForProfilingPacketsSent();
+
+ // The Periodic Counter Selection Handler should not have updated the profiling state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Check that the mock profiling connection contains one Periodic Counter Selection
const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
- BOOST_TEST(writtenData.size() == 1);
- BOOST_TEST(writtenData[0] == 416); // The size of a valid Counter Directory packet
+ BOOST_TEST(writtenData.size() == 1); // Only one packet is expected (no Periodic Counter packets)
+ BOOST_TEST(writtenData[0] == 12); // The size of the expected Periodic Counter Selection (echos the sent one)
+
+ // Reset the profiling service to stop any running thread
+ options.m_EnableProfiling = false;
+ profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketSingleCounter)
+{
+ // Swap the profiling connection factory in the profiling service instance with our mock one
+ SwapProfilingConnectionFactoryHelper helper;
+
+ // Reset the profiling service to the uninitialized state
+ armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+ options.m_EnableProfiling = true;
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+
+ // Bring the profiling service to the "Active" state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Update(); // Initialize the counter directory
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+ profilingService.Update(); // Create the profiling connection
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+ profilingService.Update(); // Start the command handler and the send thread
+
+ // Wait for the Stream Metadata packet the be sent
+ // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+ helper.WaitForProfilingPacketsSent();
+
+ // Force the profiling service to the "Active" state
+ helper.ForceTransitionToState(ProfilingState::Active);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Get the mock profiling connection
+ MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
+ // Remove the packets received so far
+ mockProfilingConnection->Clear();
+
+ // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
+ // external profiling service
+
+ // Periodic Counter Selection packet header:
+ // 26:31 [6] packet_family: Control Packet Family, value 0b000000
+ // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
+ // 8:15 [8] reserved: Reserved, value 0b00000000
+ // 0:7 [8] reserved: Reserved, value 0b00000000
+ uint32_t packetFamily = 0;
+ uint32_t packetId = 4;
+ uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+ ((packetId & 0x000003FF) << 16);
+
+ uint32_t capturePeriod = 123456; // Some capture period (microseconds)
+
+ // Get the first valid counter UID
+ const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
+ const Counters& counters = counterDirectory.GetCounters();
+ BOOST_CHECK(!counters.empty());
+ uint16_t counterUid = counters.begin()->first; // Valid counter UID
+
+ uint32_t length = 6;
+
+ auto data = std::make_unique<unsigned char[]>(length);
+ WriteUint32(data.get(), 0, capturePeriod);
+ WriteUint16(data.get(), 4, counterUid);
+
+ // Create the Periodic Counter Selection packet
+ Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter
+ // Capture thread
+
+ // Write the packet to the mock profiling connection
+ mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
+
+ // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet
+ int expectedPackets = 2;
+ std::vector<uint32_t> receivedPackets;
+
+ // Keep waiting until all the expected packets have been received
+ do
+ {
+ helper.WaitForProfilingPacketsSent();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
+ if (writtenData.empty())
+ {
+ BOOST_ERROR("Packets should be available for reading at this point");
+ return;
+ }
+ receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end());
+ expectedPackets -= boost::numeric_cast<int>(writtenData.size());
+ }
+ while (expectedPackets > 0);
+ BOOST_TEST(!receivedPackets.empty());
+
+ // The size of the expected Periodic Counter Selection packet (echos the sent one)
+ BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 14) != receivedPackets.end()));
+ // The size of the expected Periodic Counter Capture packet
+ BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 22) != receivedPackets.end()));
+
+ // The Periodic Counter Selection Handler should not have updated the profiling state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Reset the profiling service to stop any running thread
+ options.m_EnableProfiling = false;
+ profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketMultipleCounters)
+{
+ // Swap the profiling connection factory in the profiling service instance with our mock one
+ SwapProfilingConnectionFactoryHelper helper;
+
+ // Reset the profiling service to the uninitialized state
+ armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+ options.m_EnableProfiling = true;
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+
+ // Bring the profiling service to the "Active" state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Update(); // Initialize the counter directory
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+ profilingService.Update(); // Create the profiling connection
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+ profilingService.Update(); // Start the command handler and the send thread
+
+ // Wait for the Stream Metadata packet the be sent
+ // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+ helper.WaitForProfilingPacketsSent();
+
+ // Force the profiling service to the "Active" state
+ helper.ForceTransitionToState(ProfilingState::Active);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Get the mock profiling connection
+ MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
+ // Remove the packets received so far
+ mockProfilingConnection->Clear();
+
+ // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
+ // external profiling service
+
+ // Periodic Counter Selection packet header:
+ // 26:31 [6] packet_family: Control Packet Family, value 0b000000
+ // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
+ // 8:15 [8] reserved: Reserved, value 0b00000000
+ // 0:7 [8] reserved: Reserved, value 0b00000000
+ uint32_t packetFamily = 0;
+ uint32_t packetId = 4;
+ uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+ ((packetId & 0x000003FF) << 16);
+
+ uint32_t capturePeriod = 123456; // Some capture period (microseconds)
+
+ // Get the first valid counter UID
+ const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
+ const Counters& counters = counterDirectory.GetCounters();
+ BOOST_CHECK(counters.size() > 1);
+ uint16_t counterUidA = counters.begin()->first; // First valid counter UID
+ uint16_t counterUidB = (counters.begin()++)->first; // Second valid counter UID
+
+ uint32_t length = 8;
+
+ auto data = std::make_unique<unsigned char[]>(length);
+ WriteUint32(data.get(), 0, capturePeriod);
+ WriteUint16(data.get(), 4, counterUidA);
+ WriteUint16(data.get(), 6, counterUidB);
+
+ // Create the Periodic Counter Selection packet
+ Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter
+ // Capture thread
+
+ // Write the packet to the mock profiling connection
+ mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
+
+ // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet
+ int expectedPackets = 2;
+ std::vector<uint32_t> receivedPackets;
+
+ // Keep waiting until all the expected packets have been received
+ do
+ {
+ helper.WaitForProfilingPacketsSent();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
+ if (writtenData.empty())
+ {
+ BOOST_ERROR("Packets should be available for reading at this point");
+ return;
+ }
+ receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end());
+ expectedPackets -= boost::numeric_cast<int>(writtenData.size());
+ }
+ while (expectedPackets > 0);
+ BOOST_TEST(!receivedPackets.empty());
+
+ // The size of the expected Periodic Counter Selection packet (echos the sent one)
+ BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 16) != receivedPackets.end()));
+ // The size of the expected Periodic Counter Capture packet
+ BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 28) != receivedPackets.end()));
+
+ // The Periodic Counter Selection Handler should not have updated the profiling state
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
// Reset the profiling service to stop any running thread
options.m_EnableProfiling = false;
diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp
index 4d2f974344..21c98723be 100644
--- a/src/profiling/test/ProfilingTests.hpp
+++ b/src/profiling/test/ProfilingTests.hpp
@@ -9,14 +9,12 @@
#include <CommandHandlerFunctor.hpp>
#include <IProfilingConnection.hpp>
-#include <IProfilingConnectionFactory.hpp>
#include <Logging.hpp>
#include <ProfilingService.hpp>
#include <boost/test/unit_test.hpp>
#include <chrono>
-#include <iostream>
#include <thread>
namespace armnn
@@ -137,15 +135,6 @@ class TestFunctorC : public TestFunctorA
using TestFunctorA::TestFunctorA;
};
-class MockProfilingConnectionFactory : public IProfilingConnectionFactory
-{
-public:
- IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
- {
- return std::make_unique<MockProfilingConnection>();
- }
-};
-
class SwapProfilingConnectionFactoryHelper : public ProfilingService
{
public:
@@ -182,6 +171,11 @@ public:
TransitionToState(ProfilingService::Instance(), newState);
}
+ void WaitForProfilingPacketsSent()
+ {
+ return WaitForPacketSent(ProfilingService::Instance());
+ }
+
private:
MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp
index 871ca74124..73fc39b437 100644
--- a/src/profiling/test/SendCounterPacketTests.hpp
+++ b/src/profiling/test/SendCounterPacketTests.hpp
@@ -7,6 +7,7 @@
#include <SendCounterPacket.hpp>
#include <ProfilingUtils.hpp>
+#include <IProfilingConnectionFactory.hpp>
#include <armnn/Exceptions.hpp>
#include <armnn/Optional.hpp>
@@ -74,11 +75,13 @@ public:
return std::move(m_Packet);
}
- const std::vector<uint32_t> GetWrittenData() const
+ const std::vector<uint32_t> GetWrittenData()
{
std::lock_guard<std::mutex> lock(m_Mutex);
- return m_WrittenData;
+ std::vector<uint32_t> writtenData = m_WrittenData;
+ m_WrittenData.clear();
+ return writtenData;
}
void Clear()
@@ -95,6 +98,15 @@ private:
mutable std::mutex m_Mutex;
};
+class MockProfilingConnectionFactory : public IProfilingConnectionFactory
+{
+public:
+ IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
+ {
+ return std::make_unique<MockProfilingConnection>();
+ }
+};
+
class MockPacketBuffer : public IPacketBuffer
{
public: