diff options
-rw-r--r-- | src/profiling/IPeriodicCounterCapture.hpp | 6 | ||||
-rw-r--r-- | src/profiling/IProfilingConnection.hpp | 2 | ||||
-rw-r--r-- | src/profiling/Packet.hpp | 15 | ||||
-rw-r--r-- | src/profiling/SendCounterPacket.cpp | 74 | ||||
-rw-r--r-- | src/profiling/SendCounterPacket.hpp | 31 | ||||
-rw-r--r-- | src/profiling/SocketProfilingConnection.cpp | 120 | ||||
-rw-r--r-- | src/profiling/SocketProfilingConnection.hpp | 3 | ||||
-rw-r--r-- | src/profiling/test/ProfilingTests.cpp | 15 | ||||
-rw-r--r-- | src/profiling/test/SendCounterPacketTests.cpp | 438 | ||||
-rw-r--r-- | src/profiling/test/SendCounterPacketTests.hpp | 157 |
10 files changed, 738 insertions, 123 deletions
diff --git a/src/profiling/IPeriodicCounterCapture.hpp b/src/profiling/IPeriodicCounterCapture.hpp index 93df9d2541..edc034d20e 100644 --- a/src/profiling/IPeriodicCounterCapture.hpp +++ b/src/profiling/IPeriodicCounterCapture.hpp @@ -7,15 +7,15 @@ namespace armnn { - namespace profiling { + class IPeriodicCounterCapture { public: virtual void Start() = 0; - virtual ~IPeriodicCounterCapture() {}; + virtual ~IPeriodicCounterCapture() {} }; } // namespace profiling -} // namespace armnn
\ No newline at end of file +} // namespace armnn diff --git a/src/profiling/IProfilingConnection.hpp b/src/profiling/IProfilingConnection.hpp index 160a4fa31d..97f7b55477 100644 --- a/src/profiling/IProfilingConnection.hpp +++ b/src/profiling/IProfilingConnection.hpp @@ -24,7 +24,7 @@ public: virtual void Close() = 0; - virtual bool WritePacket(const char* buffer, uint32_t length) = 0; + virtual bool WritePacket(const unsigned char* buffer, uint32_t length) = 0; virtual Packet ReadPacket(uint32_t timeout) = 0; }; diff --git a/src/profiling/Packet.hpp b/src/profiling/Packet.hpp index 1e047a6511..7d70a48366 100644 --- a/src/profiling/Packet.hpp +++ b/src/profiling/Packet.hpp @@ -17,8 +17,16 @@ namespace profiling class Packet { public: + Packet() + : m_Header(0) + , m_Length(0) + , m_Data(nullptr) + {} + Packet(uint32_t header, uint32_t length, std::unique_ptr<char[]>& data) - : m_Header(header), m_Length(length), m_Data(std::move(data)) + : m_Header(header) + , m_Length(length) + , m_Data(std::move(data)) { m_PacketId = ((header >> 16) & 1023); m_PacketFamily = (header >> 26); @@ -34,7 +42,8 @@ public: m_PacketFamily(other.m_PacketFamily), m_PacketId(other.m_PacketId), m_Length(other.m_Length), - m_Data(std::move(other.m_Data)){}; + m_Data(std::move(other.m_Data)) + {} Packet(const Packet& other) = delete; Packet& operator=(const Packet&) = delete; @@ -48,6 +57,8 @@ public: uint32_t GetPacketClass() const; uint32_t GetPacketType() const; + bool IsEmpty() { return m_Header == 0 && m_Length == 0; } + private: uint32_t m_Header; uint32_t m_PacketFamily; diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp index b41f53ca24..b222270546 100644 --- a/src/profiling/SendCounterPacket.cpp +++ b/src/profiling/SendCounterPacket.cpp @@ -919,7 +919,79 @@ void SendCounterPacket::SendPeriodicCounterSelectionPacket(uint32_t capturePerio void SendCounterPacket::SetReadyToRead() { - m_ReadyToRead = true; + // Signal the send thread that there's something to read in the buffer + m_WaitCondition.notify_one(); +} + +void SendCounterPacket::Start() +{ + // Check is the send thread is already running + if (m_IsRunning.load()) + { + // The send thread is already running + return; + } + + // Mark the send thread as running + m_IsRunning.store(true); + + // Keep the send procedure going until the the send thread is signalled to stop + m_KeepRunning.store(true); + + // Start the send thread + m_SendThread = std::thread(&SendCounterPacket::Send, this); +} + +void SendCounterPacket::Stop() +{ + // Signal the send thread to stop + m_KeepRunning.store(false); + + // Check that the send thread is running + if (m_SendThread.joinable()) + { + // Kick the send thread out of the wait condition + m_WaitCondition.notify_one(); + + // Wait for the send thread to complete operations + m_SendThread.join(); + } +} + +void SendCounterPacket::Send() +{ + // Keep the sending procedure looping until the thread is signalled to stop + while (m_KeepRunning.load()) + { + // Wait condition lock scope - Begin + { + // Lock the mutex to wait on it + std::unique_lock<std::mutex> lock(m_WaitMutex); + + // Wait until the thread is notified of something to read from the buffer, or check anyway after a second + m_WaitCondition.wait_for(lock, std::chrono::seconds(1)); + } + // Wait condition lock scope - End + + // Get the data to send from the buffer + unsigned int readBufferSize = 0; + const unsigned char* readBuffer = m_Buffer.GetReadBuffer(readBufferSize); + if (readBuffer == nullptr || readBufferSize == 0) + { + // Nothing to send, ignore and continue + continue; + } + + // Check that the profiling connection is open, silently drop the data and continue if it's closed + if (m_ProfilingConnection.IsOpen()) + { + // Write a packet to the profiling connection. Silently ignore any write error and continue + m_ProfilingConnection.WritePacket(readBuffer, boost::numeric_cast<uint32_t>(readBufferSize)); + } + } + + // Mark the send thread as not running + m_IsRunning.store(false); } } // namespace profiling diff --git a/src/profiling/SendCounterPacket.hpp b/src/profiling/SendCounterPacket.hpp index 2a2d5d4313..8dd44ecd81 100644 --- a/src/profiling/SendCounterPacket.hpp +++ b/src/profiling/SendCounterPacket.hpp @@ -7,7 +7,13 @@ #include "IBufferWrapper.hpp" #include "ISendCounterPacket.hpp" -#include "CounterDirectory.hpp" +#include "ICounterDirectory.hpp" +#include "IProfilingConnection.hpp" + +#include <atomic> +#include <mutex> +#include <thread> +#include <condition_variable> namespace armnn { @@ -25,10 +31,13 @@ public: using IndexValuePairsVector = std::vector<std::pair<uint16_t, uint32_t>>; - SendCounterPacket(IBufferWrapper& buffer) - : m_Buffer(buffer), - m_ReadyToRead(false) + SendCounterPacket(IProfilingConnection& profilingConnection, IBufferWrapper& buffer) + : m_ProfilingConnection(profilingConnection) + , m_Buffer(buffer) + , m_IsRunning(false) + , m_KeepRunning(false) {} + ~SendCounterPacket() { Stop(); } void SendStreamMetaDataPacket() override; @@ -44,7 +53,13 @@ public: static const unsigned int PIPE_MAGIC = 0x45495434; static const unsigned int MAX_METADATA_PACKET_LENGTH = 4096; + void Start(); + void Stop(); + bool IsRunning() { return m_IsRunning.load(); } + private: + void Send(); + template <typename ExceptionType> void CancelOperationAndThrow(const std::string& errorMessage) { @@ -55,8 +70,13 @@ private: throw ExceptionType(errorMessage); } + IProfilingConnection& m_ProfilingConnection; IBufferWrapper& m_Buffer; - bool m_ReadyToRead; + std::mutex m_WaitMutex; + std::condition_variable m_WaitCondition; + std::thread m_SendThread; + std::atomic<bool> m_IsRunning; + std::atomic<bool> m_KeepRunning; protected: // Helper methods, protected for testing @@ -78,4 +98,3 @@ protected: } // namespace profiling } // namespace armnn - diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp index 91d57cc9bd..47fc62f7f0 100644 --- a/src/profiling/SocketProfilingConnection.cpp +++ b/src/profiling/SocketProfilingConnection.cpp @@ -23,7 +23,7 @@ SocketProfilingConnection::SocketProfilingConnection() m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); if (m_Socket[0].fd == -1) { - throw armnn::Exception(std::string(": Socket construction failed: ") + strerror(errno)); + throw armnn::RuntimeException(std::string("Socket construction failed: ") + strerror(errno)); } // Connect to the named unix domain socket. @@ -35,7 +35,7 @@ SocketProfilingConnection::SocketProfilingConnection() if (0 != connect(m_Socket[0].fd, reinterpret_cast<const sockaddr*>(&server), sizeof(sockaddr_un))) { close(m_Socket[0].fd); - throw armnn::Exception(std::string(": Cannot connect to stream socket: ") + strerror(errno)); + throw armnn::RuntimeException(std::string("Cannot connect to stream socket: ") + strerror(errno)); } // Our socket will only be interested in polling reads. @@ -46,96 +46,92 @@ SocketProfilingConnection::SocketProfilingConnection() if (0 != fcntl(m_Socket[0].fd, F_SETFL, currentFlags | O_NONBLOCK)) { close(m_Socket[0].fd); - throw armnn::Exception(std::string(": Failed to set socket as non blocking: ") + strerror(errno)); + throw armnn::RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno)); } } bool SocketProfilingConnection::IsOpen() { - if (m_Socket[0].fd > 0) - { - return true; - } - return false; + return m_Socket[0].fd > 0; } void SocketProfilingConnection::Close() { - if (0 == close(m_Socket[0].fd)) - { - memset(m_Socket, 0, sizeof(m_Socket)); - } - else + if (close(m_Socket[0].fd) != 0) { - throw armnn::Exception(std::string(": Cannot close stream socket: ") + strerror(errno)); + throw armnn::RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno)); } + + memset(m_Socket, 0, sizeof(m_Socket)); } -bool SocketProfilingConnection::WritePacket(const char* buffer, uint32_t length) +bool SocketProfilingConnection::WritePacket(const unsigned char* buffer, uint32_t length) { - if (-1 == write(m_Socket[0].fd, buffer, length)) + if (buffer == nullptr || length == 0) { return false; } - return true; + + return write(m_Socket[0].fd, buffer, length) != -1; } Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) { - // Poll for data on the socket or until timeout. + // Poll for data on the socket or until timeout occurs int pollResult = poll(m_Socket, 1, static_cast<int>(timeout)); - if (pollResult > 0) + + switch (pollResult) { - // Normal poll return but it could still contain an error signal. + case -1: // Error + throw armnn::RuntimeException(std::string("Read failure from socket: ") + strerror(errno)); + + case 0: // Timeout + throw armnn::RuntimeException("Timeout while reading from socket"); + + default: // Normal poll return but it could still contain an error signal + + // Check if the socket reported an error if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP)) { - throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno)); + throw armnn::Exception(std::string("Socket 0 reported an error: ") + strerror(errno)); + } + + // Check if there is data to read + if (!(m_Socket[0].revents & (POLLIN))) + { + // No data to read from the socket. Silently ignore and continue + return Packet(); } - else if (m_Socket[0].revents & (POLLIN)) // There is data to read. + + // There is data to read, read the header first + char header[8] = {}; + if (8 != recv(m_Socket[0].fd, &header, sizeof(header), 0)) { - // Read the header first. - char header[8]; - if (8 != recv(m_Socket[0].fd, &header, sizeof header, 0)) - { - // What do we do here if there's not a valid 8 byte header to read? - throw armnn::Exception(": Received packet did not contains a valid MIPE header. "); - } - // stream_metadata_identifier is the first 4 bytes. - uint32_t metadataIdentifier = static_cast<uint32_t>(header[0]) << 24 | - static_cast<uint32_t>(header[1]) << 16 | - static_cast<uint32_t>(header[2]) << 8 | - static_cast<uint32_t>(header[3]); - // data_length is the next 4 bytes. - uint32_t dataLength = static_cast<uint32_t>(header[4]) << 24 | - static_cast<uint32_t>(header[5]) << 16 | - static_cast<uint32_t>(header[6]) << 8 | - static_cast<uint32_t>(header[7]); - - std::unique_ptr<char[]> packetData; - if (dataLength > 0) - { - packetData = std::make_unique<char[]>(dataLength); - } - - if (dataLength != recv(m_Socket[0].fd, packetData.get(), dataLength, 0)) - { - // What do we do here if we can't read in a full packet? - throw armnn::Exception(": Invalid MIPE packet."); - } - return {metadataIdentifier, dataLength, packetData}; + // What do we do here if there's not a valid 8 byte header to read? + throw armnn::RuntimeException("The received packet did not contains a valid MIPE header"); } - else // Some unknown return signal. + + // stream_metadata_identifier is the first 4 bytes + uint32_t metadataIdentifier = 0; + std::memcpy(&metadataIdentifier, header, sizeof(metadataIdentifier)); + + // data_length is the next 4 bytes + uint32_t dataLength = 0; + std::memcpy(&dataLength, header + 4u, sizeof(dataLength)); + + std::unique_ptr<char[]> packetData; + if (dataLength > 0) { - throw armnn::Exception(": Poll returned an unexpected event." ); + packetData = std::make_unique<char[]>(dataLength); } - } - else if (pollResult == -1) - { - throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno)); - } - else // it's 0 so a timeout. - { - throw armnn::TimeoutException(": Timeout while reading from socket."); + + if (dataLength != recv(m_Socket[0].fd, packetData.get(), dataLength, 0)) + { + // What do we do here if we can't read in a full packet? + throw armnn::RuntimeException("Invalid MIPE packet"); + } + + return Packet(metadataIdentifier, dataLength, packetData); } } diff --git a/src/profiling/SocketProfilingConnection.hpp b/src/profiling/SocketProfilingConnection.hpp index f58e1f4d5d..1ae9f17f7e 100644 --- a/src/profiling/SocketProfilingConnection.hpp +++ b/src/profiling/SocketProfilingConnection.hpp @@ -21,8 +21,9 @@ public: SocketProfilingConnection(); bool IsOpen() final; void Close() final; - bool WritePacket(const char* buffer, uint32_t length) final; + bool WritePacket(const unsigned char* buffer, uint32_t length) final; Packet ReadPacket(uint32_t timeout) final; + private: // To indicate we want to use an abstract UDS ensure the first character of the address is 0. const char* m_GatorNamespace = "\0gatord_namespace"; diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index b90d469424..1741160f96 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -36,6 +36,7 @@ #include <map> #include <random> #include <thread> +#include <chrono> BOOST_AUTO_TEST_SUITE(ExternalProfiling) @@ -101,7 +102,7 @@ public: void Close(){} - bool WritePacket(const char* buffer, uint32_t length) + bool WritePacket(const unsigned char* buffer, uint32_t length) { return false; } @@ -1754,8 +1755,9 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) uint32_t version = 1; Holder holder; TestCaptureThread captureThread; + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(512); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockBuffer); uint32_t sizeOfUint32 = numeric_cast<uint32_t>(sizeof(uint32_t)); uint32_t sizeOfUint16 = numeric_cast<uint32_t>(sizeof(uint16_t)); @@ -2112,8 +2114,9 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread) std::vector<uint16_t> captureIds1 = { 0, 1 }; std::vector<uint16_t> captureIds2; + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(512); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockBuffer); std::vector<uint16_t> counterIds; CaptureReader captureReader; @@ -2183,8 +2186,9 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest0) Packet packetA(packetId, 0, packetData); + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockBuffer); CounterDirectory counterDirectory; @@ -2217,8 +2221,9 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1) Packet packetA(packetId, 0, packetData); + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockBuffer); CounterDirectory counterDirectory; const Device* device = counterDirectory.RegisterDevice("deviceA", 1); diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp index 90bc9225a0..3dda2e7b37 100644 --- a/src/profiling/test/SendCounterPacketTests.cpp +++ b/src/profiling/test/SendCounterPacketTests.cpp @@ -8,6 +8,7 @@ #include <EncodeVersion.hpp> #include <ProfilingUtils.hpp> #include <SendCounterPacket.hpp> +#include <CounterDirectory.hpp> #include <armnn/Exceptions.hpp> #include <armnn/Conversion.hpp> @@ -16,10 +17,22 @@ #include <boost/numeric/conversion/cast.hpp> #include <chrono> -#include <iostream> using namespace armnn::profiling; +size_t GetDataLength(const MockStreamCounterBuffer& mockStreamCounterBuffer, size_t packetOffset) +{ + // The data length is the written in the second byte + return ReadUint32(mockStreamCounterBuffer.GetBuffer(), + boost::numeric_cast<unsigned int>(packetOffset + sizeof(uint32_t))); +} + +size_t GetPacketSize(const MockStreamCounterBuffer& mockStreamCounterBuffer, size_t packetOffset) +{ + // The packet size is the data length plus the size of the packet header (always two words big) + return GetDataLength(mockStreamCounterBuffer, packetOffset) + 2 * sizeof(uint32_t); +} + BOOST_AUTO_TEST_SUITE(SendCounterPacketTests) BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest) @@ -56,17 +69,18 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest) BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest) { // Error no space left in buffer + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer1(10); - SendCounterPacket sendPacket1(mockBuffer1); + SendCounterPacket sendPacket1(mockProfilingConnection, mockBuffer1); uint32_t capturePeriod = 1000; std::vector<uint16_t> selectedCounterIds; BOOST_CHECK_THROW(sendPacket1.SendPeriodicCounterSelectionPacket(capturePeriod, selectedCounterIds), - armnn::profiling::BufferExhaustion); + BufferExhaustion); // Packet without any counters MockBuffer mockBuffer2(512); - SendCounterPacket sendPacket2(mockBuffer2); + SendCounterPacket sendPacket2(mockProfilingConnection, mockBuffer2); sendPacket2.SendPeriodicCounterSelectionPacket(capturePeriod, selectedCounterIds); unsigned int sizeRead = 0; @@ -83,7 +97,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest) // Full packet message MockBuffer mockBuffer3(512); - SendCounterPacket sendPacket3(mockBuffer3); + SendCounterPacket sendPacket3(mockProfilingConnection, mockBuffer3); selectedCounterIds.reserve(5); selectedCounterIds.emplace_back(100); @@ -119,8 +133,9 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest) BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest) { // Error no space left in buffer + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer1(10); - SendCounterPacket sendPacket1(mockBuffer1); + SendCounterPacket sendPacket1(mockProfilingConnection, mockBuffer1); auto captureTimestamp = std::chrono::steady_clock::now(); uint64_t time = static_cast<uint64_t >(captureTimestamp.time_since_epoch().count()); @@ -131,7 +146,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest) // Packet without any counters MockBuffer mockBuffer2(512); - SendCounterPacket sendPacket2(mockBuffer2); + SendCounterPacket sendPacket2(mockProfilingConnection, mockBuffer2); sendPacket2.SendPeriodicCounterCapturePacket(time, indexValuePairs); unsigned int sizeRead = 0; @@ -149,7 +164,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest) // Full packet message MockBuffer mockBuffer3(512); - SendCounterPacket sendPacket3(mockBuffer3); + SendCounterPacket sendPacket3(mockProfilingConnection, mockBuffer3); indexValuePairs.reserve(5); indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t >(0, 100)); @@ -200,8 +215,9 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest) uint32_t sizeUint32 = numeric_cast<uint32_t>(sizeof(uint32_t)); // Error no space left in buffer + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer1(10); - SendCounterPacket sendPacket1(mockBuffer1); + SendCounterPacket sendPacket1(mockProfilingConnection, mockBuffer1); BOOST_CHECK_THROW(sendPacket1.SendStreamMetaDataPacket(), armnn::profiling::BufferExhaustion); // Full metadata packet @@ -220,7 +236,7 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest) uint32_t packetEntries = 6; MockBuffer mockBuffer2(512); - SendCounterPacket sendPacket2(mockBuffer2); + SendCounterPacket sendPacket2(mockProfilingConnection, mockBuffer2); sendPacket2.SendStreamMetaDataPacket(); unsigned int sizeRead = 0; const unsigned char* readBuffer2 = mockBuffer2.GetReadBuffer(sizeRead); @@ -313,8 +329,9 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest) BOOST_AUTO_TEST_CASE(CreateDeviceRecordTest) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a device for testing uint16_t deviceUid = 27; @@ -345,8 +362,9 @@ BOOST_AUTO_TEST_CASE(CreateDeviceRecordTest) BOOST_AUTO_TEST_CASE(CreateInvalidDeviceRecordTest) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a device for testing uint16_t deviceUid = 27; @@ -366,8 +384,9 @@ BOOST_AUTO_TEST_CASE(CreateInvalidDeviceRecordTest) BOOST_AUTO_TEST_CASE(CreateCounterSetRecordTest) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a counter set for testing uint16_t counterSetUid = 27; @@ -398,8 +417,9 @@ BOOST_AUTO_TEST_CASE(CreateCounterSetRecordTest) BOOST_AUTO_TEST_CASE(CreateInvalidCounterSetRecordTest) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a counter set for testing uint16_t counterSetUid = 27; @@ -419,8 +439,9 @@ BOOST_AUTO_TEST_CASE(CreateInvalidCounterSetRecordTest) BOOST_AUTO_TEST_CASE(CreateEventRecordTest) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -539,8 +560,9 @@ BOOST_AUTO_TEST_CASE(CreateEventRecordTest) BOOST_AUTO_TEST_CASE(CreateEventRecordNoUnitsTest) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a counter for testing uint16_t counterUid = 44312; @@ -642,8 +664,9 @@ BOOST_AUTO_TEST_CASE(CreateEventRecordNoUnitsTest) BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest1) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -680,8 +703,9 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest1) BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest2) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -718,8 +742,9 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest2) BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest3) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -756,8 +781,9 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest3) BOOST_AUTO_TEST_CASE(CreateCategoryRecordTest) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a category for testing const std::string categoryName = "some_category"; @@ -957,8 +983,9 @@ BOOST_AUTO_TEST_CASE(CreateCategoryRecordTest) BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest1) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a category for testing const std::string categoryName = "some invalid category"; @@ -980,8 +1007,9 @@ BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest1) BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest2) { + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(mockProfilingConnection, mockBuffer); // Create a category for testing const std::string categoryName = "some_category"; @@ -1038,8 +1066,9 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest1) BOOST_CHECK(device2); // Buffer with not enough space + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(10); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::profiling::BufferExhaustion); } @@ -1130,8 +1159,9 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) BOOST_CHECK(counter3); // Buffer with enough space + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockBuffer); BOOST_CHECK_NO_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory)); // Get the read buffer @@ -1529,8 +1559,9 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest3) BOOST_CHECK(device); // Buffer with enough space + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } @@ -1547,8 +1578,9 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest4) BOOST_CHECK(counterSet); // Buffer with enough space + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } @@ -1565,8 +1597,9 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest5) BOOST_CHECK(category); // Buffer with enough space + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } @@ -1599,8 +1632,9 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest6) BOOST_CHECK(category); // Buffer with enough space + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } @@ -1648,9 +1682,355 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest7) BOOST_CHECK(counter); // Buffer with enough space + MockProfilingConnection mockProfilingConnection; MockBuffer mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } +BOOST_AUTO_TEST_CASE(SendThreadTest0) +{ + MockProfilingConnection mockProfilingConnection; + MockStreamCounterBuffer mockStreamCounterBuffer(0); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockStreamCounterBuffer); + + // Try to start the send thread many times, it must only start once + + sendCounterPacket.Start(); + BOOST_CHECK(sendCounterPacket.IsRunning()); + sendCounterPacket.Start(); + sendCounterPacket.Start(); + sendCounterPacket.Start(); + sendCounterPacket.Start(); + BOOST_CHECK(sendCounterPacket.IsRunning()); + + std::this_thread::sleep_for(std::chrono::seconds(1)); + + sendCounterPacket.Stop(); + BOOST_CHECK(!sendCounterPacket.IsRunning()); +} + +BOOST_AUTO_TEST_CASE(SendThreadTest1) +{ + size_t totalWrittenSize = 0; + + MockProfilingConnection mockProfilingConnection; + MockStreamCounterBuffer mockStreamCounterBuffer(100); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockStreamCounterBuffer); + sendCounterPacket.Start(); + + // Interleaving writes and reads to/from the buffer with pauses to test that the send thread actually waits for + // something to become available for reading + + std::this_thread::sleep_for(std::chrono::seconds(1)); + + CounterDirectory counterDirectory; + sendCounterPacket.SendStreamMetaDataPacket(); + + // Get the size of the Stream Metadata Packet + size_t streamMetadataPacketsize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += streamMetadataPacketsize; + + sendCounterPacket.SetReadyToRead(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + sendCounterPacket.SendCounterDirectoryPacket(counterDirectory); + + // Get the size of the Counter Directory Packet + size_t counterDirectoryPacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += counterDirectoryPacketSize; + + sendCounterPacket.SetReadyToRead(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + sendCounterPacket.SendPeriodicCounterCapturePacket(123u, + { + { 1u, 23u }, + { 33u, 1207623u } + }); + + // Get the size of the Periodic Counter Capture Packet + size_t periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SetReadyToRead(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + sendCounterPacket.SendPeriodicCounterCapturePacket(44u, + { + { 211u, 923u } + }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SendPeriodicCounterCapturePacket(1234u, + { + { 555u, 23u }, + { 556u, 6u }, + { 557u, 893454u }, + { 558u, 1456623u }, + { 559u, 571090u } + }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SendPeriodicCounterCapturePacket(997u, + { + { 88u, 11u }, + { 96u, 22u }, + { 97u, 33u }, + { 999u, 444u } + }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SetReadyToRead(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + sendCounterPacket.SendPeriodicCounterSelectionPacket(1000u, { 1345u, 254u, 4536u, 408u, 54u, 6323u, 428u, 1u, 6u }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SetReadyToRead(); + + // To test an exact value of the "read size" in the mock buffer, wait a second to allow the send thread to + // read all what's remaining in the buffer + std::this_thread::sleep_for(std::chrono::seconds(1)); + + sendCounterPacket.Stop(); + + BOOST_CHECK(mockStreamCounterBuffer.GetBufferSize() == totalWrittenSize); + BOOST_CHECK(mockStreamCounterBuffer.GetCommittedSize() == totalWrittenSize); + BOOST_CHECK(mockStreamCounterBuffer.GetReadSize() == totalWrittenSize); +} + +BOOST_AUTO_TEST_CASE(SendThreadTest2) +{ + size_t totalWrittenSize = 0; + + MockProfilingConnection mockProfilingConnection; + MockStreamCounterBuffer mockStreamCounterBuffer(100); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockStreamCounterBuffer); + sendCounterPacket.Start(); + + // Adding many spurious "ready to read" signals throughout the test to check that the send thread is + // capable of handling unnecessary read requests + + std::this_thread::sleep_for(std::chrono::seconds(1)); + + sendCounterPacket.SetReadyToRead(); + + CounterDirectory counterDirectory; + sendCounterPacket.SendStreamMetaDataPacket(); + + // Get the size of the Stream Metadata Packet + size_t streamMetadataPacketsize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += streamMetadataPacketsize; + + sendCounterPacket.SetReadyToRead(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + sendCounterPacket.SendCounterDirectoryPacket(counterDirectory); + + // Get the size of the Counter Directory Packet + size_t counterDirectoryPacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += counterDirectoryPacketSize; + + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SetReadyToRead(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + sendCounterPacket.SendPeriodicCounterCapturePacket(123u, + { + { 1u, 23u }, + { 33u, 1207623u } + }); + + // Get the size of the Periodic Counter Capture Packet + size_t periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SetReadyToRead(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SetReadyToRead(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SendPeriodicCounterCapturePacket(44u, + { + { 211u, 923u } + }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SendPeriodicCounterCapturePacket(1234u, + { + { 555u, 23u }, + { 556u, 6u }, + { 557u, 893454u }, + { 558u, 1456623u }, + { 559u, 571090u } + }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SendPeriodicCounterCapturePacket(997u, + { + { 88u, 11u }, + { 96u, 22u }, + { 97u, 33u }, + { 999u, 444u } + }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SetReadyToRead(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + sendCounterPacket.SendPeriodicCounterSelectionPacket(1000u, { 1345u, 254u, 4536u, 408u, 54u, 6323u, 428u, 1u, 6u }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SetReadyToRead(); + + // To test an exact value of the "read size" in the mock buffer, wait a second to allow the send thread to + // read all what's remaining in the buffer + std::this_thread::sleep_for(std::chrono::seconds(1)); + + sendCounterPacket.Stop(); + + BOOST_CHECK(mockStreamCounterBuffer.GetBufferSize() == totalWrittenSize); + BOOST_CHECK(mockStreamCounterBuffer.GetCommittedSize() == totalWrittenSize); + BOOST_CHECK(mockStreamCounterBuffer.GetReadSize() == totalWrittenSize); +} + +BOOST_AUTO_TEST_CASE(SendThreadTest3) +{ + size_t totalWrittenSize = 0; + + MockProfilingConnection mockProfilingConnection; + MockStreamCounterBuffer mockStreamCounterBuffer(100); + SendCounterPacket sendCounterPacket(mockProfilingConnection, mockStreamCounterBuffer); + sendCounterPacket.Start(); + + // Not using pauses or "grace periods" to stress test the send thread + + sendCounterPacket.SetReadyToRead(); + + CounterDirectory counterDirectory; + sendCounterPacket.SendStreamMetaDataPacket(); + + // Get the size of the Stream Metadata Packet + size_t streamMetadataPacketsize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += streamMetadataPacketsize; + + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SendCounterDirectoryPacket(counterDirectory); + + // Get the size of the Counter Directory Packet + size_t counterDirectoryPacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += counterDirectoryPacketSize; + + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SendPeriodicCounterCapturePacket(123u, + { + { 1u, 23u }, + { 33u, 1207623u } + }); + + // Get the size of the Periodic Counter Capture Packet + size_t periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SendPeriodicCounterCapturePacket(44u, + { + { 211u, 923u } + }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SendPeriodicCounterCapturePacket(1234u, + { + { 555u, 23u }, + { 556u, 6u }, + { 557u, 893454u }, + { 558u, 1456623u }, + { 559u, 571090u } + }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SendPeriodicCounterCapturePacket(997u, + { + { 88u, 11u }, + { 96u, 22u }, + { 97u, 33u }, + { 999u, 444u } + }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SetReadyToRead(); + sendCounterPacket.SendPeriodicCounterSelectionPacket(1000u, { 1345u, 254u, 4536u, 408u, 54u, 6323u, 428u, 1u, 6u }); + + // Get the size of the Periodic Counter Capture Packet + periodicCounterCapturePacketSize = GetPacketSize(mockStreamCounterBuffer, totalWrittenSize); + totalWrittenSize += periodicCounterCapturePacketSize; + + sendCounterPacket.SetReadyToRead(); + + // Abruptly terminating the send thread, the amount of data sent may be less that the amount written (the send + // thread is not guaranteed to flush the buffer) + sendCounterPacket.Stop(); + + BOOST_CHECK(mockStreamCounterBuffer.GetBufferSize() == totalWrittenSize); + BOOST_CHECK(mockStreamCounterBuffer.GetCommittedSize() == totalWrittenSize); + BOOST_CHECK(mockStreamCounterBuffer.GetReadSize() <= totalWrittenSize); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp index 5d5dfd14c7..3616816ae2 100644 --- a/src/profiling/test/SendCounterPacketTests.hpp +++ b/src/profiling/test/SendCounterPacketTests.hpp @@ -5,24 +5,50 @@ #pragma once -#include "../SendCounterPacket.hpp" -#include "../ProfilingUtils.hpp" +#include <SendCounterPacket.hpp> +#include <ProfilingUtils.hpp> #include <armnn/Exceptions.hpp> +#include <armnn/Optional.hpp> +#include <armnn/Conversion.hpp> -#include <boost/test/unit_test.hpp> +#include <boost/numeric/conversion/cast.hpp> -#include <chrono> -#include <iostream> +namespace armnn +{ -using namespace armnn::profiling; +namespace profiling +{ + +class MockProfilingConnection : public IProfilingConnection +{ +public: + MockProfilingConnection() + : m_IsOpen(true) + {} + + bool IsOpen() override { return m_IsOpen; } + + void Close() override { m_IsOpen = false; } + + bool WritePacket(const unsigned char* buffer, uint32_t length) override + { + return buffer != nullptr && length > 0; + } + + Packet ReadPacket(uint32_t timeout) override { return Packet(); } + +private: + bool m_IsOpen; +}; class MockBuffer : public IBufferWrapper { public: MockBuffer(unsigned int size) - : m_BufferSize(size), - m_Buffer(std::make_unique<unsigned char[]>(size)) {} + : m_BufferSize(size) + , m_Buffer(std::make_unique<unsigned char[]>(size)) + {} unsigned char* Reserve(unsigned int requestedSize, unsigned int& reservedSize) override { @@ -46,13 +72,115 @@ public: return m_Buffer.get(); } - void Release( unsigned int size) override {} + void Release(unsigned int size) override {} private: unsigned int m_BufferSize; std::unique_ptr<unsigned char[]> m_Buffer; }; +class MockStreamCounterBuffer : public IBufferWrapper +{ +public: + MockStreamCounterBuffer(unsigned int size) + : m_Buffer(size, 0) + , m_CommittedSize(0) + , m_ReadSize(0) + {} + + unsigned char* Reserve(unsigned int requestedSize, unsigned int& reservedSize) override + { + std::unique_lock<std::mutex>(m_Mutex); + + // Get the buffer size and the available size in the buffer past the committed size + size_t bufferSize = m_Buffer.size(); + size_t availableSize = bufferSize - m_CommittedSize; + + // Check whether the buffer needs to be resized + if (requestedSize > availableSize) + { + // Resize the buffer + size_t newSize = m_CommittedSize + requestedSize; + m_Buffer.resize(newSize, 0); + } + + // Set the reserved size + reservedSize = requestedSize; + + // Get a pointer to the beginning of the part of buffer available for writing + return m_Buffer.data() + m_CommittedSize; + } + + void Commit(unsigned int size) override + { + std::unique_lock<std::mutex>(m_Mutex); + + // Update the committed size + m_CommittedSize += size; + } + + const unsigned char* GetReadBuffer(unsigned int& size) override + { + std::unique_lock<std::mutex>(m_Mutex); + + // Get the size available for reading + size = boost::numeric_cast<unsigned int>(m_CommittedSize - m_ReadSize); + + // Get a pointer to the beginning of the part of buffer available for reading + const unsigned char* readBuffer = m_Buffer.data() + m_ReadSize; + + // Update the read size + m_ReadSize = m_CommittedSize; + + return readBuffer; + } + + void Release(unsigned int size) override + { + std::unique_lock<std::mutex>(m_Mutex); + + if (size == 0) + { + // Nothing to release + return; + } + + // Get the buffer size + size_t bufferSize = m_Buffer.size(); + + // Remove the last "size" bytes from the buffer + if (size < bufferSize) + { + // Resize the buffer + size_t newSize = bufferSize - size; + m_Buffer.resize(newSize); + } + else + { + // Clear the whole buffer + m_Buffer.clear(); + } + } + + size_t GetBufferSize() const { return m_Buffer.size(); } + size_t GetCommittedSize() const { return m_CommittedSize; } + size_t GetReadSize() const { return m_ReadSize; } + const unsigned char* GetBuffer() const { return m_Buffer.data(); } + +private: + // This mock uses an ever-expanding vector to simulate a counter stream buffer + std::vector<unsigned char> m_Buffer; + + // The size of the buffer that has been committed for reading + size_t m_CommittedSize; + + // The size of the buffer that has already been read + size_t m_ReadSize; + + // This mock buffer provides basic synchronization + std::mutex m_Mutex; +}; + class MockSendCounterPacket : public ISendCounterPacket { public: @@ -93,8 +221,7 @@ public: m_Buffer.Commit(reserved); } - void SetReadyToRead() override - {} + void SetReadyToRead() override {} private: IBufferWrapper& m_Buffer; @@ -307,8 +434,8 @@ private: class SendCounterPacketTest : public SendCounterPacket { public: - SendCounterPacketTest(IBufferWrapper& buffer) - : SendCounterPacket(buffer) + SendCounterPacketTest(IProfilingConnection& profilingconnection, IBufferWrapper& buffer) + : SendCounterPacket(profilingconnection, buffer) {} bool CreateDeviceRecordTest(const DevicePtr& device, @@ -340,3 +467,7 @@ public: return CreateCategoryRecord(category, counters, categoryRecord, errorMessage); } }; + +} // namespace profiling + +} // namespace armnn |