aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-10-07 13:05:13 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-10-08 10:43:50 +0100
commit5d737fb3b06c17ff6b65fb307343ca1c0c680401 (patch)
tree91c5548fd4e84ff086e57020a017982124b06c50
parentc2728f95086c54aa842e4c1dae8f3b5c290a72fa (diff)
downloadarmnn-5d737fb3b06c17ff6b65fb307343ca1c0c680401.tar.gz
IVGCVSW-3937 Update the Send thread to send out the Metadata packet
* The Send thread now automatically sends out Stream Metadata packets when the Profiling Service is in WaitingForAck state * Added a reference to the profiling state in the SendCounterPacket class * Moving the RuntimeException thrown in the Send thread to the main thread for rethrowing * The Stop method now rethrows the exception occurred in the send thread * The Stop method does not rethrow when destructing the object * Added unit tests Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com> Change-Id: Ice7080bff63199eac84fc4fa1d37fb1a6fcdff89
-rw-r--r--src/profiling/SendCounterPacket.cpp91
-rw-r--r--src/profiling/SendCounterPacket.hpp29
-rw-r--r--src/profiling/test/ProfilingTests.cpp16
-rw-r--r--src/profiling/test/SendCounterPacketTests.cpp309
-rw-r--r--src/profiling/test/SendCounterPacketTests.hpp39
5 files changed, 390 insertions, 94 deletions
diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp
index dc5a950bea..b9f2b187b7 100644
--- a/src/profiling/SendCounterPacket.cpp
+++ b/src/profiling/SendCounterPacket.cpp
@@ -920,7 +920,7 @@ void SendCounterPacket::Start(IProfilingConnection& profilingConnection)
m_SendThread = std::thread(&SendCounterPacket::Send, this, std::ref(profilingConnection));
}
-void SendCounterPacket::Stop()
+void SendCounterPacket::Stop(bool rethrowSendThreadExceptions)
{
// Signal the send thread to stop
m_KeepRunning.store(false);
@@ -934,6 +934,30 @@ void SendCounterPacket::Stop()
// Wait for the send thread to complete operations
m_SendThread.join();
}
+
+ // Check if the send thread exception has to be rethrown
+ if (!rethrowSendThreadExceptions)
+ {
+ // No need to rethrow the send thread exception, return immediately
+ return;
+ }
+
+ // Exception handling lock scope - Begin
+ {
+ // Lock the mutex to handle any exception coming from the send thread
+ std::unique_lock<std::mutex> lock(m_WaitMutex);
+
+ // Check if there's an exception to rethrow
+ if (m_SendThreadException)
+ {
+ // Rethrow the send thread exception
+ std::rethrow_exception(m_SendThreadException);
+
+ // Nullify the exception as it has been rethrown
+ m_SendThreadException = nullptr;
+ }
+ }
+ // Exception handling lock scope - End
}
void SendCounterPacket::Send(IProfilingConnection& profilingConnection)
@@ -946,20 +970,67 @@ void SendCounterPacket::Send(IProfilingConnection& profilingConnection)
// Lock the mutex to wait on it
std::unique_lock<std::mutex> lock(m_WaitMutex);
- if (m_Timeout < 0)
+ // Check the current state of the profiling service
+ ProfilingState currentState = m_StateMachine.GetCurrentState();
+ switch (currentState)
{
- // Wait indefinitely until notified that something to read has become available in the buffer
+ case ProfilingState::Uninitialised:
+ case ProfilingState::NotConnected:
+
+ // The send thread cannot be running when the profiling service is uninitialized or not connected,
+ // stop the thread immediately
+ m_KeepRunning.store(false);
+ m_IsRunning.store(false);
+
+ // An exception should be thrown here, save it to be rethrown later from the main thread so that
+ // it can be caught by the consumer
+ m_SendThreadException =
+ std::make_exception_ptr(RuntimeException("The send thread should not be running with the "
+ "profiling service not yet initialized or connected"));
+
+ return;
+ case ProfilingState::WaitingForAck:
+
+ // Send out a StreamMetadata packet and wait for the profiling connection to be acknowledged.
+ // When a ConnectionAcknowledged packet is received, the profiling service state will be automatically
+ // updated by the command handler
+
+ // Prepare a StreamMetadata packet and write it to the Counter Stream buffer
+ SendStreamMetaDataPacket();
+
+ // Flush the buffer manually to send the packet
+ FlushBuffer(profilingConnection);
+
+ // Wait indefinitely until notified otherwise (it could that the profiling state has changed due to the
+ // connection being acknowledged, or that new data is ready to be sent, or that the send thread is
+ // being shut down, etc.)
m_WaitCondition.wait(lock);
- }
- else
- {
- // Wait until the thread is notified of something to read from the buffer,
- // or check anyway after the specified number of milliseconds
- m_WaitCondition.wait_for(lock, std::chrono::milliseconds(m_Timeout));
+
+ // Do not flush the buffer again
+ continue;
+ case ProfilingState::Active:
+ default:
+ // Normal working state for the send thread
+
+ // Check if the send thread is required to enforce a timeout wait policy
+ if (m_Timeout < 0)
+ {
+ // Wait indefinitely until notified that something to read has become available in the buffer
+ m_WaitCondition.wait(lock);
+ }
+ else
+ {
+ // Wait until the thread is notified of something to read from the buffer,
+ // or check anyway after the specified number of milliseconds
+ m_WaitCondition.wait_for(lock, std::chrono::milliseconds(m_Timeout));
+ }
+
+ break;
}
}
// Wait condition lock scope - End
+ // Send all the available packets in the buffer
FlushBuffer(profilingConnection);
}
@@ -1000,7 +1071,7 @@ void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection)
// Mark the packet buffer as read
m_BufferManager.MarkRead(packetBuffer);
- // Get next available readable buffer
+ // Get the next available readable buffer
packetBuffer = m_BufferManager.GetReadableBuffer();
}
}
diff --git a/src/profiling/SendCounterPacket.hpp b/src/profiling/SendCounterPacket.hpp
index ed76937cc3..9361efbc74 100644
--- a/src/profiling/SendCounterPacket.hpp
+++ b/src/profiling/SendCounterPacket.hpp
@@ -6,9 +6,10 @@
#pragma once
#include "IBufferManager.hpp"
-#include "ISendCounterPacket.hpp"
#include "ICounterDirectory.hpp"
+#include "ISendCounterPacket.hpp"
#include "IProfilingConnection.hpp"
+#include "ProfilingStateMachine.hpp"
#include "ProfilingUtils.hpp"
#include <atomic>
@@ -26,20 +27,25 @@ namespace profiling
class SendCounterPacket : public ISendCounterPacket
{
public:
- using CategoryRecord = std::vector<uint32_t>;
- using DeviceRecord = std::vector<uint32_t>;
- using CounterSetRecord = std::vector<uint32_t>;
- using EventRecord = std::vector<uint32_t>;
-
+ using CategoryRecord = std::vector<uint32_t>;
+ using DeviceRecord = std::vector<uint32_t>;
+ using CounterSetRecord = std::vector<uint32_t>;
+ using EventRecord = std::vector<uint32_t>;
using IndexValuePairsVector = std::vector<std::pair<uint16_t, uint32_t>>;
- SendCounterPacket(IBufferManager& buffer, int timeout = 1000)
- : m_BufferManager(buffer)
+ SendCounterPacket(ProfilingStateMachine& profilingStateMachine, IBufferManager& buffer, int timeout = 1000)
+ : m_StateMachine(profilingStateMachine)
+ , m_BufferManager(buffer)
, m_Timeout(timeout)
, m_IsRunning(false)
, m_KeepRunning(false)
+ , m_SendThreadException(nullptr)
{}
- ~SendCounterPacket() { Stop(); }
+ ~SendCounterPacket()
+ {
+ // Don't rethrow when destructing the object
+ Stop(false);
+ }
void SendStreamMetaDataPacket() override;
@@ -56,7 +62,7 @@ public:
static const unsigned int MAX_METADATA_PACKET_LENGTH = 4096;
void Start(IProfilingConnection& profilingConnection);
- void Stop();
+ void Stop(bool rethrowSendThreadExceptions = true);
bool IsRunning() { return m_IsRunning.load(); }
private:
@@ -76,6 +82,7 @@ private:
{
SetReadyToRead();
}
+
if (writerBuffer != nullptr)
{
// Cancel the operation
@@ -88,6 +95,7 @@ private:
void FlushBuffer(IProfilingConnection& profilingConnection);
+ ProfilingStateMachine& m_StateMachine;
IBufferManager& m_BufferManager;
int m_Timeout;
std::mutex m_WaitMutex;
@@ -95,6 +103,7 @@ private:
std::thread m_SendThread;
std::atomic<bool> m_IsRunning;
std::atomic<bool> m_KeepRunning;
+ std::exception_ptr m_SendThreadException;
protected:
// Helper methods, protected for testing
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index 91568d111d..24ab779412 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -1767,6 +1767,8 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
{
using boost::numeric_cast;
+ ProfilingStateMachine profilingStateMachine;
+
class TestCaptureThread : public IPeriodicCounterCapture
{
void Start() override {}
@@ -1779,7 +1781,7 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
Holder holder;
TestCaptureThread captureThread;
MockBufferManager mockBuffer(512);
- SendCounterPacket sendCounterPacket(mockBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
uint32_t sizeOfUint32 = numeric_cast<uint32_t>(sizeof(uint32_t));
uint32_t sizeOfUint16 = numeric_cast<uint32_t>(sizeof(uint16_t));
@@ -2135,12 +2137,14 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread)
std::unordered_map<uint16_t, uint32_t> m_Data;
};
+ ProfilingStateMachine profilingStateMachine;
+
Holder data;
std::vector<uint16_t> captureIds1 = { 0, 1 };
std::vector<uint16_t> captureIds2;
MockBufferManager mockBuffer(512);
- SendCounterPacket sendCounterPacket(mockBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
std::vector<uint16_t> counterIds;
CaptureReader captureReader;
@@ -2201,6 +2205,8 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest0)
{
using boost::numeric_cast;
+ ProfilingStateMachine profilingStateMachine;
+
const uint32_t packetId = 0x30000;
const uint32_t version = 1;
@@ -2209,7 +2215,7 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest0)
Packet packetA(packetId, 0, packetData);
MockBufferManager mockBuffer(1024);
- SendCounterPacket sendCounterPacket(mockBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
CounterDirectory counterDirectory;
@@ -2234,6 +2240,8 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1)
{
using boost::numeric_cast;
+ ProfilingStateMachine profilingStateMachine;
+
const uint32_t packetId = 0x30000;
const uint32_t version = 1;
@@ -2242,7 +2250,7 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1)
Packet packetA(packetId, 0, packetData);
MockBufferManager mockBuffer(1024);
- SendCounterPacket sendCounterPacket(mockBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
CounterDirectory counterDirectory;
const Device* device = counterDirectory.RegisterDevice("deviceA", 1);
diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp
index 16302bcd0a..1216420383 100644
--- a/src/profiling/test/SendCounterPacketTests.cpp
+++ b/src/profiling/test/SendCounterPacketTests.cpp
@@ -21,14 +21,71 @@
using namespace armnn::profiling;
+namespace
+{
+
+void SetNotConnectedProfilingState(ProfilingStateMachine& profilingStateMachine)
+{
+ ProfilingState currentState = profilingStateMachine.GetCurrentState();
+ switch (currentState)
+ {
+ case ProfilingState::WaitingForAck:
+ profilingStateMachine.TransitionToState(ProfilingState::Active);
+ case ProfilingState::Uninitialised:
+ case ProfilingState::Active:
+ profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+ case ProfilingState::NotConnected:
+ return;
+ default:
+ BOOST_CHECK_MESSAGE(false, "Invalid profiling state");
+ }
+}
+
+void SetWaitingForAckProfilingState(ProfilingStateMachine& profilingStateMachine)
+{
+ ProfilingState currentState = profilingStateMachine.GetCurrentState();
+ switch (currentState)
+ {
+ case ProfilingState::Uninitialised:
+ case ProfilingState::Active:
+ profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+ case ProfilingState::NotConnected:
+ profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+ case ProfilingState::WaitingForAck:
+ return;
+ default:
+ BOOST_CHECK_MESSAGE(false, "Invalid profiling state");
+ }
+}
+
+void SetActiveProfilingState(ProfilingStateMachine& profilingStateMachine)
+{
+ ProfilingState currentState = profilingStateMachine.GetCurrentState();
+ switch (currentState)
+ {
+ case ProfilingState::Uninitialised:
+ profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+ case ProfilingState::NotConnected:
+ profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+ case ProfilingState::WaitingForAck:
+ profilingStateMachine.TransitionToState(ProfilingState::Active);
+ case ProfilingState::Active:
+ return;
+ default:
+ BOOST_CHECK_MESSAGE(false, "Invalid profiling state");
+ }
+}
+
+} // Anonymous namespace
+
BOOST_AUTO_TEST_SUITE(SendCounterPacketTests)
BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest)
{
MockBufferManager mockBuffer(512);
- MockSendCounterPacket sendCounterPacket(mockBuffer);
+ MockSendCounterPacket mockSendCounterPacket(mockBuffer);
- sendCounterPacket.SendStreamMetaDataPacket();
+ mockSendCounterPacket.SendStreamMetaDataPacket();
auto packetBuffer = mockBuffer.GetReadableBuffer();
const char* buffer = reinterpret_cast<const char*>(packetBuffer->GetReadableData());
@@ -38,7 +95,7 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest)
mockBuffer.MarkRead(packetBuffer);
CounterDirectory counterDirectory;
- sendCounterPacket.SendCounterDirectoryPacket(counterDirectory);
+ mockSendCounterPacket.SendCounterDirectoryPacket(counterDirectory);
packetBuffer = mockBuffer.GetReadableBuffer();
buffer = reinterpret_cast<const char*>(packetBuffer->GetReadableData());
@@ -50,7 +107,7 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest)
uint64_t timestamp = 0;
std::vector<std::pair<uint16_t, uint32_t>> indexValuePairs;
- sendCounterPacket.SendPeriodicCounterCapturePacket(timestamp, indexValuePairs);
+ mockSendCounterPacket.SendPeriodicCounterCapturePacket(timestamp, indexValuePairs);
packetBuffer = mockBuffer.GetReadableBuffer();
buffer = reinterpret_cast<const char*>(packetBuffer->GetReadableData());
@@ -61,7 +118,7 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest)
uint32_t capturePeriod = 0;
std::vector<uint16_t> selectedCounterIds;
- sendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, selectedCounterIds);
+ mockSendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, selectedCounterIds);
packetBuffer = mockBuffer.GetReadableBuffer();
buffer = reinterpret_cast<const char*>(packetBuffer->GetReadableData());
@@ -73,9 +130,11 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest)
BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest)
{
+ ProfilingStateMachine profilingStateMachine;
+
// Error no space left in buffer
MockBufferManager mockBuffer1(10);
- SendCounterPacket sendPacket1(mockBuffer1);
+ SendCounterPacket sendPacket1(profilingStateMachine, mockBuffer1);
uint32_t capturePeriod = 1000;
std::vector<uint16_t> selectedCounterIds;
@@ -84,7 +143,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest)
// Packet without any counters
MockBufferManager mockBuffer2(512);
- SendCounterPacket sendPacket2(mockBuffer2);
+ SendCounterPacket sendPacket2(profilingStateMachine, mockBuffer2);
sendPacket2.SendPeriodicCounterSelectionPacket(capturePeriod, selectedCounterIds);
auto readBuffer2 = mockBuffer2.GetReadableBuffer();
@@ -100,7 +159,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest)
// Full packet message
MockBufferManager mockBuffer3(512);
- SendCounterPacket sendPacket3(mockBuffer3);
+ SendCounterPacket sendPacket3(profilingStateMachine, mockBuffer3);
selectedCounterIds.reserve(5);
selectedCounterIds.emplace_back(100);
@@ -134,9 +193,11 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest)
BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest)
{
+ ProfilingStateMachine profilingStateMachine;
+
// Error no space left in buffer
MockBufferManager mockBuffer1(10);
- SendCounterPacket sendPacket1(mockBuffer1);
+ SendCounterPacket sendPacket1(profilingStateMachine, mockBuffer1);
auto captureTimestamp = std::chrono::steady_clock::now();
uint64_t time = static_cast<uint64_t >(captureTimestamp.time_since_epoch().count());
@@ -147,7 +208,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest)
// Packet without any counters
MockBufferManager mockBuffer2(512);
- SendCounterPacket sendPacket2(mockBuffer2);
+ SendCounterPacket sendPacket2(profilingStateMachine, mockBuffer2);
sendPacket2.SendPeriodicCounterCapturePacket(time, indexValuePairs);
auto readBuffer2 = mockBuffer2.GetReadableBuffer();
@@ -164,7 +225,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest)
// Full packet message
MockBufferManager mockBuffer3(512);
- SendCounterPacket sendPacket3(mockBuffer3);
+ SendCounterPacket sendPacket3(profilingStateMachine, mockBuffer3);
indexValuePairs.reserve(5);
indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t >(0, 100));
@@ -213,9 +274,11 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest)
uint32_t sizeUint32 = numeric_cast<uint32_t>(sizeof(uint32_t));
+ ProfilingStateMachine profilingStateMachine;
+
// Error no space left in buffer
MockBufferManager mockBuffer1(10);
- SendCounterPacket sendPacket1(mockBuffer1);
+ SendCounterPacket sendPacket1(profilingStateMachine, mockBuffer1);
BOOST_CHECK_THROW(sendPacket1.SendStreamMetaDataPacket(), armnn::profiling::BufferExhaustion);
// Full metadata packet
@@ -234,7 +297,7 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest)
uint32_t packetEntries = 6;
MockBufferManager mockBuffer2(512);
- SendCounterPacket sendPacket2(mockBuffer2);
+ SendCounterPacket sendPacket2(profilingStateMachine, mockBuffer2);
sendPacket2.SendStreamMetaDataPacket();
auto readBuffer2 = mockBuffer2.GetReadableBuffer();
@@ -328,8 +391,10 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest)
BOOST_AUTO_TEST_CASE(CreateDeviceRecordTest)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a device for testing
uint16_t deviceUid = 27;
@@ -360,8 +425,10 @@ BOOST_AUTO_TEST_CASE(CreateDeviceRecordTest)
BOOST_AUTO_TEST_CASE(CreateInvalidDeviceRecordTest)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a device for testing
uint16_t deviceUid = 27;
@@ -381,8 +448,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidDeviceRecordTest)
BOOST_AUTO_TEST_CASE(CreateCounterSetRecordTest)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a counter set for testing
uint16_t counterSetUid = 27;
@@ -413,8 +482,10 @@ BOOST_AUTO_TEST_CASE(CreateCounterSetRecordTest)
BOOST_AUTO_TEST_CASE(CreateInvalidCounterSetRecordTest)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a counter set for testing
uint16_t counterSetUid = 27;
@@ -434,8 +505,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidCounterSetRecordTest)
BOOST_AUTO_TEST_CASE(CreateEventRecordTest)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a counter for testing
uint16_t counterUid = 7256;
@@ -554,8 +627,10 @@ BOOST_AUTO_TEST_CASE(CreateEventRecordTest)
BOOST_AUTO_TEST_CASE(CreateEventRecordNoUnitsTest)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a counter for testing
uint16_t counterUid = 44312;
@@ -657,8 +732,10 @@ BOOST_AUTO_TEST_CASE(CreateEventRecordNoUnitsTest)
BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest1)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a counter for testing
uint16_t counterUid = 7256;
@@ -695,8 +772,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest1)
BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest2)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a counter for testing
uint16_t counterUid = 7256;
@@ -733,8 +812,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest2)
BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest3)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a counter for testing
uint16_t counterUid = 7256;
@@ -771,8 +852,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest3)
BOOST_AUTO_TEST_CASE(CreateCategoryRecordTest)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a category for testing
const std::string categoryName = "some_category";
@@ -972,8 +1055,10 @@ BOOST_AUTO_TEST_CASE(CreateCategoryRecordTest)
BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest1)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a category for testing
const std::string categoryName = "some invalid category";
@@ -995,8 +1080,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest1)
BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest2)
{
+ ProfilingStateMachine profilingStateMachine;
+
MockBufferManager mockBuffer(0);
- SendCounterPacketTest sendCounterPacketTest(mockBuffer);
+ SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer);
// Create a category for testing
const std::string categoryName = "some_category";
@@ -1035,6 +1122,8 @@ BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest2)
BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest1)
{
+ ProfilingStateMachine profilingStateMachine;
+
// The counter directory used for testing
CounterDirectory counterDirectory;
@@ -1054,13 +1143,15 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest1)
// Buffer with not enough space
MockBufferManager mockBuffer(10);
- SendCounterPacket sendCounterPacket(mockBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory),
armnn::profiling::BufferExhaustion);
}
BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2)
{
+ ProfilingStateMachine profilingStateMachine;
+
// The counter directory used for testing
CounterDirectory counterDirectory;
@@ -1146,7 +1237,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2)
// Buffer with enough space
MockBufferManager mockBuffer(1024);
- SendCounterPacket sendCounterPacket(mockBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
BOOST_CHECK_NO_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory));
// Get the readable buffer
@@ -1535,6 +1626,8 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2)
BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest3)
{
+ ProfilingStateMachine profilingStateMachine;
+
// Using a mock counter directory that allows to register invalid objects
MockCounterDirectory counterDirectory;
@@ -1547,12 +1640,14 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest3)
// Buffer with enough space
MockBufferManager mockBuffer(1024);
- SendCounterPacket sendCounterPacket(mockBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException);
}
BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest4)
{
+ ProfilingStateMachine profilingStateMachine;
+
// Using a mock counter directory that allows to register invalid objects
MockCounterDirectory counterDirectory;
@@ -1565,12 +1660,14 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest4)
// Buffer with enough space
MockBufferManager mockBuffer(1024);
- SendCounterPacket sendCounterPacket(mockBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException);
}
BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest5)
{
+ ProfilingStateMachine profilingStateMachine;
+
// Using a mock counter directory that allows to register invalid objects
MockCounterDirectory counterDirectory;
@@ -1583,12 +1680,14 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest5)
// Buffer with enough space
MockBufferManager mockBuffer(1024);
- SendCounterPacket sendCounterPacket(mockBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException);
}
BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest6)
{
+ ProfilingStateMachine profilingStateMachine;
+
// Using a mock counter directory that allows to register invalid objects
MockCounterDirectory counterDirectory;
@@ -1617,12 +1716,14 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest6)
// Buffer with enough space
MockBufferManager mockBuffer(1024);
- SendCounterPacket sendCounterPacket(mockBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException);
}
BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest7)
{
+ ProfilingStateMachine profilingStateMachine;
+
// Using a mock counter directory that allows to register invalid objects
MockCounterDirectory counterDirectory;
@@ -1666,15 +1767,18 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest7)
// Buffer with enough space
MockBufferManager mockBuffer(1024);
- SendCounterPacket sendCounterPacket(mockBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException);
}
BOOST_AUTO_TEST_CASE(SendThreadTest0)
{
+ ProfilingStateMachine profilingStateMachine;
+ SetActiveProfilingState(profilingStateMachine);
+
MockProfilingConnection mockProfilingConnection;
MockStreamCounterBuffer mockStreamCounterBuffer(0);
- SendCounterPacket sendCounterPacket(mockStreamCounterBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer);
// Try to start the send thread many times, it must only start once
@@ -1694,11 +1798,14 @@ BOOST_AUTO_TEST_CASE(SendThreadTest0)
BOOST_AUTO_TEST_CASE(SendThreadTest1)
{
+ ProfilingStateMachine profilingStateMachine;
+ SetActiveProfilingState(profilingStateMachine);
+
unsigned int totalWrittenSize = 0;
MockProfilingConnection mockProfilingConnection;
MockStreamCounterBuffer mockStreamCounterBuffer(1024);
- SendCounterPacket sendCounterPacket(mockStreamCounterBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer);
sendCounterPacket.Start(mockProfilingConnection);
// Interleaving writes and reads to/from the buffer with pauses to test that the send thread actually waits for
@@ -1802,11 +1909,14 @@ BOOST_AUTO_TEST_CASE(SendThreadTest1)
BOOST_AUTO_TEST_CASE(SendThreadTest2)
{
+ ProfilingStateMachine profilingStateMachine;
+ SetActiveProfilingState(profilingStateMachine);
+
unsigned int totalWrittenSize = 0;
MockProfilingConnection mockProfilingConnection;
MockStreamCounterBuffer mockStreamCounterBuffer(1024);
- SendCounterPacket sendCounterPacket(mockStreamCounterBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer);
sendCounterPacket.Start(mockProfilingConnection);
// Adding many spurious "ready to read" signals throughout the test to check that the send thread is
@@ -1922,11 +2032,14 @@ BOOST_AUTO_TEST_CASE(SendThreadTest2)
BOOST_AUTO_TEST_CASE(SendThreadTest3)
{
+ ProfilingStateMachine profilingStateMachine;
+ SetActiveProfilingState(profilingStateMachine);
+
unsigned int totalWrittenSize = 0;
MockProfilingConnection mockProfilingConnection;
MockStreamCounterBuffer mockStreamCounterBuffer(1024);
- SendCounterPacket sendCounterPacket(mockStreamCounterBuffer);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer);
sendCounterPacket.Start(mockProfilingConnection);
// Not using pauses or "grace periods" to stress test the send thread
@@ -2025,9 +2138,12 @@ BOOST_AUTO_TEST_CASE(SendThreadTest3)
BOOST_AUTO_TEST_CASE(SendThreadBufferTest)
{
+ ProfilingStateMachine profilingStateMachine;
+ SetActiveProfilingState(profilingStateMachine);
+
MockProfilingConnection mockProfilingConnection;
BufferManager bufferManager(1, 1024);
- SendCounterPacket sendCounterPacket(bufferManager, -1);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager, -1);
sendCounterPacket.Start(mockProfilingConnection);
// Interleaving writes and reads to/from the buffer with pauses to test that the send thread actually waits for
@@ -2152,9 +2268,12 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest)
BOOST_AUTO_TEST_CASE(SendThreadBufferTest1)
{
- MockWriteProfilingConnection mockProfilingConnection;
+ ProfilingStateMachine profilingStateMachine;
+ SetActiveProfilingState(profilingStateMachine);
+
+ MockProfilingConnection mockProfilingConnection;
BufferManager bufferManager(3, 1024);
- SendCounterPacket sendCounterPacket(bufferManager, -1);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager, -1);
sendCounterPacket.Start(mockProfilingConnection);
// SendStreamMetaDataPacket
@@ -2203,8 +2322,7 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest1)
BOOST_TEST(reservedBuffer.get());
// Check that data was actually written to the profiling connection in any order
- std::vector<uint32_t> writtenData = mockProfilingConnection.GetWrittenData();
- std::vector<uint32_t> expectedOutput{streamMetadataPacketsize, 32, 28};
+ const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
BOOST_TEST(writtenData.size() == 3);
bool foundStreamMetaDataPacket =
std::find(writtenData.begin(), writtenData.end(), streamMetadataPacketsize) != writtenData.end();
@@ -2215,4 +2333,113 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest1)
BOOST_TEST(foundPeriodicCounterCapturePacket);
}
+BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket1)
+{
+ ProfilingStateMachine profilingStateMachine;
+
+ MockProfilingConnection mockProfilingConnection;
+ BufferManager bufferManager(3, 1024);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager);
+ sendCounterPacket.Start(mockProfilingConnection);
+
+ // The profiling state is set to "Uninitialized", so the send thread should throw an exception
+
+ // Wait a bit to make sure that the send thread is properly started
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ BOOST_CHECK_THROW(sendCounterPacket.Stop(), armnn::RuntimeException);
+}
+
+BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket2)
+{
+ ProfilingStateMachine profilingStateMachine;
+ SetNotConnectedProfilingState(profilingStateMachine);
+
+ MockProfilingConnection mockProfilingConnection;
+ BufferManager bufferManager(3, 1024);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager);
+ sendCounterPacket.Start(mockProfilingConnection);
+
+ // The profiling state is set to "NotConnected", so the send thread should throw an exception
+
+ // Wait a bit to make sure that the send thread is properly started
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ BOOST_CHECK_THROW(sendCounterPacket.Stop(), armnn::RuntimeException);
+}
+
+BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket3)
+{
+ ProfilingStateMachine profilingStateMachine;
+ SetWaitingForAckProfilingState(profilingStateMachine);
+
+ // Calculate the size of a Stream Metadata packet
+ std::string processName = GetProcessName().substr(0, 60);
+ unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast<unsigned int>(processName.size()) + 1;
+ unsigned int streamMetadataPacketsize = 118 + processNameSize;
+
+ MockProfilingConnection mockProfilingConnection;
+ BufferManager bufferManager(3, 1024);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager);
+ sendCounterPacket.Start(mockProfilingConnection);
+
+ // The profiling state is set to "WaitingForAck", so the send thread should send a Stream Metadata packet
+
+ // Wait for a bit to make sure that we get the packet
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ BOOST_CHECK_NO_THROW(sendCounterPacket.Stop());
+
+ // Check that the buffer contains one Stream Metadata packet
+ const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
+ BOOST_TEST(writtenData.size() == 1);
+ BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
+}
+
+BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket4)
+{
+ ProfilingStateMachine profilingStateMachine;
+ SetWaitingForAckProfilingState(profilingStateMachine);
+
+ // Calculate the size of a Stream Metadata packet
+ std::string processName = GetProcessName().substr(0, 60);
+ unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast<unsigned int>(processName.size()) + 1;
+ unsigned int streamMetadataPacketsize = 118 + processNameSize;
+
+ MockProfilingConnection mockProfilingConnection;
+ BufferManager bufferManager(3, 1024);
+ SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager);
+ sendCounterPacket.Start(mockProfilingConnection);
+
+ // The profiling state is set to "WaitingForAck", so the send thread should send a Stream Metadata packet
+
+ // Wait for a bit to make sure that we get the packet
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ // Check that the profiling state is still "WaitingForAck"
+ BOOST_TEST((profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck));
+
+ // Check that the buffer contains one Stream Metadata packet
+ const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
+ BOOST_TEST(writtenData.size() == 1);
+ BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
+
+ mockProfilingConnection.Clear();
+
+ // Try triggering a new buffer read
+ sendCounterPacket.SetReadyToRead();
+
+ // Wait for a bit to make sure that we get the packet
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ // Check that the profiling state is still "WaitingForAck"
+ BOOST_TEST((profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck));
+
+ // Check that the buffer contains one Stream Metadata packet
+ BOOST_TEST(writtenData.size() == 1);
+ BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
+
+ BOOST_CHECK_NO_THROW(sendCounterPacket.Stop());
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp
index 0323f62d80..cae02b064d 100644
--- a/src/profiling/test/SendCounterPacketTests.hpp
+++ b/src/profiling/test/SendCounterPacketTests.hpp
@@ -19,7 +19,6 @@ namespace armnn
namespace profiling
{
-
class MockProfilingConnection : public IProfilingConnection
{
public:
@@ -33,38 +32,20 @@ public:
bool WritePacket(const unsigned char* buffer, uint32_t length) override
{
- return buffer != nullptr && length > 0;
- }
-
- Packet ReadPacket(uint32_t timeout) override { return Packet(); }
-
-private:
- bool m_IsOpen;
-};
-
-class MockWriteProfilingConnection : public IProfilingConnection
-{
-public:
- MockWriteProfilingConnection()
- : m_IsOpen(true)
- {}
-
- bool IsOpen() override { return m_IsOpen; }
-
- void Close() override { m_IsOpen = false; }
+ if (buffer == nullptr || length == 0)
+ {
+ return false;
+ }
- bool WritePacket(const unsigned char* buffer, uint32_t length) override
- {
m_WrittenData.push_back(length);
- return buffer != nullptr && length > 0;
+ return true;
}
Packet ReadPacket(uint32_t timeout) override { return Packet(); }
- std::vector<uint32_t> GetWrittenData()
- {
- return m_WrittenData;
- }
+ const std::vector<uint32_t>& GetWrittenData() const { return m_WrittenData; }
+
+ void Clear() { m_WrittenData.clear(); }
private:
bool m_IsOpen;
@@ -497,8 +478,8 @@ private:
class SendCounterPacketTest : public SendCounterPacket
{
public:
- SendCounterPacketTest(IBufferManager& buffer)
- : SendCounterPacket(buffer)
+ SendCounterPacketTest(ProfilingStateMachine& profilingStateMachine, IBufferManager& buffer)
+ : SendCounterPacket(profilingStateMachine, buffer)
{}
bool CreateDeviceRecordTest(const DevicePtr& device,