From e848538efbdf01aa0b067da942c3c214f8e62826 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Thu, 10 Oct 2019 14:08:21 +0100 Subject: IVGCVSW-3964 Implement the Periodic Counter Selection command handler * Improved the PeriodicCounterPacket class to handle errors properly * Improved the PeriodicCounterSelectionCommandHandler to handle invalid counter UIDs in the selection packet * Added the Periodic Counter Selection command handler to the ProfilingService class * Code refactoring and added comments * Added WaitForPacketSent method to the SendCounterPacket class to allow waiting for the packets to be sent (useful in the unit tests) * Added unit tests and updated the old ones accordingly * Fixed threading issues with a number of unit tests Signed-off-by: Matteo Martincigh Change-Id: I271b7b0bfa801d88fe1725b934d24e30cd839ed7 --- .../ConnectionAcknowledgedCommandHandler.cpp | 2 +- src/profiling/Holder.cpp | 12 +- src/profiling/Holder.hpp | 4 +- src/profiling/ICounterValues.hpp | 1 + src/profiling/PeriodicCounterCapture.cpp | 33 +- .../PeriodicCounterSelectionCommandHandler.cpp | 115 ++-- .../PeriodicCounterSelectionCommandHandler.hpp | 31 +- src/profiling/ProfilingService.cpp | 32 +- src/profiling/ProfilingService.hpp | 24 + .../RequestCounterDirectoryCommandHandler.cpp | 2 +- src/profiling/SendCounterPacket.cpp | 18 +- src/profiling/SendCounterPacket.hpp | 12 +- src/profiling/test/ProfilingTests.cpp | 584 +++++++++++++++++++-- src/profiling/test/ProfilingTests.hpp | 16 +- src/profiling/test/SendCounterPacketTests.hpp | 16 +- 15 files changed, 784 insertions(+), 118 deletions(-) (limited to 'src/profiling') diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp index 9d2d1a2bd2..deffd1414b 100644 --- a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp +++ b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp @@ -22,7 +22,7 @@ void ConnectionAcknowledgedCommandHandler::operator()(const Packet& packet) { case ProfilingState::Uninitialised: case ProfilingState::NotConnected: - throw RuntimeException(boost::str(boost::format("Connection Acknowledged Handler invoked while in an " + throw RuntimeException(boost::str(boost::format("Connection Acknowledged Command Handler invoked while in an " "wrong state: %1%") % GetProfilingStateName(currentState))); case ProfilingState::WaitingForAck: diff --git a/src/profiling/Holder.cpp b/src/profiling/Holder.cpp index 5916017eb6..750be7ec74 100644 --- a/src/profiling/Holder.cpp +++ b/src/profiling/Holder.cpp @@ -11,10 +11,10 @@ namespace armnn namespace profiling { -CaptureData& CaptureData::operator= (const CaptureData& captureData) +CaptureData& CaptureData::operator=(const CaptureData& other) { - m_CapturePeriod = captureData.m_CapturePeriod; - m_CounterIds = captureData.m_CounterIds; + m_CapturePeriod = other.m_CapturePeriod; + m_CounterIds = other.m_CounterIds; return *this; } @@ -29,12 +29,12 @@ void CaptureData::SetCounterIds(const std::vector& counterIds) m_CounterIds = counterIds; } -std::uint32_t CaptureData::GetCapturePeriod() const +uint32_t CaptureData::GetCapturePeriod() const { return m_CapturePeriod; } -std::vector CaptureData::GetCounterIds() const +const std::vector& CaptureData::GetCounterIds() const { return m_CounterIds; } @@ -42,12 +42,14 @@ std::vector CaptureData::GetCounterIds() const CaptureData Holder::GetCaptureData() const { std::lock_guard lockGuard(m_CaptureThreadMutex); + return m_CaptureData; } void Holder::SetCaptureData(uint32_t capturePeriod, const std::vector& counterIds) { std::lock_guard lockGuard(m_CaptureThreadMutex); + m_CaptureData.SetCapturePeriod(capturePeriod); m_CaptureData.SetCounterIds(counterIds); } diff --git a/src/profiling/Holder.hpp b/src/profiling/Holder.hpp index 72ca0914a9..3143105ab4 100644 --- a/src/profiling/Holder.hpp +++ b/src/profiling/Holder.hpp @@ -27,12 +27,12 @@ public: : m_CapturePeriod(captureData.m_CapturePeriod) , m_CounterIds(captureData.m_CounterIds) {} - CaptureData& operator= (const CaptureData& captureData); + CaptureData& operator=(const CaptureData& other); void SetCapturePeriod(uint32_t capturePeriod); void SetCounterIds(const std::vector& counterIds); uint32_t GetCapturePeriod() const; - std::vector GetCounterIds() const; + const std::vector& GetCounterIds() const; private: uint32_t m_CapturePeriod; diff --git a/src/profiling/ICounterValues.hpp b/src/profiling/ICounterValues.hpp index 5e32ca2b37..18e34b6747 100644 --- a/src/profiling/ICounterValues.hpp +++ b/src/profiling/ICounterValues.hpp @@ -18,6 +18,7 @@ class IReadCounterValues public: virtual ~IReadCounterValues() {} + virtual bool IsCounterRegistered(uint16_t counterUid) const = 0; virtual uint16_t GetCounterCount() const = 0; virtual uint32_t GetCounterValue(uint16_t counterUid) const = 0; }; diff --git a/src/profiling/PeriodicCounterCapture.cpp b/src/profiling/PeriodicCounterCapture.cpp index 9002bfc065..0ccb516ae2 100644 --- a/src/profiling/PeriodicCounterCapture.cpp +++ b/src/profiling/PeriodicCounterCapture.cpp @@ -5,6 +5,8 @@ #include "PeriodicCounterCapture.hpp" +#include + namespace armnn { @@ -34,10 +36,13 @@ void PeriodicCounterCapture::Start() void PeriodicCounterCapture::Stop() { + // Signal the capture thread to stop m_KeepRunning.store(false); + // Check that the capture thread is running if (m_PeriodCaptureThread.joinable()) { + // Wait for the capture thread to complete operations m_PeriodCaptureThread.join(); } } @@ -51,10 +56,12 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues { while (m_KeepRunning.load()) { + // Check if the current capture data indicates that there's data capture auto currentCaptureData = ReadCaptureData(); - std::vector counterIds = currentCaptureData.GetCounterIds(); + const std::vector& counterIds = currentCaptureData.GetCounterIds(); if (currentCaptureData.GetCapturePeriod() == 0 || counterIds.empty()) { + // No data capture, terminate the thread m_KeepRunning.store(false); break; } @@ -63,12 +70,22 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues auto numCounters = counterIds.size(); values.reserve(numCounters); - // Create vector of pairs of CounterIndexes and Values - uint32_t counterValue = 0; + // Create a vector of pairs of CounterIndexes and Values for (uint16_t index = 0; index < numCounters; ++index) { auto requestedId = counterIds[index]; - counterValue = readCounterValues.GetCounterValue(requestedId); + uint32_t counterValue = 0; + try + { + counterValue = readCounterValues.GetCounterValue(requestedId); + } + catch (const Exception& e) + { + // Report the error and continue + BOOST_LOG_TRIVIAL(warning) << "An error has occurred when getting a counter value: " + << e.what() << std::endl; + continue; + } values.emplace_back(std::make_pair(requestedId, counterValue)); } @@ -81,9 +98,15 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues // Take a timestamp auto timestamp = clock::now(); + // Write a Periodic Counter Capture packet to the Counter Stream Buffer m_SendCounterPacket.SendPeriodicCounterCapturePacket( static_cast(timestamp.time_since_epoch().count()), values); - std::this_thread::sleep_for(std::chrono::milliseconds(currentCaptureData.GetCapturePeriod())); + + // Notify the Send Thread that new data is available in the Counter Stream Buffer + m_SendCounterPacket.SetReadyToRead(); + + // Wait the indicated capture period (microseconds) + std::this_thread::sleep_for(std::chrono::microseconds(currentCaptureData.GetCapturePeriod())); } m_IsRunning.store(false); diff --git a/src/profiling/PeriodicCounterSelectionCommandHandler.cpp b/src/profiling/PeriodicCounterSelectionCommandHandler.cpp index 9be37fcfd2..db09856dae 100644 --- a/src/profiling/PeriodicCounterSelectionCommandHandler.cpp +++ b/src/profiling/PeriodicCounterSelectionCommandHandler.cpp @@ -7,6 +7,9 @@ #include "ProfilingUtils.hpp" #include +#include + +#include namespace armnn { @@ -14,57 +17,109 @@ namespace armnn namespace profiling { -using namespace std; -using boost::numeric_cast; - void PeriodicCounterSelectionCommandHandler::ParseData(const Packet& packet, CaptureData& captureData) { std::vector counterIds; - uint32_t sizeOfUint32 = numeric_cast(sizeof(uint32_t)); - uint32_t sizeOfUint16 = numeric_cast(sizeof(uint16_t)); + uint32_t sizeOfUint32 = boost::numeric_cast(sizeof(uint32_t)); + uint32_t sizeOfUint16 = boost::numeric_cast(sizeof(uint16_t)); uint32_t offset = 0; - if (packet.GetLength() > 0) + if (packet.GetLength() < 4) { - if (packet.GetLength() >= 4) - { - captureData.SetCapturePeriod(ReadUint32(reinterpret_cast(packet.GetData()), offset)); + // Insufficient packet size + return; + } - unsigned int counters = (packet.GetLength() - 4) / 2; + // Parse the capture period + uint32_t capturePeriod = ReadUint32(packet.GetData(), offset); - if (counters > 0) - { - counterIds.reserve(counters); - offset += sizeOfUint32; - for(unsigned int pos = 0; pos < counters; ++pos) - { - counterIds.emplace_back(ReadUint16(reinterpret_cast(packet.GetData()), - offset)); - offset += sizeOfUint16; - } - } + // Set the capture period + captureData.SetCapturePeriod(capturePeriod); - captureData.SetCounterIds(counterIds); + // Parse the counter ids + unsigned int counters = (packet.GetLength() - 4) / 2; + if (counters > 0) + { + counterIds.reserve(counters); + offset += sizeOfUint32; + for (unsigned int i = 0; i < counters; ++i) + { + // Parse the counter id + uint16_t counterId = ReadUint16(packet.GetData(), offset); + counterIds.emplace_back(counterId); + offset += sizeOfUint16; } } + + // Set the counter ids + captureData.SetCounterIds(counterIds); } void PeriodicCounterSelectionCommandHandler::operator()(const Packet& packet) { - CaptureData captureData; + ProfilingState currentState = m_StateMachine.GetCurrentState(); + switch (currentState) + { + case ProfilingState::Uninitialised: + case ProfilingState::NotConnected: + case ProfilingState::WaitingForAck: + throw RuntimeException(boost::str(boost::format("Periodic Counter Selection Command Handler invoked while in " + "an wrong state: %1%") + % GetProfilingStateName(currentState))); + case ProfilingState::Active: + { + // Process the packet + if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 4u)) + { + throw armnn::InvalidArgumentException(boost::str(boost::format("Expected Packet family = 0, id = 4 but " + "received family = %1%, id = %2%") + % packet.GetPacketFamily() + % packet.GetPacketId())); + } + + // Parse the packet to get the capture period and counter UIDs + CaptureData captureData; + ParseData(packet, captureData); - ParseData(packet, captureData); + // Get the capture data + const uint32_t capturePeriod = captureData.GetCapturePeriod(); + const std::vector& counterIds = captureData.GetCounterIds(); - vector counterIds = captureData.GetCounterIds(); + // Check whether the selected counter UIDs are valid + std::vector validCounterIds; + for (uint16_t counterId : counterIds) + { + // Check whether the counter is registered + if (!m_ReadCounterValues.IsCounterRegistered(counterId)) + { + // Invalid counter UID, ignore it and continue + continue; + } - m_CaptureDataHolder.SetCaptureData(captureData.GetCapturePeriod(), counterIds); + // The counter is valid + validCounterIds.push_back(counterId); + } - m_CaptureThread.Start(); + // Set the capture data with only the valid counter UIDs + m_CaptureDataHolder.SetCaptureData(capturePeriod, validCounterIds); - // Write packet to Counter Stream Buffer - m_SendCounterPacket.SendPeriodicCounterSelectionPacket(captureData.GetCapturePeriod(), captureData.GetCounterIds()); + // Echo back the Periodic Counter Selection packet to the Counter Stream Buffer + m_SendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, validCounterIds); + + // Notify the Send Thread that new data is available in the Counter Stream Buffer + m_SendCounterPacket.SetReadyToRead(); + + // Start the Period Counter Capture thread (if not running already) + m_PeriodicCounterCapture.Start(); + + break; + } + default: + throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%") + % static_cast(currentState))); + } } } // namespace profiling -} // namespace armnn \ No newline at end of file +} // namespace armnn diff --git a/src/profiling/PeriodicCounterSelectionCommandHandler.hpp b/src/profiling/PeriodicCounterSelectionCommandHandler.hpp index e247e7773f..1da08e3c7a 100644 --- a/src/profiling/PeriodicCounterSelectionCommandHandler.hpp +++ b/src/profiling/PeriodicCounterSelectionCommandHandler.hpp @@ -10,10 +10,7 @@ #include "Holder.hpp" #include "SendCounterPacket.hpp" #include "IPeriodicCounterCapture.hpp" - -#include -#include -#include +#include "ICounterValues.hpp" namespace armnn { @@ -25,22 +22,30 @@ class PeriodicCounterSelectionCommandHandler : public CommandHandlerFunctor { public: - PeriodicCounterSelectionCommandHandler(uint32_t packetId, uint32_t version, Holder& captureDataHolder, - IPeriodicCounterCapture& captureThread, - ISendCounterPacket& sendCounterPacket) - : CommandHandlerFunctor(packetId, version), - m_CaptureDataHolder(captureDataHolder), - m_CaptureThread(captureThread), - m_SendCounterPacket(sendCounterPacket) + PeriodicCounterSelectionCommandHandler(uint32_t packetId, + uint32_t version, + Holder& captureDataHolder, + IPeriodicCounterCapture& periodicCounterCapture, + const IReadCounterValues& readCounterValue, + ISendCounterPacket& sendCounterPacket, + const ProfilingStateMachine& profilingStateMachine) + : CommandHandlerFunctor(packetId, version) + , m_CaptureDataHolder(captureDataHolder) + , m_PeriodicCounterCapture(periodicCounterCapture) + , m_ReadCounterValues(readCounterValue) + , m_SendCounterPacket(sendCounterPacket) + , m_StateMachine(profilingStateMachine) {} void operator()(const Packet& packet) override; - private: Holder& m_CaptureDataHolder; - IPeriodicCounterCapture& m_CaptureThread; + IPeriodicCounterCapture& m_PeriodicCounterCapture; + const IReadCounterValues& m_ReadCounterValues; ISendCounterPacket& m_SendCounterPacket; + const ProfilingStateMachine& m_StateMachine; + void ParseData(const Packet& packet, CaptureData& captureData); }; diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp index 693f8337db..79184416cd 100644 --- a/src/profiling/ProfilingService.cpp +++ b/src/profiling/ProfilingService.cpp @@ -53,6 +53,9 @@ void ProfilingService::Update() // Stop the send thread (if running) m_SendCounterPacket.Stop(false); + // Stop the periodic counter capture thread (if running) + m_PeriodicCounterCapture.Stop(); + // Reset any existing profiling connection m_ProfilingConnection.reset(); @@ -90,6 +93,9 @@ void ProfilingService::Update() break; case ProfilingState::Active: + // The period counter capture thread is started by the Periodic Counter Selection command handler upon + // request by an external profiling service + break; default: throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1") @@ -112,9 +118,14 @@ uint16_t ProfilingService::GetCounterCount() const return m_CounterDirectory.GetCounterCount(); } +bool ProfilingService::IsCounterRegistered(uint16_t counterUid) const +{ + return counterUid < m_CounterIndex.size(); +} + uint32_t ProfilingService::GetCounterValue(uint16_t counterUid) const { - BOOST_ASSERT(counterUid < m_CounterIndex.size()); + CheckCounterUid(counterUid); std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); BOOST_ASSERT(counterValuePtr); return counterValuePtr->load(std::memory_order::memory_order_relaxed); @@ -122,7 +133,7 @@ uint32_t ProfilingService::GetCounterValue(uint16_t counterUid) const void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value) { - BOOST_ASSERT(counterUid < m_CounterIndex.size()); + CheckCounterUid(counterUid); std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); BOOST_ASSERT(counterValuePtr); counterValuePtr->store(value, std::memory_order::memory_order_relaxed); @@ -130,7 +141,7 @@ void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value) uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value) { - BOOST_ASSERT(counterUid < m_CounterIndex.size()); + CheckCounterUid(counterUid); std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); BOOST_ASSERT(counterValuePtr); return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed); @@ -138,7 +149,7 @@ uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value) uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value) { - BOOST_ASSERT(counterUid < m_CounterIndex.size()); + CheckCounterUid(counterUid); std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); BOOST_ASSERT(counterValuePtr); return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed); @@ -146,7 +157,7 @@ uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t va uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid) { - BOOST_ASSERT(counterUid < m_CounterIndex.size()); + CheckCounterUid(counterUid); std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); BOOST_ASSERT(counterValuePtr); return counterValuePtr->operator++(std::memory_order::memory_order_relaxed); @@ -154,7 +165,7 @@ uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid) uint32_t ProfilingService::DecrementCounterValue(uint16_t counterUid) { - BOOST_ASSERT(counterUid < m_CounterIndex.size()); + CheckCounterUid(counterUid); std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); BOOST_ASSERT(counterValuePtr); return counterValuePtr->operator--(std::memory_order::memory_order_relaxed); @@ -239,6 +250,7 @@ void ProfilingService::Reset() // First stop the threads (Command Handler first)... m_CommandHandler.Stop(); m_SendCounterPacket.Stop(false); + m_PeriodicCounterCapture.Stop(); // ...then destroy the profiling connection... m_ProfilingConnection.reset(); @@ -252,6 +264,14 @@ void ProfilingService::Reset() m_StateMachine.Reset(); } +inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const +{ + if (!IsCounterRegistered(counterUid)) + { + throw InvalidArgumentException(boost::str(boost::format("Counter UID %1% is not registered") % counterUid)); + } +} + } // namespace profiling } // namespace armnn diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp index 0e66924267..dd70af4b39 100644 --- a/src/profiling/ProfilingService.hpp +++ b/src/profiling/ProfilingService.hpp @@ -12,8 +12,10 @@ #include "CommandHandler.hpp" #include "BufferManager.hpp" #include "SendCounterPacket.hpp" +#include "PeriodicCounterCapture.hpp" #include "ConnectionAcknowledgedCommandHandler.hpp" #include "RequestCounterDirectoryCommandHandler.hpp" +#include "PeriodicCounterSelectionCommandHandler.hpp" namespace armnn { @@ -46,6 +48,7 @@ public: // Getters for the profiling service state const ICounterDirectory& GetCounterDirectory() const; ProfilingState GetCurrentState() const; + bool IsCounterRegistered(uint16_t counterUid) const override; uint16_t GetCounterCount() const override; uint32_t GetCounterValue(uint16_t counterUid) const override; @@ -68,6 +71,9 @@ private: void InitializeCounterValue(uint16_t counterUid); void Reset(); + // Helper function + void CheckCounterUid(uint16_t counterUid) const; + // Profiling service components ExternalProfilingOptions m_Options; CounterDirectory m_CounterDirectory; @@ -81,8 +87,11 @@ private: CommandHandler m_CommandHandler; BufferManager m_BufferManager; SendCounterPacket m_SendCounterPacket; + Holder m_Holder; + PeriodicCounterCapture m_PeriodicCounterCapture; ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler; RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler; + PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler; protected: // Default constructor/destructor kept protected for testing @@ -102,6 +111,7 @@ protected: m_PacketVersionResolver) , m_BufferManager() , m_SendCounterPacket(m_StateMachine, m_BufferManager) + , m_PeriodicCounterCapture(m_Holder, m_SendCounterPacket, *this) , m_ConnectionAcknowledgedCommandHandler(1, m_PacketVersionResolver.ResolvePacketVersion(1).GetEncodedValue(), m_StateMachine) @@ -110,12 +120,22 @@ protected: m_CounterDirectory, m_SendCounterPacket, m_StateMachine) + , m_PeriodicCounterSelectionCommandHandler(4, + m_PacketVersionResolver.ResolvePacketVersion(4).GetEncodedValue(), + m_Holder, + m_PeriodicCounterCapture, + *this, + m_SendCounterPacket, + m_StateMachine) { // Register the "Connection Acknowledged" command handler m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler); // Register the "Request Counter Directory" command handler m_CommandHandlerRegistry.RegisterFunctor(&m_RequestCounterDirectoryCommandHandler); + + // Register the "Periodic Counter Selection" command handler + m_CommandHandlerRegistry.RegisterFunctor(&m_PeriodicCounterSelectionCommandHandler); } ~ProfilingService() = default; @@ -138,6 +158,10 @@ protected: { instance.m_StateMachine.TransitionToState(newState); } + void WaitForPacketSent(ProfilingService& instance) + { + return instance.m_SendCounterPacket.WaitForPacketSent(); + } }; } // namespace profiling diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.cpp b/src/profiling/RequestCounterDirectoryCommandHandler.cpp index e85acb4215..b8ac9d9426 100644 --- a/src/profiling/RequestCounterDirectoryCommandHandler.cpp +++ b/src/profiling/RequestCounterDirectoryCommandHandler.cpp @@ -21,7 +21,7 @@ void RequestCounterDirectoryCommandHandler::operator()(const Packet& packet) case ProfilingState::Uninitialised: case ProfilingState::NotConnected: case ProfilingState::WaitingForAck: - throw RuntimeException(boost::str(boost::format("Request Counter Directory Handler invoked while in an " + throw RuntimeException(boost::str(boost::format("Request Counter Directory Comand Handler invoked while in an " "wrong state: %1%") % GetProfilingStateName(currentState))); case ProfilingState::Active: diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp index e48da3ed7c..41adf37244 100644 --- a/src/profiling/SendCounterPacket.cpp +++ b/src/profiling/SendCounterPacket.cpp @@ -1035,17 +1035,21 @@ void SendCounterPacket::Send(IProfilingConnection& profilingConnection) } // Ensure that all readable data got written to the profiling connection before the thread is stopped - FlushBuffer(profilingConnection); + // (do not notify any watcher in this case, as this is just to wrap up things before shutting down the send thread) + FlushBuffer(profilingConnection, false); // Mark the send thread as not running m_IsRunning.store(false); } -void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection) +void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection, bool notifyWatchers) { // Get the first available readable buffer std::unique_ptr packetBuffer = m_BufferManager.GetReadableBuffer(); + // Initialize the flag that indicates whether at least a packet has been sent + bool packetsSent = false; + while (packetBuffer != nullptr) { // Get the data to send from the buffer @@ -1066,6 +1070,9 @@ void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection) { // Write a packet to the profiling connection. Silently ignore any write error and continue profilingConnection.WritePacket(readBuffer, boost::numeric_cast(readBufferSize)); + + // Set the flag that indicates whether at least a packet has been sent + packetsSent = true; } // Mark the packet buffer as read @@ -1074,6 +1081,13 @@ void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection) // Get the next available readable buffer packetBuffer = m_BufferManager.GetReadableBuffer(); } + + // Check whether at least a packet has been sent + if (packetsSent && notifyWatchers) + { + // Notify to any watcher that something has been sent + m_PacketSentWaitCondition.notify_one(); + } } } // namespace profiling diff --git a/src/profiling/SendCounterPacket.hpp b/src/profiling/SendCounterPacket.hpp index 9361efbc74..e1a42aa496 100644 --- a/src/profiling/SendCounterPacket.hpp +++ b/src/profiling/SendCounterPacket.hpp @@ -65,6 +65,14 @@ public: void Stop(bool rethrowSendThreadExceptions = true); bool IsRunning() { return m_IsRunning.load(); } + void WaitForPacketSent() + { + std::unique_lock lock(m_PacketSentWaitMutex); + + // Blocks until notified that at least a packet has been sent + m_PacketSentWaitCondition.wait(lock); + } + private: void Send(IProfilingConnection& profilingConnection); @@ -93,7 +101,7 @@ private: throw ExceptionType(errorMessage); } - void FlushBuffer(IProfilingConnection& profilingConnection); + void FlushBuffer(IProfilingConnection& profilingConnection, bool notifyWatchers = true); ProfilingStateMachine& m_StateMachine; IBufferManager& m_BufferManager; @@ -104,6 +112,8 @@ private: std::atomic m_IsRunning; std::atomic m_KeepRunning; std::exception_ptr m_SendThreadException; + std::mutex m_PacketSentWaitMutex; + std::condition_variable m_PacketSentWaitCondition; protected: // Helper methods, protected for testing diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 27bacf7145..554b7e1936 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -35,6 +35,7 @@ #include #include #include +#include using namespace armnn::profiling; @@ -1691,11 +1692,19 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) void Stop() override {} }; + class TestReadCounterValues : public IReadCounterValues + { + bool IsCounterRegistered(uint16_t counterUid) const override { return true; } + uint16_t GetCounterCount() const override { return 0; } + uint32_t GetCounterValue(uint16_t counterUid) const override { return 0; } + }; + const uint32_t packetId = 0x40000; uint32_t version = 1; Holder holder; TestCaptureThread captureThread; + TestReadCounterValues readCounterValues; MockBufferManager mockBuffer(512); SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); @@ -1718,16 +1727,29 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) Packet packetA(packetId, dataLength1, uniqueData1); - PeriodicCounterSelectionCommandHandler commandHandler(packetId, version, holder, captureThread, - sendCounterPacket); - commandHandler(packetA); + PeriodicCounterSelectionCommandHandler commandHandler(packetId, + version, + holder, + captureThread, + readCounterValues, + sendCounterPacket, + profilingStateMachine); - std::vector counterIds = holder.GetCaptureData().GetCounterIds(); + profilingStateMachine.TransitionToState(ProfilingState::Uninitialised); + BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException); + profilingStateMachine.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException); + profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException); + profilingStateMachine.TransitionToState(ProfilingState::Active); + BOOST_CHECK_NO_THROW(commandHandler(packetA)); + + const std::vector counterIdsA = holder.GetCaptureData().GetCounterIds(); BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period1); - BOOST_TEST(counterIds.size() == 2); - BOOST_TEST(counterIds[0] == 4000); - BOOST_TEST(counterIds[1] == 5000); + BOOST_TEST(counterIdsA.size() == 2); + BOOST_TEST(counterIdsA[0] == 4000); + BOOST_TEST(counterIdsA[1] == 5000); auto readBuffer = mockBuffer.GetReadableBuffer(); @@ -1766,10 +1788,10 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) commandHandler(packetB); - counterIds = holder.GetCaptureData().GetCounterIds(); + const std::vector counterIdsB = holder.GetCaptureData().GetCounterIds(); BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period2); - BOOST_TEST(counterIds.size() == 0); + BOOST_TEST(counterIdsB.size() == 0); readBuffer = mockBuffer.GetReadableBuffer(); @@ -2024,35 +2046,40 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread) public: CaptureReader() {} + bool IsCounterRegistered(uint16_t counterUid) const override + { + return m_Data.find(counterUid) != m_Data.end(); + } + uint16_t GetCounterCount() const override { return boost::numeric_cast(m_Data.size()); } - uint32_t GetCounterValue(uint16_t index) const override + uint32_t GetCounterValue(uint16_t counterUid) const override { - if (m_Data.find(index) == m_Data.end()) + if (m_Data.find(counterUid) == m_Data.end()) { return 0; } - return m_Data.at(index); + return m_Data.at(counterUid).load(); } - void SetCounterValue(uint16_t index, uint32_t value) + void SetCounterValue(uint16_t counterUid, uint32_t value) { - if (m_Data.find(index) == m_Data.end()) + if (m_Data.find(counterUid) == m_Data.end()) { - m_Data.insert(std::pair(index, value)); + m_Data.insert(std::make_pair(counterUid, value)); } else { - m_Data.at(index) = value; + m_Data.at(counterUid).store(value); } } private: - std::unordered_map m_Data; + std::unordered_map> m_Data; }; ProfilingStateMachine profilingStateMachine; @@ -2261,19 +2288,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) // Bring the profiling service to the "WaitingForAck" state BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); - profilingService.Update(); + profilingService.Update(); // Initialize the counter directory BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); - profilingService.Update(); - BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); - profilingService.Update(); - - // Wait for a bit to make sure that we get the packet - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + profilingService.Update();// Create the profiling connection // Get the mock profiling connection MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); BOOST_CHECK(mockProfilingConnection); + // Remove the packets received so far + mockProfilingConnection->Clear(); + + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); + + // Wait for the Stream Metadata packet to be sent + helper.WaitForProfilingPacketsSent(); + // Check that the mock profiling connection contains one Stream Metadata packet const std::vector writtenData = mockProfilingConnection->GetWrittenData(); BOOST_TEST(writtenData.size() == 1); @@ -2330,19 +2361,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) // Bring the profiling service to the "WaitingForAck" state BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); - profilingService.Update(); + profilingService.Update(); // Initialize the counter directory BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); - profilingService.Update(); - BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); - profilingService.Update(); - - // Wait for a bit to make sure that we get the packet - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + profilingService.Update(); // Create the profiling connection // Get the mock profiling connection MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); BOOST_CHECK(mockProfilingConnection); + // Remove the packets received so far + mockProfilingConnection->Clear(); + + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet to be sent + helper.WaitForProfilingPacketsSent(); + // Check that the mock profiling connection contains one Stream Metadata packet const std::vector writtenData = mockProfilingConnection->GetWrittenData(); BOOST_TEST(writtenData.size() == 1); @@ -2403,7 +2438,13 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket) BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); profilingService.Update(); // Create the profiling connection BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); - profilingService.Update(); // Start the threads + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state helper.ForceTransitionToState(ProfilingState::Active); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); @@ -2411,6 +2452,9 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket) MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); BOOST_CHECK(mockProfilingConnection); + // Remove the packets received so far + mockProfilingConnection->Clear(); + // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid // reply from an external profiling service @@ -2437,7 +2481,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket) // Check that the expected error has occurred and logged to the standard output BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=123 and Version=4194304 does not exist")); - // The Connection Acknowledged Command Handler should not have updated the profiling state + // The Request Counter Directory Command Handler should not have updated the profiling state BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); // Reset the profiling service to stop any running thread @@ -2462,7 +2506,13 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket) BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); profilingService.Update(); // Create the profiling connection BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); - profilingService.Update(); // Start the threads + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state helper.ForceTransitionToState(ProfilingState::Active); BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); @@ -2470,6 +2520,9 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket) MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); BOOST_CHECK(mockProfilingConnection); + // Remove the packets received so far + mockProfilingConnection->Clear(); + // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid // reply from an external profiling service @@ -2489,17 +2542,470 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket) // Write the packet to the mock profiling connection mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket)); + // Wait for the Counter Directory packet to be sent + helper.WaitForProfilingPacketsSent(); + + // Check that the mock profiling connection contains one Counter Directory packet + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); + BOOST_TEST(writtenData.size() == 1); + BOOST_TEST(writtenData[0] == 416); // The size of the expected Counter Directory packet + + // The Request Counter Directory Command Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadPeriodicCounterSelectionPacket) +{ + // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output + LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Redirect the standard output to a local stream so that we can parse the warning message + std::stringstream ss; + StreamRedirector streamRedirector(std::cout, ss.rdbuf()); + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); // Initialize the counter directory + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Remove the packets received so far + mockProfilingConnection->Clear(); + + // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an + // external profiling service + + // Periodic Counter Selection packet header: + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000100 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 999; // Wrong packet id!!! + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + // Create the Periodic Counter Selection packet + Packet periodicCounterSelectionPacket(header); // Length == 0, this will disable the collection of counters + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket)); + // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that - // the Create the Request Counter packet gets processed by the profiling service + // the Periodic Counter Selection packet gets processed by the profiling service std::this_thread::sleep_for(std::chrono::seconds(2)); - // The Connection Acknowledged Command Handler should not have updated the profiling state + // Check that the expected error has occurred and logged to the standard output + BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=999 and Version=4194304 does not exist")); + + // The Periodic Counter Selection Handler should not have updated the profiling state BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); - // Check that the mock profiling connection contains one Counter Directory packet + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadPeriodicCounterSelectionPacketInvalidCounterUid) +{ + // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output + LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); // Initialize the counter directory + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Remove the packets received so far + mockProfilingConnection->Clear(); + + // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an + // external profiling service + + // Periodic Counter Selection packet header: + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000100 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 4; + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + uint32_t capturePeriod = 123456; // Some capture period (microseconds) + + // Get the first valid counter UID + const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory(); + const Counters& counters = counterDirectory.GetCounters(); + BOOST_CHECK(counters.size() > 1); + uint16_t counterUidA = counters.begin()->first; // First valid counter UID + uint16_t counterUidB = 9999; // Second invalid counter UID + + uint32_t length = 8; + + auto data = std::make_unique(length); + WriteUint32(data.get(), 0, capturePeriod); + WriteUint16(data.get(), 4, counterUidA); + WriteUint16(data.get(), 6, counterUidB); + + // Create the Periodic Counter Selection packet + Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter + // Capture thread + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket)); + + // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet + int expectedPackets = 2; + std::vector receivedPackets; + + // Keep waiting until all the expected packets have been received + do + { + helper.WaitForProfilingPacketsSent(); + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); + if (writtenData.empty()) + { + BOOST_ERROR("Packets should be available for reading at this point"); + return; + } + receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end()); + expectedPackets -= boost::numeric_cast(writtenData.size()); + } + while (expectedPackets > 0); + BOOST_TEST(!receivedPackets.empty()); + + // The size of the expected Periodic Counter Selection packet + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 14) != receivedPackets.end())); + // The size of the expected Periodic Counter Capture packet + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 22) != receivedPackets.end())); + + // The Periodic Counter Selection Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketNoCounters) +{ + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); // Initialize the counter directory + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Remove the packets received so far + mockProfilingConnection->Clear(); + + // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an + // external profiling service + + // Periodic Counter Selection packet header: + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000100 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 4; + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + // Create the Periodic Counter Selection packet + Packet periodicCounterSelectionPacket(header); // Length == 0, this will disable the collection of counters + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket)); + + // Wait for the Periodic Counter Selection packet to be sent + helper.WaitForProfilingPacketsSent(); + + // The Periodic Counter Selection Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Check that the mock profiling connection contains one Periodic Counter Selection const std::vector writtenData = mockProfilingConnection->GetWrittenData(); - BOOST_TEST(writtenData.size() == 1); - BOOST_TEST(writtenData[0] == 416); // The size of a valid Counter Directory packet + BOOST_TEST(writtenData.size() == 1); // Only one packet is expected (no Periodic Counter packets) + BOOST_TEST(writtenData[0] == 12); // The size of the expected Periodic Counter Selection (echos the sent one) + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketSingleCounter) +{ + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); // Initialize the counter directory + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Remove the packets received so far + mockProfilingConnection->Clear(); + + // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an + // external profiling service + + // Periodic Counter Selection packet header: + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000100 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 4; + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + uint32_t capturePeriod = 123456; // Some capture period (microseconds) + + // Get the first valid counter UID + const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory(); + const Counters& counters = counterDirectory.GetCounters(); + BOOST_CHECK(!counters.empty()); + uint16_t counterUid = counters.begin()->first; // Valid counter UID + + uint32_t length = 6; + + auto data = std::make_unique(length); + WriteUint32(data.get(), 0, capturePeriod); + WriteUint16(data.get(), 4, counterUid); + + // Create the Periodic Counter Selection packet + Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter + // Capture thread + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket)); + + // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet + int expectedPackets = 2; + std::vector receivedPackets; + + // Keep waiting until all the expected packets have been received + do + { + helper.WaitForProfilingPacketsSent(); + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); + if (writtenData.empty()) + { + BOOST_ERROR("Packets should be available for reading at this point"); + return; + } + receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end()); + expectedPackets -= boost::numeric_cast(writtenData.size()); + } + while (expectedPackets > 0); + BOOST_TEST(!receivedPackets.empty()); + + // The size of the expected Periodic Counter Selection packet (echos the sent one) + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 14) != receivedPackets.end())); + // The size of the expected Periodic Counter Capture packet + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 22) != receivedPackets.end())); + + // The Periodic Counter Selection Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketMultipleCounters) +{ + // Swap the profiling connection factory in the profiling service instance with our mock one + SwapProfilingConnectionFactoryHelper helper; + + // Reset the profiling service to the uninitialized state + armnn::Runtime::CreationOptions::ExternalProfilingOptions options; + options.m_EnableProfiling = true; + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + // Bring the profiling service to the "Active" state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Update(); // Initialize the counter directory + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); + profilingService.Update(); // Create the profiling connection + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + profilingService.Update(); // Start the command handler and the send thread + + // Wait for the Stream Metadata packet the be sent + // (we are not testing the connection acknowledgement here so it will be ignored by this test) + helper.WaitForProfilingPacketsSent(); + + // Force the profiling service to the "Active" state + helper.ForceTransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + + // Remove the packets received so far + mockProfilingConnection->Clear(); + + // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an + // external profiling service + + // Periodic Counter Selection packet header: + // 26:31 [6] packet_family: Control Packet Family, value 0b000000 + // 16:25 [10] packet_id: Packet identifier, value 0b0000000100 + // 8:15 [8] reserved: Reserved, value 0b00000000 + // 0:7 [8] reserved: Reserved, value 0b00000000 + uint32_t packetFamily = 0; + uint32_t packetId = 4; + uint32_t header = ((packetFamily & 0x0000003F) << 26) | + ((packetId & 0x000003FF) << 16); + + uint32_t capturePeriod = 123456; // Some capture period (microseconds) + + // Get the first valid counter UID + const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory(); + const Counters& counters = counterDirectory.GetCounters(); + BOOST_CHECK(counters.size() > 1); + uint16_t counterUidA = counters.begin()->first; // First valid counter UID + uint16_t counterUidB = (counters.begin()++)->first; // Second valid counter UID + + uint32_t length = 8; + + auto data = std::make_unique(length); + WriteUint32(data.get(), 0, capturePeriod); + WriteUint16(data.get(), 4, counterUidA); + WriteUint16(data.get(), 6, counterUidB); + + // Create the Periodic Counter Selection packet + Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter + // Capture thread + + // Write the packet to the mock profiling connection + mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket)); + + // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet + int expectedPackets = 2; + std::vector receivedPackets; + + // Keep waiting until all the expected packets have been received + do + { + helper.WaitForProfilingPacketsSent(); + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); + if (writtenData.empty()) + { + BOOST_ERROR("Packets should be available for reading at this point"); + return; + } + receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end()); + expectedPackets -= boost::numeric_cast(writtenData.size()); + } + while (expectedPackets > 0); + BOOST_TEST(!receivedPackets.empty()); + + // The size of the expected Periodic Counter Selection packet (echos the sent one) + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 16) != receivedPackets.end())); + // The size of the expected Periodic Counter Capture packet + BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 28) != receivedPackets.end())); + + // The Periodic Counter Selection Handler should not have updated the profiling state + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); // Reset the profiling service to stop any running thread options.m_EnableProfiling = false; diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp index 4d2f974344..21c98723be 100644 --- a/src/profiling/test/ProfilingTests.hpp +++ b/src/profiling/test/ProfilingTests.hpp @@ -9,14 +9,12 @@ #include #include -#include #include #include #include #include -#include #include namespace armnn @@ -137,15 +135,6 @@ class TestFunctorC : public TestFunctorA using TestFunctorA::TestFunctorA; }; -class MockProfilingConnectionFactory : public IProfilingConnectionFactory -{ -public: - IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override - { - return std::make_unique(); - } -}; - class SwapProfilingConnectionFactoryHelper : public ProfilingService { public: @@ -182,6 +171,11 @@ public: TransitionToState(ProfilingService::Instance(), newState); } + void WaitForProfilingPacketsSent() + { + return WaitForPacketSent(ProfilingService::Instance()); + } + private: MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory; IProfilingConnectionFactory* m_BackupProfilingConnectionFactory; diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp index 871ca74124..73fc39b437 100644 --- a/src/profiling/test/SendCounterPacketTests.hpp +++ b/src/profiling/test/SendCounterPacketTests.hpp @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -74,11 +75,13 @@ public: return std::move(m_Packet); } - const std::vector GetWrittenData() const + const std::vector GetWrittenData() { std::lock_guard lock(m_Mutex); - return m_WrittenData; + std::vector writtenData = m_WrittenData; + m_WrittenData.clear(); + return writtenData; } void Clear() @@ -95,6 +98,15 @@ private: mutable std::mutex m_Mutex; }; +class MockProfilingConnectionFactory : public IProfilingConnectionFactory +{ +public: + IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override + { + return std::make_unique(); + } +}; + class MockPacketBuffer : public IPacketBuffer { public: -- cgit v1.2.1