From 3896b47a3532aadcde43a3e7fed760a0f4a29e6b Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Mon, 10 Feb 2020 12:24:15 +0000 Subject: IVGCVSW-4328 BufferManager running out of buffers crashes application * Refactored SendCounterPacket classes, separated SendCounterPacket from Send thread * Created ISendThread.hpp, IConsumer, SendThread.hpp and SendThread.cpp * Injected IConsumer to BufferManager to notify SendThread when packet is ready to read Signed-off-by: Sadik Armagan Change-Id: I80f0bb8b8401c6bfd1611f7760217c6fe35d7ad8 --- Android.mk | 1 + CMakeLists.txt | 4 + src/profiling/BufferManager.cpp | 21 +- src/profiling/BufferManager.hpp | 13 +- .../ConnectionAcknowledgedCommandHandler.cpp | 3 - src/profiling/IBufferManager.hpp | 7 +- src/profiling/IConsumer.hpp | 26 ++ src/profiling/ISendCounterPacket.hpp | 4 - src/profiling/ISendThread.hpp | 31 ++ src/profiling/PeriodicCounterCapture.cpp | 3 - .../PeriodicCounterSelectionCommandHandler.cpp | 3 - .../PeriodicCounterSelectionCommandHandler.hpp | 1 + src/profiling/ProfilingService.cpp | 6 +- src/profiling/ProfilingService.hpp | 7 +- .../RequestCounterDirectoryCommandHandler.cpp | 3 - src/profiling/SendCounterPacket.cpp | 238 +------------ src/profiling/SendCounterPacket.hpp | 47 +-- src/profiling/SendThread.cpp | 278 +++++++++++++++ src/profiling/SendThread.hpp | 75 ++++ src/profiling/test/ProfilingTests.cpp | 19 +- src/profiling/test/SendCounterPacketTests.cpp | 384 +++++++-------------- src/profiling/test/SendCounterPacketTests.hpp | 66 +++- 22 files changed, 667 insertions(+), 573 deletions(-) create mode 100644 src/profiling/IConsumer.hpp create mode 100644 src/profiling/ISendThread.hpp create mode 100644 src/profiling/SendThread.cpp create mode 100644 src/profiling/SendThread.hpp diff --git a/Android.mk b/Android.mk index ba7f3b52cf..82a96cda53 100644 --- a/Android.mk +++ b/Android.mk @@ -202,6 +202,7 @@ LOCAL_SRC_FILES := \ src/profiling/RegisterBackendCounters.cpp \ src/profiling/RequestCounterDirectoryCommandHandler.cpp \ src/profiling/SendCounterPacket.cpp \ + src/profiling/SendThread.cpp \ src/profiling/SendTimelinePacket.cpp \ src/profiling/SocketProfilingConnection.cpp \ src/profiling/TimelinePacketWriterFactory.cpp \ diff --git a/CMakeLists.txt b/CMakeLists.txt index 763c07e275..d534c0a6d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -469,10 +469,12 @@ list(APPEND armnn_sources src/profiling/Holder.cpp src/profiling/Holder.hpp src/profiling/IBufferManager.hpp + src/profiling/IConsumer.hpp src/profiling/ICounterDirectory.hpp src/profiling/ICounterRegistry.hpp src/profiling/ICounterValues.hpp src/profiling/ISendCounterPacket.hpp + src/profiling/ISendThread.hpp src/profiling/IPacketBuffer.hpp src/profiling/IPeriodicCounterCapture.hpp src/profiling/IProfilingConnection.hpp @@ -507,6 +509,8 @@ list(APPEND armnn_sources src/profiling/RequestCounterDirectoryCommandHandler.hpp src/profiling/SendCounterPacket.cpp src/profiling/SendCounterPacket.hpp + src/profiling/SendThread.cpp + src/profiling/SendThread.hpp src/profiling/SendTimelinePacket.cpp src/profiling/SendTimelinePacket.hpp src/profiling/SocketProfilingConnection.cpp diff --git a/src/profiling/BufferManager.cpp b/src/profiling/BufferManager.cpp index b24bf4b5b0..f5ab729259 100644 --- a/src/profiling/BufferManager.cpp +++ b/src/profiling/BufferManager.cpp @@ -40,13 +40,18 @@ IPacketBufferPtr BufferManager::Reserve(unsigned int requestedSize, unsigned int return buffer; } -void BufferManager::Commit(IPacketBufferPtr& packetBuffer, unsigned int size) +void BufferManager::Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer) { std::unique_lock readableListLock(m_ReadableMutex, std::defer_lock); packetBuffer->Commit(size); readableListLock.lock(); m_ReadableList.push_back(std::move(packetBuffer)); readableListLock.unlock(); + + if (notifyConsumer) + { + FlushReadList(); + } } void BufferManager::Initialize() @@ -103,6 +108,20 @@ void BufferManager::MarkRead(IPacketBufferPtr& packetBuffer) availableListLock.unlock(); } +void BufferManager::SetConsumer(IConsumer* consumer) +{ + m_Consumer = consumer; +} + +void BufferManager::FlushReadList() +{ + // notify consumer that packet is ready to read + if (m_Consumer != nullptr) + { + m_Consumer->SetReadyToRead(); + } +} + } // namespace profiling } // namespace armnn diff --git a/src/profiling/BufferManager.hpp b/src/profiling/BufferManager.hpp index 495b113867..d678cd3ec0 100644 --- a/src/profiling/BufferManager.hpp +++ b/src/profiling/BufferManager.hpp @@ -6,6 +6,7 @@ #pragma once #include "IBufferManager.hpp" +#include "IConsumer.hpp" #include #include @@ -28,7 +29,7 @@ public: void Reset(); - void Commit(IPacketBufferPtr& packetBuffer, unsigned int size) override; + void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override; void Release(IPacketBufferPtr& packetBuffer) override; @@ -36,6 +37,13 @@ public: void MarkRead(IPacketBufferPtr& packetBuffer) override; + /// Set Consumer on the buffer manager to be notified when there is a Commit + /// Can only be one consumer + void SetConsumer(IConsumer* consumer) override; + + /// Notify the Consumer buffer can be read + void FlushReadList() override; + private: void Initialize(); @@ -55,6 +63,9 @@ private: // Mutex for readable packet buffer list std::mutex m_ReadableMutex; + + // Consumer thread to notify packet is ready to read + IConsumer* m_Consumer = nullptr; }; } // namespace profiling diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp index a2a045d1e0..72d09b3da7 100644 --- a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp +++ b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp @@ -41,9 +41,6 @@ void ConnectionAcknowledgedCommandHandler::operator()(const Packet& packet) m_SendCounterPacket.SendCounterDirectoryPacket(m_CounterDirectory); m_SendTimelinePacket.SendTimelineMessageDirectoryPackage(); - // Notify the Send Thread that new data is available in the Counter Stream Buffer - m_SendCounterPacket.SetReadyToRead(); - break; case ProfilingState::Active: return; // NOP diff --git a/src/profiling/IBufferManager.hpp b/src/profiling/IBufferManager.hpp index 2b497da585..01ecb8222d 100644 --- a/src/profiling/IBufferManager.hpp +++ b/src/profiling/IBufferManager.hpp @@ -5,6 +5,7 @@ #pragma once +#include "IConsumer.hpp" #include "IPacketBuffer.hpp" #include @@ -24,13 +25,17 @@ public: virtual IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) = 0; - virtual void Commit(IPacketBufferPtr& packetBuffer, unsigned int size) = 0; + virtual void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) = 0; virtual void Release(IPacketBufferPtr& packetBuffer) = 0; virtual IPacketBufferPtr GetReadableBuffer() = 0; virtual void MarkRead(IPacketBufferPtr& packetBuffer) = 0; + + virtual void SetConsumer(IConsumer* consumer) = 0; + + virtual void FlushReadList() = 0; }; } // namespace profiling diff --git a/src/profiling/IConsumer.hpp b/src/profiling/IConsumer.hpp new file mode 100644 index 0000000000..f00f17458b --- /dev/null +++ b/src/profiling/IConsumer.hpp @@ -0,0 +1,26 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +namespace armnn +{ + +namespace profiling +{ + +class IConsumer +{ +public: + virtual ~IConsumer() {} + + /// Set a "ready to read" flag in the buffer to notify the reading thread to start reading it. + virtual void SetReadyToRead() = 0; +}; + +} // namespace profiling + +} // namespace armnn + diff --git a/src/profiling/ISendCounterPacket.hpp b/src/profiling/ISendCounterPacket.hpp index d666f8bc36..5c8e6b8d46 100644 --- a/src/profiling/ISendCounterPacket.hpp +++ b/src/profiling/ISendCounterPacket.hpp @@ -32,10 +32,6 @@ public: /// Create and write a PeriodicCounterSelectionPacket from the parameters to the buffer. virtual void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod, const std::vector& selectedCounterIds) = 0; - - /// Set a "ready to read" flag in the buffer to notify the reading thread to start reading it. - virtual void SetReadyToRead() = 0; - }; } // namespace profiling diff --git a/src/profiling/ISendThread.hpp b/src/profiling/ISendThread.hpp new file mode 100644 index 0000000000..c5e05b183c --- /dev/null +++ b/src/profiling/ISendThread.hpp @@ -0,0 +1,31 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "IProfilingConnection.hpp" + +namespace armnn +{ + +namespace profiling +{ + +class ISendThread +{ +public: + virtual ~ISendThread() {} + + /// Start the thread + virtual void Start(IProfilingConnection& profilingConnection) = 0; + + /// Stop the thread + virtual void Stop(bool rethrowSendThreadExceptions = true) = 0; +}; + +} // namespace profiling + +} // namespace armnn + diff --git a/src/profiling/PeriodicCounterCapture.cpp b/src/profiling/PeriodicCounterCapture.cpp index f3bb5e9202..d60cbd7d15 100644 --- a/src/profiling/PeriodicCounterCapture.cpp +++ b/src/profiling/PeriodicCounterCapture.cpp @@ -99,9 +99,6 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues // Write a Periodic Counter Capture packet to the Counter Stream Buffer m_SendCounterPacket.SendPeriodicCounterCapturePacket(timestamp, values); - // Notify the Send Thread that new data is available in the Counter Stream Buffer - m_SendCounterPacket.SetReadyToRead(); - // Wait the indicated capture period (microseconds) std::this_thread::sleep_for(std::chrono::microseconds(currentCaptureData.GetCapturePeriod())); diff --git a/src/profiling/PeriodicCounterSelectionCommandHandler.cpp b/src/profiling/PeriodicCounterSelectionCommandHandler.cpp index a6b6a050ad..4a051b8d60 100644 --- a/src/profiling/PeriodicCounterSelectionCommandHandler.cpp +++ b/src/profiling/PeriodicCounterSelectionCommandHandler.cpp @@ -112,9 +112,6 @@ void PeriodicCounterSelectionCommandHandler::operator()(const Packet& packet) // Echo back the Periodic Counter Selection packet to the Counter Stream Buffer m_SendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, validCounterIds); - // Notify the Send Thread that new data is available in the Counter Stream Buffer - m_SendCounterPacket.SetReadyToRead(); - if (capturePeriod == 0 || validCounterIds.empty()) { // No data capture stop the thread diff --git a/src/profiling/PeriodicCounterSelectionCommandHandler.hpp b/src/profiling/PeriodicCounterSelectionCommandHandler.hpp index e2738f8643..c97474759a 100644 --- a/src/profiling/PeriodicCounterSelectionCommandHandler.hpp +++ b/src/profiling/PeriodicCounterSelectionCommandHandler.hpp @@ -8,6 +8,7 @@ #include "Packet.hpp" #include "CommandHandlerFunctor.hpp" #include "Holder.hpp" +#include "ProfilingStateMachine.hpp" #include "SendCounterPacket.hpp" #include "IPeriodicCounterCapture.hpp" #include "ICounterValues.hpp" diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp index 5edb2b4026..c73f3b29ec 100644 --- a/src/profiling/ProfilingService.cpp +++ b/src/profiling/ProfilingService.cpp @@ -109,7 +109,7 @@ void ProfilingService::Update() m_CommandHandler.Stop(); // Stop the send thread (if running) - m_SendCounterPacket.Stop(false); + m_SendThread.Stop(false); // Stop the periodic counter capture thread (if running) m_PeriodicCounterCapture.Stop(); @@ -143,7 +143,7 @@ void ProfilingService::Update() // Start the send thread, while in "WaitingForAck" state it'll send out a "Stream MetaData" packet waiting for // a valid "Connection Acknowledged" packet confirming the connection - m_SendCounterPacket.Start(*m_ProfilingConnection); + m_SendThread.Start(*m_ProfilingConnection); // The connection acknowledged command handler will automatically transition the state to "Active" once a // valid "Connection Acknowledged" packet has been received @@ -419,7 +419,7 @@ void ProfilingService::Stop() m_CommandHandler.Stop(); m_PeriodicCounterCapture.Stop(); // The the consuming thread - m_SendCounterPacket.Stop(false); + m_SendThread.Stop(false); // ...then close and destroy the profiling connection... if (m_ProfilingConnection != nullptr && m_ProfilingConnection->IsOpen()) diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp index 17099b1247..27166b362e 100644 --- a/src/profiling/ProfilingService.hpp +++ b/src/profiling/ProfilingService.hpp @@ -20,6 +20,7 @@ #include "ProfilingStateMachine.hpp" #include "RequestCounterDirectoryCommandHandler.hpp" #include "SendCounterPacket.hpp" +#include "SendThread.hpp" #include "SendTimelinePacket.hpp" #include "TimelinePacketWriterFactory.hpp" #include @@ -134,6 +135,7 @@ private: CommandHandler m_CommandHandler; BufferManager m_BufferManager; SendCounterPacket m_SendCounterPacket; + SendThread m_SendThread; SendTimelinePacket m_SendTimelinePacket; Holder m_Holder; PeriodicCounterCapture m_PeriodicCounterCapture; @@ -163,7 +165,8 @@ protected: m_CommandHandlerRegistry, m_PacketVersionResolver) , m_BufferManager() - , m_SendCounterPacket(m_StateMachine, m_BufferManager) + , m_SendCounterPacket(m_BufferManager) + , m_SendThread(m_StateMachine, m_BufferManager, m_SendCounterPacket) , m_SendTimelinePacket(m_BufferManager) , m_PeriodicCounterCapture(m_Holder, m_SendCounterPacket, *this) , m_ConnectionAcknowledgedCommandHandler(0, @@ -229,7 +232,7 @@ protected: } bool WaitForPacketSent(ProfilingService& instance, uint32_t timeout = 1000) { - return instance.m_SendCounterPacket.WaitForPacketSent(timeout); + return instance.m_SendThread.WaitForPacketSent(timeout); } BufferManager& GetBufferManager(ProfilingService& instance) diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.cpp b/src/profiling/RequestCounterDirectoryCommandHandler.cpp index 2dbab3c1d5..5521a25f20 100644 --- a/src/profiling/RequestCounterDirectoryCommandHandler.cpp +++ b/src/profiling/RequestCounterDirectoryCommandHandler.cpp @@ -38,9 +38,6 @@ void RequestCounterDirectoryCommandHandler::operator()(const Packet& packet) m_SendCounterPacket.SendCounterDirectoryPacket(m_CounterDirectory); m_SendTimelinePacket.SendTimelineMessageDirectoryPackage(); - // Notify the Send Thread that new data is available in the Counter Stream Buffer - m_SendCounterPacket.SetReadyToRead(); - break; default: throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%") diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp index 4d305af951..942caec295 100644 --- a/src/profiling/SendCounterPacket.cpp +++ b/src/profiling/SendCounterPacket.cpp @@ -5,7 +5,6 @@ #include "SendCounterPacket.hpp" #include "EncodeVersion.hpp" -#include "ProfilingUtils.hpp" #include #include @@ -169,7 +168,7 @@ void SendCounterPacket::SendStreamMetaDataPacket() CancelOperationAndThrow(writeBuffer, "Error processing packet."); } - m_BufferManager.Commit(writeBuffer, totalSize); + m_BufferManager.Commit(writeBuffer, totalSize, false); } bool SendCounterPacket::CreateCategoryRecord(const CategoryPtr& category, @@ -903,241 +902,6 @@ void SendCounterPacket::SendPeriodicCounterSelectionPacket(uint32_t capturePerio m_BufferManager.Commit(writeBuffer, totalSize); } -void SendCounterPacket::SetReadyToRead() -{ - // We need to wait for the send thread to release its mutex - { - std::lock_guard lck(m_WaitMutex); - m_ReadyToRead = true; - } - // Signal the send thread that there's something to read in the buffer - m_WaitCondition.notify_one(); -} - -void SendCounterPacket::Start(IProfilingConnection& profilingConnection) -{ - // Check if the send thread is already running - if (m_IsRunning.load()) - { - // The send thread is already running - return; - } - - if (m_SendThread.joinable()) - { - m_SendThread.join(); - } - - // Mark the send thread as running - m_IsRunning.store(true); - - // Keep the send procedure going until the send thread is signalled to stop - m_KeepRunning.store(true); - - // Make sure the send thread will not flush the buffer until signaled to do so - // no need for a mutex as the send thread can not be running at this point - m_ReadyToRead = false; - - m_PacketSent = false; - - // Start the send thread - m_SendThread = std::thread(&SendCounterPacket::Send, this, std::ref(profilingConnection)); -} - -void SendCounterPacket::Stop(bool rethrowSendThreadExceptions) -{ - // Signal the send thread to stop - m_KeepRunning.store(false); - - // Check that the send thread is running - if (m_SendThread.joinable()) - { - // Kick the send thread out of the wait condition - SetReadyToRead(); - // Wait for the send thread to complete operations - m_SendThread.join(); - } - - // Check if the send thread exception has to be rethrown - if (!rethrowSendThreadExceptions) - { - // No need to rethrow the send thread exception, return immediately - return; - } - - // Check if there's an exception to rethrow - if (m_SendThreadException) - { - // Rethrow the send thread exception - std::rethrow_exception(m_SendThreadException); - - // Nullify the exception as it has been rethrown - m_SendThreadException = nullptr; - } -} - -void SendCounterPacket::Send(IProfilingConnection& profilingConnection) -{ - // Run once and keep the sending procedure looping until the thread is signalled to stop - do - { - // Check the current state of the profiling service - ProfilingState currentState = m_StateMachine.GetCurrentState(); - switch (currentState) - { - case ProfilingState::Uninitialised: - case ProfilingState::NotConnected: - - // The send thread cannot be running when the profiling service is uninitialized or not connected, - // stop the thread immediately - m_KeepRunning.store(false); - m_IsRunning.store(false); - - // An exception should be thrown here, save it to be rethrown later from the main thread so that - // it can be caught by the consumer - m_SendThreadException = - std::make_exception_ptr(RuntimeException("The send thread should not be running with the " - "profiling service not yet initialized or connected")); - - return; - case ProfilingState::WaitingForAck: - - // Send out a StreamMetadata packet and wait for the profiling connection to be acknowledged. - // When a ConnectionAcknowledged packet is received, the profiling service state will be automatically - // updated by the command handler - - // Prepare a StreamMetadata packet and write it to the Counter Stream buffer - SendStreamMetaDataPacket(); - - // Flush the buffer manually to send the packet - FlushBuffer(profilingConnection); - - // Wait for a connection ack from the remote server. We should expect a response within timeout value. - // If not, drop back to the start of the loop and detect somebody closing the thread. Then send the - // StreamMetadata again. - - // Wait condition lock scope - Begin - { - std::unique_lock lock(m_WaitMutex); - - bool timeout = m_WaitCondition.wait_for(lock, - std::chrono::milliseconds(m_Timeout), - [&]{ return m_ReadyToRead; }); - // If we get notified we need to flush the buffer again - if(timeout) - { - // Otherwise if we just timed out don't flush the buffer - continue; - } - //reset condition variable predicate for next use - m_ReadyToRead = false; - } - // Wait condition lock scope - End - break; - case ProfilingState::Active: - default: - // Wait condition lock scope - Begin - { - std::unique_lock lock(m_WaitMutex); - - // Normal working state for the send thread - // Check if the send thread is required to enforce a timeout wait policy - if (m_Timeout < 0) - { - // Wait indefinitely until notified that something to read has become available in the buffer - m_WaitCondition.wait(lock, [&] { return m_ReadyToRead; }); - } - else - { - // Wait until the thread is notified of something to read from the buffer, - // or check anyway after the specified number of milliseconds - m_WaitCondition.wait_for(lock, std::chrono::milliseconds(m_Timeout), [&] { return m_ReadyToRead; }); - } - - //reset condition variable predicate for next use - m_ReadyToRead = false; - } - // Wait condition lock scope - End - break; - } - - // Send all the available packets in the buffer - FlushBuffer(profilingConnection); - } while (m_KeepRunning.load()); - - // Ensure that all readable data got written to the profiling connection before the thread is stopped - // (do not notify any watcher in this case, as this is just to wrap up things before shutting down the send thread) - FlushBuffer(profilingConnection, false); - - // Mark the send thread as not running - m_IsRunning.store(false); -} - -void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection, bool notifyWatchers) -{ - // Get the first available readable buffer - IPacketBufferPtr packetBuffer = m_BufferManager.GetReadableBuffer(); - - // Initialize the flag that indicates whether at least a packet has been sent - bool packetsSent = false; - - while (packetBuffer != nullptr) - { - // Get the data to send from the buffer - const unsigned char* readBuffer = packetBuffer->GetReadableData(); - unsigned int readBufferSize = packetBuffer->GetSize(); - - if (readBuffer == nullptr || readBufferSize == 0) - { - // Nothing to send, get the next available readable buffer and continue - m_BufferManager.MarkRead(packetBuffer); - packetBuffer = m_BufferManager.GetReadableBuffer(); - - continue; - } - - // Check that the profiling connection is open, silently drop the data and continue if it's closed - if (profilingConnection.IsOpen()) - { - // Write a packet to the profiling connection. Silently ignore any write error and continue - profilingConnection.WritePacket(readBuffer, boost::numeric_cast(readBufferSize)); - - // Set the flag that indicates whether at least a packet has been sent - packetsSent = true; - } - - // Mark the packet buffer as read - m_BufferManager.MarkRead(packetBuffer); - - // Get the next available readable buffer - packetBuffer = m_BufferManager.GetReadableBuffer(); - } - // Check whether at least a packet has been sent - if (packetsSent && notifyWatchers) - { - // Wait for the parent thread to release its mutex if necessary - { - std::lock_guard lck(m_PacketSentWaitMutex); - m_PacketSent = true; - } - // Notify to any watcher that something has been sent - m_PacketSentWaitCondition.notify_one(); - } -} - -bool SendCounterPacket::WaitForPacketSent(uint32_t timeout = 1000) -{ - std::unique_lock lock(m_PacketSentWaitMutex); - // Blocks until notified that at least a packet has been sent or until timeout expires. - bool timedOut = m_PacketSentWaitCondition.wait_for(lock, - std::chrono::milliseconds(timeout), - [&] { return m_PacketSent; }); - - m_PacketSent = false; - - return timedOut; -} - } // namespace profiling } // namespace armnn diff --git a/src/profiling/SendCounterPacket.hpp b/src/profiling/SendCounterPacket.hpp index 80d6f8437a..5a10711e1e 100644 --- a/src/profiling/SendCounterPacket.hpp +++ b/src/profiling/SendCounterPacket.hpp @@ -8,14 +8,8 @@ #include "IBufferManager.hpp" #include "ICounterDirectory.hpp" #include "ISendCounterPacket.hpp" -#include "IProfilingConnection.hpp" -#include "ProfilingStateMachine.hpp" #include "ProfilingUtils.hpp" -#include -#include -#include -#include #include namespace armnn @@ -33,19 +27,9 @@ public: using EventRecord = std::vector; using IndexValuePairsVector = std::vector>; - SendCounterPacket(ProfilingStateMachine& profilingStateMachine, IBufferManager& buffer, int timeout = 1000) - : m_StateMachine(profilingStateMachine) - , m_BufferManager(buffer) - , m_Timeout(timeout) - , m_IsRunning(false) - , m_KeepRunning(false) - , m_SendThreadException(nullptr) + SendCounterPacket(IBufferManager& buffer) + : m_BufferManager(buffer) {} - ~SendCounterPacket() - { - // Don't rethrow when destructing the object - Stop(false); - } void SendStreamMetaDataPacket() override; @@ -56,18 +40,9 @@ public: void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod, const std::vector& selectedCounterIds) override; - void SetReadyToRead() override; - static const unsigned int PIPE_MAGIC = 0x45495434; - void Start(IProfilingConnection& profilingConnection); - void Stop(bool rethrowSendThreadExceptions = true); - bool IsRunning() { return m_IsRunning.load(); } - bool WaitForPacketSent(uint32_t timeout); - private: - void Send(IProfilingConnection& profilingConnection); - template void CancelOperationAndThrow(const std::string& errorMessage) { @@ -80,7 +55,7 @@ private: { if (std::is_same::value) { - SetReadyToRead(); + m_BufferManager.FlushReadList(); } if (writerBuffer != nullptr) @@ -93,23 +68,7 @@ private: throw ExceptionType(errorMessage); } - void FlushBuffer(IProfilingConnection& profilingConnection, bool notifyWatchers = true); - - ProfilingStateMachine& m_StateMachine; IBufferManager& m_BufferManager; - int m_Timeout; - std::mutex m_WaitMutex; - std::condition_variable m_WaitCondition; - std::thread m_SendThread; - std::atomic m_IsRunning; - std::atomic m_KeepRunning; - // m_ReadyToRead will be protected by m_WaitMutex - bool m_ReadyToRead; - // m_PacketSent will be protected by m_PacketSentWaitMutex - bool m_PacketSent; - std::exception_ptr m_SendThreadException; - std::mutex m_PacketSentWaitMutex; - std::condition_variable m_PacketSentWaitCondition; protected: // Helper methods, protected for testing diff --git a/src/profiling/SendThread.cpp b/src/profiling/SendThread.cpp new file mode 100644 index 0000000000..d595c9d4a5 --- /dev/null +++ b/src/profiling/SendThread.cpp @@ -0,0 +1,278 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "SendThread.hpp" +#include "EncodeVersion.hpp" +#include "ProfilingUtils.hpp" + +#include +#include +#include + +#include +#include +#include + +#include + +namespace armnn +{ + +namespace profiling +{ + +using boost::numeric_cast; + +SendThread::SendThread(armnn::profiling::ProfilingStateMachine& profilingStateMachine, + armnn::profiling::IBufferManager& buffer, armnn::profiling::ISendCounterPacket& sendCounterPacket, int timeout) + : m_StateMachine(profilingStateMachine) + , m_BufferManager(buffer) + , m_SendCounterPacket(sendCounterPacket) + , m_Timeout(timeout) + , m_IsRunning(false) + , m_KeepRunning(false) + , m_SendThreadException(nullptr) +{ + m_BufferManager.SetConsumer(this); +} + +void SendThread::SetReadyToRead() +{ + // We need to wait for the send thread to release its mutex + { + std::lock_guard lck(m_WaitMutex); + m_ReadyToRead = true; + } + // Signal the send thread that there's something to read in the buffer + m_WaitCondition.notify_one(); +} + +void SendThread::Start(IProfilingConnection& profilingConnection) +{ + // Check if the send thread is already running + if (m_IsRunning.load()) + { + // The send thread is already running + return; + } + + if (m_SendThread.joinable()) + { + m_SendThread.join(); + } + + // Mark the send thread as running + m_IsRunning.store(true); + + // Keep the send procedure going until the send thread is signalled to stop + m_KeepRunning.store(true); + + // Make sure the send thread will not flush the buffer until signaled to do so + // no need for a mutex as the send thread can not be running at this point + m_ReadyToRead = false; + + m_PacketSent = false; + + // Start the send thread + m_SendThread = std::thread(&SendThread::Send, this, std::ref(profilingConnection)); +} + +void SendThread::Stop(bool rethrowSendThreadExceptions) +{ + // Signal the send thread to stop + m_KeepRunning.store(false); + + // Check that the send thread is running + if (m_SendThread.joinable()) + { + // Kick the send thread out of the wait condition + SetReadyToRead(); + // Wait for the send thread to complete operations + m_SendThread.join(); + } + + // Check if the send thread exception has to be rethrown + if (!rethrowSendThreadExceptions) + { + // No need to rethrow the send thread exception, return immediately + return; + } + + // Check if there's an exception to rethrow + if (m_SendThreadException) + { + // Rethrow the send thread exception + std::rethrow_exception(m_SendThreadException); + + // Nullify the exception as it has been rethrown + m_SendThreadException = nullptr; + } +} + +void SendThread::Send(IProfilingConnection& profilingConnection) +{ + // Run once and keep the sending procedure looping until the thread is signalled to stop + do + { + // Check the current state of the profiling service + ProfilingState currentState = m_StateMachine.GetCurrentState(); + switch (currentState) + { + case ProfilingState::Uninitialised: + case ProfilingState::NotConnected: + + // The send thread cannot be running when the profiling service is uninitialized or not connected, + // stop the thread immediately + m_KeepRunning.store(false); + m_IsRunning.store(false); + + // An exception should be thrown here, save it to be rethrown later from the main thread so that + // it can be caught by the consumer + m_SendThreadException = + std::make_exception_ptr(RuntimeException("The send thread should not be running with the " + "profiling service not yet initialized or connected")); + + return; + case ProfilingState::WaitingForAck: + + // Send out a StreamMetadata packet and wait for the profiling connection to be acknowledged. + // When a ConnectionAcknowledged packet is received, the profiling service state will be automatically + // updated by the command handler + + // Prepare a StreamMetadata packet and write it to the Counter Stream buffer + m_SendCounterPacket.SendStreamMetaDataPacket(); + + // Flush the buffer manually to send the packet + FlushBuffer(profilingConnection); + + // Wait for a connection ack from the remote server. We should expect a response within timeout value. + // If not, drop back to the start of the loop and detect somebody closing the thread. Then send the + // StreamMetadata again. + + // Wait condition lock scope - Begin + { + std::unique_lock lock(m_WaitMutex); + + bool timeout = m_WaitCondition.wait_for(lock, + std::chrono::milliseconds(m_Timeout), + [&]{ return m_ReadyToRead; }); + // If we get notified we need to flush the buffer again + if(timeout) + { + // Otherwise if we just timed out don't flush the buffer + continue; + } + //reset condition variable predicate for next use + m_ReadyToRead = false; + } + // Wait condition lock scope - End + break; + case ProfilingState::Active: + default: + // Wait condition lock scope - Begin + { + std::unique_lock lock(m_WaitMutex); + + // Normal working state for the send thread + // Check if the send thread is required to enforce a timeout wait policy + if (m_Timeout < 0) + { + // Wait indefinitely until notified that something to read has become available in the buffer + m_WaitCondition.wait(lock, [&] { return m_ReadyToRead; }); + } + else + { + // Wait until the thread is notified of something to read from the buffer, + // or check anyway after the specified number of milliseconds + m_WaitCondition.wait_for(lock, std::chrono::milliseconds(m_Timeout), [&] { return m_ReadyToRead; }); + } + + //reset condition variable predicate for next use + m_ReadyToRead = false; + } + // Wait condition lock scope - End + break; + } + + // Send all the available packets in the buffer + FlushBuffer(profilingConnection); + } while (m_KeepRunning.load()); + + // Ensure that all readable data got written to the profiling connection before the thread is stopped + // (do not notify any watcher in this case, as this is just to wrap up things before shutting down the send thread) + FlushBuffer(profilingConnection, false); + + // Mark the send thread as not running + m_IsRunning.store(false); +} + +void SendThread::FlushBuffer(IProfilingConnection& profilingConnection, bool notifyWatchers) +{ + // Get the first available readable buffer + IPacketBufferPtr packetBuffer = m_BufferManager.GetReadableBuffer(); + + // Initialize the flag that indicates whether at least a packet has been sent + bool packetsSent = false; + + while (packetBuffer != nullptr) + { + // Get the data to send from the buffer + const unsigned char* readBuffer = packetBuffer->GetReadableData(); + unsigned int readBufferSize = packetBuffer->GetSize(); + + if (readBuffer == nullptr || readBufferSize == 0) + { + // Nothing to send, get the next available readable buffer and continue + m_BufferManager.MarkRead(packetBuffer); + packetBuffer = m_BufferManager.GetReadableBuffer(); + + continue; + } + + // Check that the profiling connection is open, silently drop the data and continue if it's closed + if (profilingConnection.IsOpen()) + { + // Write a packet to the profiling connection. Silently ignore any write error and continue + profilingConnection.WritePacket(readBuffer, boost::numeric_cast(readBufferSize)); + + // Set the flag that indicates whether at least a packet has been sent + packetsSent = true; + } + + // Mark the packet buffer as read + m_BufferManager.MarkRead(packetBuffer); + + // Get the next available readable buffer + packetBuffer = m_BufferManager.GetReadableBuffer(); + } + // Check whether at least a packet has been sent + if (packetsSent && notifyWatchers) + { + // Wait for the parent thread to release its mutex if necessary + { + std::lock_guard lck(m_PacketSentWaitMutex); + m_PacketSent = true; + } + // Notify to any watcher that something has been sent + m_PacketSentWaitCondition.notify_one(); + } +} + +bool SendThread::WaitForPacketSent(uint32_t timeout = 1000) +{ + std::unique_lock lock(m_PacketSentWaitMutex); + // Blocks until notified that at least a packet has been sent or until timeout expires. + bool timedOut = m_PacketSentWaitCondition.wait_for(lock, + std::chrono::milliseconds(timeout), + [&] { return m_PacketSent; }); + + m_PacketSent = false; + + return timedOut; +} + +} // namespace profiling + +} // namespace armnn diff --git a/src/profiling/SendThread.hpp b/src/profiling/SendThread.hpp new file mode 100644 index 0000000000..af1a72bce5 --- /dev/null +++ b/src/profiling/SendThread.hpp @@ -0,0 +1,75 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "IBufferManager.hpp" +#include "IConsumer.hpp" +#include "ICounterDirectory.hpp" +#include "ISendCounterPacket.hpp" +#include "ISendThread.hpp" +#include "IProfilingConnection.hpp" +#include "ProfilingStateMachine.hpp" +#include "ProfilingUtils.hpp" + +#include +#include +#include +#include +#include + +namespace armnn +{ + +namespace profiling +{ + +class SendThread : public ISendThread, public IConsumer +{ +public: + SendThread(ProfilingStateMachine& profilingStateMachine, + IBufferManager& buffer, ISendCounterPacket& sendCounterPacket, int timeout= 1000); + ~SendThread() + { + // Don't rethrow when destructing the object + Stop(false); + } + void Start(IProfilingConnection& profilingConnection) override; + + void Stop(bool rethrowSendThreadExceptions = true) override; + + void SetReadyToRead() override; + + bool IsRunning() { return m_IsRunning.load(); } + + bool WaitForPacketSent(uint32_t timeout); + +private: + void Send(IProfilingConnection& profilingConnection); + + void FlushBuffer(IProfilingConnection& profilingConnection, bool notifyWatchers = true); + + ProfilingStateMachine& m_StateMachine; + IBufferManager& m_BufferManager; + ISendCounterPacket& m_SendCounterPacket; + int m_Timeout; + std::mutex m_WaitMutex; + std::condition_variable m_WaitCondition; + std::thread m_SendThread; + std::atomic m_IsRunning; + std::atomic m_KeepRunning; + // m_ReadyToRead will be protected by m_WaitMutex + bool m_ReadyToRead; + // m_PacketSent will be protected by m_PacketSentWaitMutex + bool m_PacketSent; + std::exception_ptr m_SendThreadException; + std::mutex m_PacketSentWaitMutex; + std::condition_variable m_PacketSentWaitCondition; + +}; + +} // namespace profiling + +} // namespace armnn diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index b15ddf7885..0bad66fb1c 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -135,7 +136,8 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandler) TestProfilingConnectionArmnnError testProfilingConnectionArmnnError; CounterDirectory counterDirectory; MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + SendCounterPacket sendCounterPacket(mockBuffer); + SendThread sendThread(profilingStateMachine, mockBuffer, sendCounterPacket); SendTimelinePacket sendTimelinePacket(mockBuffer); ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(0, 1, 4194304, counterDirectory, @@ -1766,7 +1768,8 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) TestCaptureThread captureThread; TestReadCounterValues readCounterValues; MockBufferManager mockBuffer(512); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + SendCounterPacket sendCounterPacket(mockBuffer); + SendThread sendThread(profilingStateMachine, mockBuffer, sendCounterPacket); uint32_t sizeOfUint32 = numeric_cast(sizeof(uint32_t)); uint32_t sizeOfUint16 = numeric_cast(sizeof(uint16_t)); @@ -1896,7 +1899,8 @@ BOOST_AUTO_TEST_CASE(CheckConnectionAcknowledged) BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Uninitialised); CounterDirectory counterDirectory; MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(profilingState, mockBuffer); + SendCounterPacket sendCounterPacket(mockBuffer); + SendThread sendThread(profilingState, mockBuffer, sendCounterPacket); SendTimelinePacket sendTimelinePacket(mockBuffer); ConnectionAcknowledgedCommandHandler commandHandler(packetFamilyId, connectionPacketId, version, counterDirectory, @@ -2158,7 +2162,8 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread) std::vector captureIds2; MockBufferManager mockBuffer(512); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + SendCounterPacket sendCounterPacket(mockBuffer); + SendThread sendThread(profilingStateMachine, mockBuffer, sendCounterPacket); std::vector counterIds; CaptureReader captureReader(2); @@ -2216,7 +2221,8 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1) ProfilingStateMachine profilingStateMachine; CounterDirectory counterDirectory; MockBufferManager mockBuffer1(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer1); + SendCounterPacket sendCounterPacket(mockBuffer1); + SendThread sendThread(profilingStateMachine, mockBuffer1, sendCounterPacket); MockBufferManager mockBuffer2(1024); SendTimelinePacket sendTimelinePacket(mockBuffer2); RequestCounterDirectoryCommandHandler commandHandler(familyId, packetId, version, counterDirectory, @@ -2277,7 +2283,8 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest2) ProfilingStateMachine profilingStateMachine; CounterDirectory counterDirectory; MockBufferManager mockBuffer1(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer1); + SendCounterPacket sendCounterPacket(mockBuffer1); + SendThread sendThread(profilingStateMachine, mockBuffer1, sendCounterPacket); MockBufferManager mockBuffer2(1024); SendTimelinePacket sendTimelinePacket(mockBuffer2); RequestCounterDirectoryCommandHandler commandHandler(familyId, packetId, version, counterDirectory, diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp index 83bffe4686..9ec24e539f 100644 --- a/src/profiling/test/SendCounterPacketTests.cpp +++ b/src/profiling/test/SendCounterPacketTests.cpp @@ -146,11 +146,9 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest) BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest) { - ProfilingStateMachine profilingStateMachine; - // Error no space left in buffer MockBufferManager mockBuffer1(10); - SendCounterPacket sendPacket1(profilingStateMachine, mockBuffer1); + SendCounterPacket sendPacket1(mockBuffer1); uint32_t capturePeriod = 1000; std::vector selectedCounterIds; @@ -159,7 +157,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest) // Packet without any counters MockBufferManager mockBuffer2(512); - SendCounterPacket sendPacket2(profilingStateMachine, mockBuffer2); + SendCounterPacket sendPacket2(mockBuffer2); sendPacket2.SendPeriodicCounterSelectionPacket(capturePeriod, selectedCounterIds); auto readBuffer2 = mockBuffer2.GetReadableBuffer(); @@ -175,7 +173,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest) // Full packet message MockBufferManager mockBuffer3(512); - SendCounterPacket sendPacket3(profilingStateMachine, mockBuffer3); + SendCounterPacket sendPacket3(mockBuffer3); selectedCounterIds.reserve(5); selectedCounterIds.emplace_back(100); @@ -213,7 +211,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest) // Error no space left in buffer MockBufferManager mockBuffer1(10); - SendCounterPacket sendPacket1(profilingStateMachine, mockBuffer1); + SendCounterPacket sendPacket1(mockBuffer1); auto captureTimestamp = std::chrono::steady_clock::now(); uint64_t time = static_cast(captureTimestamp.time_since_epoch().count()); @@ -224,7 +222,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest) // Packet without any counters MockBufferManager mockBuffer2(512); - SendCounterPacket sendPacket2(profilingStateMachine, mockBuffer2); + SendCounterPacket sendPacket2(mockBuffer2); sendPacket2.SendPeriodicCounterCapturePacket(time, indexValuePairs); auto readBuffer2 = mockBuffer2.GetReadableBuffer(); @@ -241,7 +239,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest) // Full packet message MockBufferManager mockBuffer3(512); - SendCounterPacket sendPacket3(profilingStateMachine, mockBuffer3); + SendCounterPacket sendPacket3(mockBuffer3); indexValuePairs.reserve(5); indexValuePairs.emplace_back(std::make_pair(0, 100)); @@ -290,11 +288,9 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest) uint32_t sizeUint32 = numeric_cast(sizeof(uint32_t)); - ProfilingStateMachine profilingStateMachine; - // Error no space left in buffer MockBufferManager mockBuffer1(10); - SendCounterPacket sendPacket1(profilingStateMachine, mockBuffer1); + SendCounterPacket sendPacket1(mockBuffer1); BOOST_CHECK_THROW(sendPacket1.SendStreamMetaDataPacket(), armnn::profiling::BufferExhaustion); // Full metadata packet @@ -313,7 +309,7 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest) uint32_t packetEntries = 6; MockBufferManager mockBuffer2(512); - SendCounterPacket sendPacket2(profilingStateMachine, mockBuffer2); + SendCounterPacket sendPacket2(mockBuffer2); sendPacket2.SendStreamMetaDataPacket(); auto readBuffer2 = mockBuffer2.GetReadableBuffer(); @@ -408,10 +404,8 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest) BOOST_AUTO_TEST_CASE(CreateDeviceRecordTest) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a device for testing uint16_t deviceUid = 27; @@ -442,10 +436,8 @@ BOOST_AUTO_TEST_CASE(CreateDeviceRecordTest) BOOST_AUTO_TEST_CASE(CreateInvalidDeviceRecordTest) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a device for testing uint16_t deviceUid = 27; @@ -465,10 +457,8 @@ BOOST_AUTO_TEST_CASE(CreateInvalidDeviceRecordTest) BOOST_AUTO_TEST_CASE(CreateCounterSetRecordTest) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a counter set for testing uint16_t counterSetUid = 27; @@ -499,10 +489,8 @@ BOOST_AUTO_TEST_CASE(CreateCounterSetRecordTest) BOOST_AUTO_TEST_CASE(CreateInvalidCounterSetRecordTest) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a counter set for testing uint16_t counterSetUid = 27; @@ -522,10 +510,8 @@ BOOST_AUTO_TEST_CASE(CreateInvalidCounterSetRecordTest) BOOST_AUTO_TEST_CASE(CreateEventRecordTest) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -645,10 +631,8 @@ BOOST_AUTO_TEST_CASE(CreateEventRecordTest) BOOST_AUTO_TEST_CASE(CreateEventRecordNoUnitsTest) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a counter for testing uint16_t counterUid = 44312; @@ -751,10 +735,8 @@ BOOST_AUTO_TEST_CASE(CreateEventRecordNoUnitsTest) BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest1) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -792,10 +774,8 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest1) BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest2) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -833,10 +813,8 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest2) BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest3) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -874,10 +852,8 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest3) BOOST_AUTO_TEST_CASE(CreateCategoryRecordTest) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a category for testing const std::string categoryName = "some_category"; @@ -1080,10 +1056,8 @@ BOOST_AUTO_TEST_CASE(CreateCategoryRecordTest) BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest1) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a category for testing const std::string categoryName = "some invalid category"; @@ -1105,10 +1079,8 @@ BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest1) BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest2) { - ProfilingStateMachine profilingStateMachine; - MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); // Create a category for testing const std::string categoryName = "some_category"; @@ -1148,8 +1120,6 @@ BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest2) BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest1) { - ProfilingStateMachine profilingStateMachine; - // The counter directory used for testing CounterDirectory counterDirectory; @@ -1169,15 +1139,13 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest1) // Buffer with not enough space MockBufferManager mockBuffer(10); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + SendCounterPacket sendCounterPacket(mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::profiling::BufferExhaustion); } BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) { - ProfilingStateMachine profilingStateMachine; - // The counter directory used for testing CounterDirectory counterDirectory; @@ -1269,7 +1237,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + SendCounterPacket sendCounterPacket(mockBuffer); BOOST_CHECK_NO_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory)); // Get the readable buffer @@ -1658,8 +1626,6 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest3) { - ProfilingStateMachine profilingStateMachine; - // Using a mock counter directory that allows to register invalid objects MockCounterDirectory counterDirectory; @@ -1672,14 +1638,12 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest3) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + SendCounterPacket sendCounterPacket(mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest4) { - ProfilingStateMachine profilingStateMachine; - // Using a mock counter directory that allows to register invalid objects MockCounterDirectory counterDirectory; @@ -1692,14 +1656,12 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest4) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + SendCounterPacket sendCounterPacket(mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest5) { - ProfilingStateMachine profilingStateMachine; - // Using a mock counter directory that allows to register invalid objects MockCounterDirectory counterDirectory; @@ -1712,14 +1674,12 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest5) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + SendCounterPacket sendCounterPacket(mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest6) { - ProfilingStateMachine profilingStateMachine; - // Using a mock counter directory that allows to register invalid objects MockCounterDirectory counterDirectory; @@ -1748,14 +1708,12 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest6) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + SendCounterPacket sendCounterPacket(mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest7) { - ProfilingStateMachine profilingStateMachine; - // Using a mock counter directory that allows to register invalid objects MockCounterDirectory counterDirectory; @@ -1801,7 +1759,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest7) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); + SendCounterPacket sendCounterPacket(mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } @@ -1812,20 +1770,21 @@ BOOST_AUTO_TEST_CASE(SendThreadTest0) MockProfilingConnection mockProfilingConnection; MockStreamCounterBuffer mockStreamCounterBuffer(0); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer); + SendCounterPacket sendCounterPacket(mockStreamCounterBuffer); + SendThread sendThread(profilingStateMachine, mockStreamCounterBuffer, sendCounterPacket); // Try to start the send thread many times, it must only start once - sendCounterPacket.Start(mockProfilingConnection); - BOOST_CHECK(sendCounterPacket.IsRunning()); - sendCounterPacket.Start(mockProfilingConnection); - sendCounterPacket.Start(mockProfilingConnection); - sendCounterPacket.Start(mockProfilingConnection); - sendCounterPacket.Start(mockProfilingConnection); - BOOST_CHECK(sendCounterPacket.IsRunning()); + sendThread.Start(mockProfilingConnection); + BOOST_CHECK(sendThread.IsRunning()); + sendThread.Start(mockProfilingConnection); + sendThread.Start(mockProfilingConnection); + sendThread.Start(mockProfilingConnection); + sendThread.Start(mockProfilingConnection); + BOOST_CHECK(sendThread.IsRunning()); - sendCounterPacket.Stop(); - BOOST_CHECK(!sendCounterPacket.IsRunning()); + sendThread.Stop(); + BOOST_CHECK(!sendThread.IsRunning()); } BOOST_AUTO_TEST_CASE(SendThreadTest1) @@ -1837,8 +1796,9 @@ BOOST_AUTO_TEST_CASE(SendThreadTest1) MockProfilingConnection mockProfilingConnection; MockStreamCounterBuffer mockStreamCounterBuffer(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer); - sendCounterPacket.Start(mockProfilingConnection); + SendCounterPacket sendCounterPacket(mockStreamCounterBuffer); + SendThread sendThread(profilingStateMachine, mockStreamCounterBuffer, sendCounterPacket); + sendThread.Start(mockProfilingConnection); // Interleaving writes and reads to/from the buffer with pauses to test that the send thread actually waits for // something to become available for reading @@ -1854,7 +1814,7 @@ BOOST_AUTO_TEST_CASE(SendThreadTest1) unsigned int streamMetadataPacketsize = 118 + processNameSize; totalWrittenSize += streamMetadataPacketsize; - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); @@ -1864,7 +1824,7 @@ BOOST_AUTO_TEST_CASE(SendThreadTest1) unsigned int counterDirectoryPacketSize = 32; totalWrittenSize += counterDirectoryPacketSize; - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); @@ -1878,7 +1838,7 @@ BOOST_AUTO_TEST_CASE(SendThreadTest1) unsigned int periodicCounterCapturePacketSize = 28; totalWrittenSize += periodicCounterCapturePacketSize; - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); @@ -1916,7 +1876,7 @@ BOOST_AUTO_TEST_CASE(SendThreadTest1) periodicCounterCapturePacketSize = 40; totalWrittenSize += periodicCounterCapturePacketSize; - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); @@ -1926,13 +1886,13 @@ BOOST_AUTO_TEST_CASE(SendThreadTest1) periodicCounterCapturePacketSize = 30; totalWrittenSize += periodicCounterCapturePacketSize; - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); // To test an exact value of the "read size" in the mock buffer, wait to allow the send thread to // read all what's remaining in the buffer std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); - sendCounterPacket.Stop(); + sendThread.Stop(); BOOST_CHECK(mockStreamCounterBuffer.GetCommittedSize() == totalWrittenSize); BOOST_CHECK(mockStreamCounterBuffer.GetReadableSize() == totalWrittenSize); @@ -1948,15 +1908,16 @@ BOOST_AUTO_TEST_CASE(SendThreadTest2) MockProfilingConnection mockProfilingConnection; MockStreamCounterBuffer mockStreamCounterBuffer(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer); - sendCounterPacket.Start(mockProfilingConnection); + SendCounterPacket sendCounterPacket(mockStreamCounterBuffer); + SendThread sendThread(profilingStateMachine, mockStreamCounterBuffer, sendCounterPacket); + sendThread.Start(mockProfilingConnection); // Adding many spurious "ready to read" signals throughout the test to check that the send thread is // capable of handling unnecessary read requests std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); CounterDirectory counterDirectory; sendCounterPacket.SendStreamMetaDataPacket(); @@ -1967,7 +1928,7 @@ BOOST_AUTO_TEST_CASE(SendThreadTest2) unsigned int streamMetadataPacketsize = 118 + processNameSize; totalWrittenSize += streamMetadataPacketsize; - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); @@ -1977,8 +1938,8 @@ BOOST_AUTO_TEST_CASE(SendThreadTest2) unsigned int counterDirectoryPacketSize = 32; totalWrittenSize += counterDirectoryPacketSize; - sendCounterPacket.SetReadyToRead(); - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); + sendThread.SetReadyToRead(); std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); @@ -1992,17 +1953,17 @@ BOOST_AUTO_TEST_CASE(SendThreadTest2) unsigned int periodicCounterCapturePacketSize = 28; totalWrittenSize += periodicCounterCapturePacketSize; - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); - sendCounterPacket.SetReadyToRead(); - sendCounterPacket.SetReadyToRead(); - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); + sendThread.SetReadyToRead(); + sendThread.SetReadyToRead(); std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); sendCounterPacket.SendPeriodicCounterCapturePacket(44u, { { 211u, 923u } @@ -2025,7 +1986,7 @@ BOOST_AUTO_TEST_CASE(SendThreadTest2) periodicCounterCapturePacketSize = 46; totalWrittenSize += periodicCounterCapturePacketSize; - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); sendCounterPacket.SendPeriodicCounterCapturePacket(997u, { { 88u, 11u }, @@ -2038,8 +1999,8 @@ BOOST_AUTO_TEST_CASE(SendThreadTest2) periodicCounterCapturePacketSize = 40; totalWrittenSize += periodicCounterCapturePacketSize; - sendCounterPacket.SetReadyToRead(); - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); + sendThread.SetReadyToRead(); std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); @@ -2049,11 +2010,11 @@ BOOST_AUTO_TEST_CASE(SendThreadTest2) periodicCounterCapturePacketSize = 30; totalWrittenSize += periodicCounterCapturePacketSize; - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); // To test an exact value of the "read size" in the mock buffer, wait to allow the send thread to // read all what's remaining in the buffer - sendCounterPacket.Stop(); + sendThread.Stop(); BOOST_CHECK(mockStreamCounterBuffer.GetCommittedSize() == totalWrittenSize); BOOST_CHECK(mockStreamCounterBuffer.GetReadableSize() == totalWrittenSize); @@ -2069,12 +2030,13 @@ BOOST_AUTO_TEST_CASE(SendThreadTest3) MockProfilingConnection mockProfilingConnection; MockStreamCounterBuffer mockStreamCounterBuffer(1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer); - sendCounterPacket.Start(mockProfilingConnection); + SendCounterPacket sendCounterPacket(mockStreamCounterBuffer); + SendThread sendThread(profilingStateMachine, mockStreamCounterBuffer, sendCounterPacket); + sendThread.Start(mockProfilingConnection); // Not using pauses or "grace periods" to stress test the send thread - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); CounterDirectory counterDirectory; sendCounterPacket.SendStreamMetaDataPacket(); @@ -2085,15 +2047,15 @@ BOOST_AUTO_TEST_CASE(SendThreadTest3) unsigned int streamMetadataPacketsize = 118 + processNameSize; totalWrittenSize += streamMetadataPacketsize; - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); sendCounterPacket.SendCounterDirectoryPacket(counterDirectory); // Get the size of the Counter Directory Packet unsigned int counterDirectoryPacketSize =32; totalWrittenSize += counterDirectoryPacketSize; - sendCounterPacket.SetReadyToRead(); - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); + sendThread.SetReadyToRead(); sendCounterPacket.SendPeriodicCounterCapturePacket(123u, { { 1u, 23u }, @@ -2104,11 +2066,11 @@ BOOST_AUTO_TEST_CASE(SendThreadTest3) unsigned int periodicCounterCapturePacketSize = 28; totalWrittenSize += periodicCounterCapturePacketSize; - sendCounterPacket.SetReadyToRead(); - sendCounterPacket.SetReadyToRead(); - sendCounterPacket.SetReadyToRead(); - sendCounterPacket.SetReadyToRead(); - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); + sendThread.SetReadyToRead(); + sendThread.SetReadyToRead(); + sendThread.SetReadyToRead(); + sendThread.SetReadyToRead(); sendCounterPacket.SendPeriodicCounterCapturePacket(44u, { { 211u, 923u } @@ -2131,8 +2093,8 @@ BOOST_AUTO_TEST_CASE(SendThreadTest3) periodicCounterCapturePacketSize = 46; totalWrittenSize += periodicCounterCapturePacketSize; - sendCounterPacket.SetReadyToRead(); - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); + sendThread.SetReadyToRead(); sendCounterPacket.SendPeriodicCounterCapturePacket(997u, { { 88u, 11u }, @@ -2145,19 +2107,19 @@ BOOST_AUTO_TEST_CASE(SendThreadTest3) periodicCounterCapturePacketSize = 40; totalWrittenSize += periodicCounterCapturePacketSize; - sendCounterPacket.SetReadyToRead(); - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); + sendThread.SetReadyToRead(); sendCounterPacket.SendPeriodicCounterSelectionPacket(1000u, { 1345u, 254u, 4536u, 408u, 54u, 6323u, 428u, 1u, 6u }); // Get the size of the Periodic Counter Capture Packet periodicCounterCapturePacketSize = 30; totalWrittenSize += periodicCounterCapturePacketSize; - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); // Abruptly terminating the send thread, the amount of data sent may be less that the amount written (the send // thread is not guaranteed to flush the buffer) - sendCounterPacket.Stop(); + sendThread.Stop(); BOOST_CHECK(mockStreamCounterBuffer.GetCommittedSize() == totalWrittenSize); BOOST_CHECK(mockStreamCounterBuffer.GetReadableSize() <= totalWrittenSize); @@ -2166,98 +2128,40 @@ BOOST_AUTO_TEST_CASE(SendThreadTest3) BOOST_CHECK(mockStreamCounterBuffer.GetReadSize() <= mockStreamCounterBuffer.GetCommittedSize()); } -BOOST_AUTO_TEST_CASE(SendThreadBufferTest) +BOOST_AUTO_TEST_CASE(SendCounterPacketTestWithSendThread) { ProfilingStateMachine profilingStateMachine; - SetActiveProfilingState(profilingStateMachine); + SetWaitingForAckProfilingState(profilingStateMachine); MockProfilingConnection mockProfilingConnection; BufferManager bufferManager(1, 1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager, -1); - sendCounterPacket.Start(mockProfilingConnection); - - // Interleaving writes and reads to/from the buffer with pauses to test that the send thread actually waits for - // something to become available for reading - std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_UNTIL_READABLE_MS)); - - // SendStreamMetaDataPacket - sendCounterPacket.SendStreamMetaDataPacket(); - - // Read data from the buffer - // Buffer should become readable after commit by SendStreamMetaDataPacket - auto packetBuffer = bufferManager.GetReadableBuffer(); - BOOST_TEST(packetBuffer.get()); + SendCounterPacket sendCounterPacket(bufferManager); + SendThread sendThread(profilingStateMachine, bufferManager, sendCounterPacket, -1); + sendThread.Start(mockProfilingConnection); std::string processName = GetProcessName().substr(0, 60); unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast(processName.size()) + 1; unsigned int streamMetadataPacketsize = 118 + processNameSize; - BOOST_TEST(packetBuffer->GetSize() == streamMetadataPacketsize); - - // Buffer is not available when SendStreamMetaDataPacket already occupied the buffer. - unsigned int reservedSize = 0; - auto reservedBuffer = bufferManager.Reserve(512, reservedSize); - BOOST_TEST(!reservedBuffer.get()); - // Recommit to be read by sendCounterPacket - bufferManager.Commit(packetBuffer, streamMetadataPacketsize); - - sendCounterPacket.SetReadyToRead(); - - // Join the send thread to make sure it has read the buffer - sendCounterPacket.Stop(); - sendCounterPacket.Start(mockProfilingConnection); + sendThread.Stop(); - // The buffer is read by the send thread so it should not be in the readable buffer. - auto readBuffer = bufferManager.GetReadableBuffer(); - BOOST_TEST(!readBuffer); + // check for packet in ProfilingConnection + BOOST_CHECK(mockProfilingConnection.CheckForPacket({PacketType::StreamMetaData, streamMetadataPacketsize}) == 1); - // Successfully reserved the buffer with requested size - reservedBuffer = bufferManager.Reserve(512, reservedSize); - BOOST_TEST(reservedSize == 512); - BOOST_TEST(reservedBuffer.get()); - - // Release the buffer to be used by sendCounterPacket - bufferManager.Release(reservedBuffer); + SetActiveProfilingState(profilingStateMachine); + sendThread.Start(mockProfilingConnection); // SendCounterDirectoryPacket CounterDirectory counterDirectory; sendCounterPacket.SendCounterDirectoryPacket(counterDirectory); - // Read data from the buffer - // Buffer should become readable after commit by SendCounterDirectoryPacket - auto counterDirectoryPacketBuffer = bufferManager.GetReadableBuffer(); - BOOST_TEST(counterDirectoryPacketBuffer.get()); - - // Get the size of the Counter Directory Packet + sendThread.Stop(); unsigned int counterDirectoryPacketSize = 32; - BOOST_TEST(counterDirectoryPacketBuffer->GetSize() == counterDirectoryPacketSize); - - // Buffer is not available when SendCounterDirectoryPacket already occupied the buffer. - reservedSize = 0; - reservedBuffer = bufferManager.Reserve(512, reservedSize); - BOOST_TEST(reservedSize == 0); - BOOST_TEST(!reservedBuffer.get()); - - // Recommit to be read by sendCounterPacket - bufferManager.Commit(counterDirectoryPacketBuffer, counterDirectoryPacketSize); - - sendCounterPacket.SetReadyToRead(); + // check for packet in ProfilingConnection + BOOST_CHECK(mockProfilingConnection.CheckForPacket( + {PacketType::CounterDirectory, counterDirectoryPacketSize}) == 1); - // Join the send thread to make sure it has read the buffer - sendCounterPacket.Stop(); - sendCounterPacket.Start(mockProfilingConnection); - - // The buffer is read by the send thread so it should not be in the readable buffer. - readBuffer = bufferManager.GetReadableBuffer(); - BOOST_TEST(!readBuffer); - - // Successfully reserved the buffer with requested size - reservedBuffer = bufferManager.Reserve(512, reservedSize); - BOOST_TEST(reservedSize == 512); - BOOST_TEST(reservedBuffer.get()); - - // Release the buffer to be used by sendCounterPacket - bufferManager.Release(reservedBuffer); + sendThread.Start(mockProfilingConnection); // SendPeriodicCounterCapturePacket sendCounterPacket.SendPeriodicCounterCapturePacket(123u, @@ -2266,51 +2170,23 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest) { 33u, 1207623u } }); - // Read data from the buffer - // Buffer should become readable after commit by SendPeriodicCounterCapturePacket - auto periodicCounterCapturePacketBuffer = bufferManager.GetReadableBuffer(); - BOOST_TEST(periodicCounterCapturePacketBuffer.get()); + sendThread.Stop(); - // Get the size of the Periodic Counter Capture Packet unsigned int periodicCounterCapturePacketSize = 28; - BOOST_TEST(periodicCounterCapturePacketBuffer->GetSize() == periodicCounterCapturePacketSize); - - // Buffer is not available when SendPeriodicCounterCapturePacket already occupied the buffer. - reservedSize = 0; - reservedBuffer = bufferManager.Reserve(512, reservedSize); - BOOST_TEST(reservedSize == 0); - BOOST_TEST(!reservedBuffer.get()); - - // Recommit to be read by sendCounterPacket - bufferManager.Commit(periodicCounterCapturePacketBuffer, periodicCounterCapturePacketSize); - - sendCounterPacket.SetReadyToRead(); - - // Join the send thread to make sure it has read the buffer - sendCounterPacket.Stop(); - sendCounterPacket.Start(mockProfilingConnection); - - // The buffer is read by the send thread so it should not be in the readable buffer. - readBuffer = bufferManager.GetReadableBuffer(); - BOOST_TEST(!readBuffer); - - // Successfully reserved the buffer with requested size - reservedBuffer = bufferManager.Reserve(512, reservedSize); - BOOST_TEST(reservedSize == 512); - BOOST_TEST(reservedBuffer.get()); - - sendCounterPacket.Stop(); + BOOST_CHECK(mockProfilingConnection.CheckForPacket( + {PacketType::PeriodicCounterCapture, periodicCounterCapturePacketSize}) == 1); } -BOOST_AUTO_TEST_CASE(SendThreadBufferTest1) +BOOST_AUTO_TEST_CASE(SendThreadBufferTest) { ProfilingStateMachine profilingStateMachine; SetActiveProfilingState(profilingStateMachine); MockProfilingConnection mockProfilingConnection; BufferManager bufferManager(3, 1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager, -1); - sendCounterPacket.Start(mockProfilingConnection); + SendCounterPacket sendCounterPacket(bufferManager); + SendThread sendThread(profilingStateMachine, bufferManager, sendCounterPacket, -1); + sendThread.Start(mockProfilingConnection); // SendStreamMetaDataPacket sendCounterPacket.SendStreamMetaDataPacket(); @@ -2328,14 +2204,10 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest1) // Recommit to be read by sendCounterPacket bufferManager.Commit(packetBuffer, streamMetadataPacketsize); - sendCounterPacket.SetReadyToRead(); - // SendCounterDirectoryPacket CounterDirectory counterDirectory; sendCounterPacket.SendCounterDirectoryPacket(counterDirectory); - sendCounterPacket.SetReadyToRead(); - // SendPeriodicCounterCapturePacket sendCounterPacket.SendPeriodicCounterCapturePacket(123u, { @@ -2343,9 +2215,7 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest1) { 33u, 1207623u } }); - sendCounterPacket.SetReadyToRead(); - - sendCounterPacket.Stop(); + sendThread.Stop(); // The buffer is read by the send thread so it should not be in the readable buffer. auto readBuffer = bufferManager.GetReadableBuffer(); @@ -2374,11 +2244,12 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket1) MockProfilingConnection mockProfilingConnection; BufferManager bufferManager(3, 1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager); - sendCounterPacket.Start(mockProfilingConnection); + SendCounterPacket sendCounterPacket(bufferManager); + SendThread sendThread(profilingStateMachine, bufferManager, sendCounterPacket); + sendThread.Start(mockProfilingConnection); // The profiling state is set to "Uninitialized", so the send thread should throw an exception - BOOST_CHECK_THROW(sendCounterPacket.Stop(), armnn::RuntimeException); + BOOST_CHECK_THROW(sendThread.Stop(), armnn::RuntimeException); } BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket2) @@ -2388,11 +2259,12 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket2) MockProfilingConnection mockProfilingConnection; BufferManager bufferManager(3, 1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager); - sendCounterPacket.Start(mockProfilingConnection); + SendCounterPacket sendCounterPacket(bufferManager); + SendThread sendThread(profilingStateMachine, bufferManager, sendCounterPacket); + sendThread.Start(mockProfilingConnection); // The profiling state is set to "NotConnected", so the send thread should throw an exception - BOOST_CHECK_THROW(sendCounterPacket.Stop(), armnn::RuntimeException); + BOOST_CHECK_THROW(sendThread.Stop(), armnn::RuntimeException); } BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket3) @@ -2407,12 +2279,13 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket3) MockProfilingConnection mockProfilingConnection; BufferManager bufferManager(3, 1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager); - sendCounterPacket.Start(mockProfilingConnection); + SendCounterPacket sendCounterPacket(bufferManager); + SendThread sendThread(profilingStateMachine, bufferManager, sendCounterPacket); + sendThread.Start(mockProfilingConnection); // The profiling state is set to "WaitingForAck", so the send thread should send a Stream Metadata packet - // Wait for sendCounterPacket to join - BOOST_CHECK_NO_THROW(sendCounterPacket.Stop()); + // Wait for sendThread to join + BOOST_CHECK_NO_THROW(sendThread.Stop()); // Check that the buffer contains at least one Stream Metadata packet and no other packets const auto writtenDataSize = mockProfilingConnection.GetWrittenDataSize(); @@ -2434,14 +2307,15 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket4) MockProfilingConnection mockProfilingConnection; BufferManager bufferManager(3, 1024); - SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager); - sendCounterPacket.Start(mockProfilingConnection); + SendCounterPacket sendCounterPacket(bufferManager); + SendThread sendThread(profilingStateMachine, bufferManager, sendCounterPacket); + sendThread.Start(mockProfilingConnection); // The profiling state is set to "WaitingForAck", so the send thread should send a Stream Metadata packet - // Wait for sendCounterPacket to join - sendCounterPacket.Stop(); + // Wait for sendThread to join + sendThread.Stop(); - sendCounterPacket.Start(mockProfilingConnection); + sendThread.Start(mockProfilingConnection); // Check that the profiling state is still "WaitingForAck" BOOST_TEST((profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck)); @@ -2450,14 +2324,14 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket4) mockProfilingConnection.Clear(); - sendCounterPacket.Stop(); - sendCounterPacket.Start(mockProfilingConnection); + sendThread.Stop(); + sendThread.Start(mockProfilingConnection); // Try triggering a new buffer read - sendCounterPacket.SetReadyToRead(); + sendThread.SetReadyToRead(); - // Wait for sendCounterPacket to join - BOOST_CHECK_NO_THROW(sendCounterPacket.Stop()); + // Wait for sendThread to join + BOOST_CHECK_NO_THROW(sendThread.Stop()); // Check that the profiling state is still "WaitingForAck" BOOST_TEST((profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck)); diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp index 4d517e5829..c7fc7b84ac 100644 --- a/src/profiling/test/SendCounterPacketTests.hpp +++ b/src/profiling/test/SendCounterPacketTests.hpp @@ -6,6 +6,7 @@ #pragma once #include +#include #include #include @@ -17,6 +18,11 @@ #include #include +#include +#include +#include +#include + namespace armnn { @@ -210,10 +216,15 @@ public: return std::move(m_Buffer); } - void Commit(IPacketBufferPtr& packetBuffer, unsigned int size) override + void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override { packetBuffer->Commit(size); m_Buffer = std::move(packetBuffer); + + if (notifyConsumer) + { + FlushReadList(); + } } IPacketBufferPtr GetReadableBuffer() override @@ -233,9 +244,27 @@ public: m_Buffer = std::move(packetBuffer); } + void SetConsumer(IConsumer* consumer) override + { + if (consumer != nullptr) + { + m_Consumer = consumer; + } + } + + void FlushReadList() override + { + // notify consumer that packet is ready to read + if (m_Consumer != nullptr) + { + m_Consumer->SetReadyToRead(); + } + } + private: unsigned int m_BufferSize; IPacketBufferPtr m_Buffer; + IConsumer* m_Consumer = nullptr; }; class MockStreamCounterBuffer : public IBufferManager @@ -264,13 +293,18 @@ public: return std::make_unique(requestedSize); } - void Commit(IPacketBufferPtr& packetBuffer, unsigned int size) override + void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override { std::lock_guard lock(m_Mutex); packetBuffer->Commit(size); m_BufferList.push_back(std::move(packetBuffer)); m_CommittedSize += size; + + if (notifyConsumer) + { + FlushReadList(); + } } void Release(IPacketBufferPtr& packetBuffer) override @@ -302,6 +336,23 @@ public: packetBuffer->MarkRead(); } + void SetConsumer(IConsumer* consumer) override + { + if (consumer != nullptr) + { + m_Consumer = consumer; + } + } + + void FlushReadList() override + { + // notify consumer that packet is ready to read + if (m_Consumer != nullptr) + { + m_Consumer->SetReadyToRead(); + } + } + unsigned int GetCommittedSize() const { return m_CommittedSize; } unsigned int GetReadableSize() const { return m_ReadableSize; } unsigned int GetReadSize() const { return m_ReadSize; } @@ -324,6 +375,9 @@ private: // The total size of the buffers that has already been read unsigned int m_ReadSize; + + // Consumer thread to notify packet is ready to read + IConsumer* m_Consumer = nullptr; }; class MockSendCounterPacket : public ISendCounterPacket @@ -337,7 +391,7 @@ public: unsigned int reserved = 0; IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved); memcpy(buffer->GetWritableData(), message.c_str(), static_cast(message.size()) + 1); - m_BufferManager.Commit(buffer, reserved); + m_BufferManager.Commit(buffer, reserved, false); } void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override @@ -375,8 +429,6 @@ public: m_BufferManager.Commit(buffer, reserved); } - void SetReadyToRead() override {} - private: IBufferManager& m_BufferManager; }; @@ -596,8 +648,8 @@ private: class SendCounterPacketTest : public SendCounterPacket { public: - SendCounterPacketTest(ProfilingStateMachine& profilingStateMachine, IBufferManager& buffer) - : SendCounterPacket(profilingStateMachine, buffer) + SendCounterPacketTest(IBufferManager& buffer) + : SendCounterPacket(buffer) {} bool CreateDeviceRecordTest(const DevicePtr& device, -- cgit v1.2.1