diff options
Diffstat (limited to 'src/profiling/test')
-rw-r--r-- | src/profiling/test/ProfilingTests.cpp | 167 | ||||
-rw-r--r-- | src/profiling/test/SendCounterPacketTests.cpp | 1347 | ||||
-rw-r--r-- | src/profiling/test/SendCounterPacketTests.hpp | 243 |
3 files changed, 1753 insertions, 4 deletions
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 51dbb07a58..e97068fbb4 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -1647,7 +1647,6 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) BOOST_TEST(((headerWord0 >> 16) & 0x3FF) == 4); // packet id BOOST_TEST(headerWord1 == 4); // data lenght BOOST_TEST(period == 11); // capture period - } BOOST_AUTO_TEST_CASE(CheckSocketProfilingConnection) @@ -1656,4 +1655,170 @@ BOOST_AUTO_TEST_CASE(CheckSocketProfilingConnection) BOOST_CHECK_THROW(new SocketProfilingConnection(), armnn::Exception); } +BOOST_AUTO_TEST_CASE(SwTraceIsValidCharTest) +{ + // Only ASCII 7-bit encoding supported + for (unsigned char c = 0; c < 128; c++) + { + BOOST_CHECK(SwTraceCharPolicy::IsValidChar(c)); + } + + // Not ASCII + for (unsigned char c = 255; c >= 128; c++) + { + BOOST_CHECK(!SwTraceCharPolicy::IsValidChar(c)); + } +} + +BOOST_AUTO_TEST_CASE(SwTraceIsValidNameCharTest) +{ + // Only alpha-numeric and underscore ASCII 7-bit encoding supported + const unsigned char validChars[] = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; + for (unsigned char i = 0; i < sizeof(validChars) / sizeof(validChars[0]) - 1; i++) + { + BOOST_CHECK(SwTraceNameCharPolicy::IsValidChar(validChars[i])); + } + + // Non alpha-numeric chars + for (unsigned char c = 0; c < 48; c++) + { + BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c)); + } + for (unsigned char c = 58; c < 65; c++) + { + BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c)); + } + for (unsigned char c = 91; c < 95; c++) + { + BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c)); + } + for (unsigned char c = 96; c < 97; c++) + { + BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c)); + } + for (unsigned char c = 123; c < 128; c++) + { + BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c)); + } + + // Not ASCII + for (unsigned char c = 255; c >= 128; c++) + { + BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c)); + } +} + +BOOST_AUTO_TEST_CASE(IsValidSwTraceStringTest) +{ + // Valid SWTrace strings + BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("")); + BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("_")); + BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("0123")); + BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("valid_string")); + BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("VALID_string_456")); + BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>(" ")); + BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("valid string")); + BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("!$%")); + BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("valid|\\~string#123")); + + // Invalid SWTrace strings + BOOST_CHECK(!IsValidSwTraceString<SwTraceCharPolicy>("€£")); + BOOST_CHECK(!IsValidSwTraceString<SwTraceCharPolicy>("invalid‡string")); + BOOST_CHECK(!IsValidSwTraceString<SwTraceCharPolicy>("12Ž34")); +} + +BOOST_AUTO_TEST_CASE(IsValidSwTraceNameStringTest) +{ + // Valid SWTrace name strings + BOOST_CHECK(IsValidSwTraceString<SwTraceNameCharPolicy>("")); + BOOST_CHECK(IsValidSwTraceString<SwTraceNameCharPolicy>("_")); + BOOST_CHECK(IsValidSwTraceString<SwTraceNameCharPolicy>("0123")); + BOOST_CHECK(IsValidSwTraceString<SwTraceNameCharPolicy>("valid_string")); + BOOST_CHECK(IsValidSwTraceString<SwTraceNameCharPolicy>("VALID_string_456")); + + // Invalid SWTrace name strings + BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>(" ")); + BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("invalid string")); + BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("!$%")); + BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("invalid|\\~string#123")); + BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("€£")); + BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("invalid‡string")); + BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("12Ž34")); +} + +template <typename SwTracePolicy> +void StringToSwTraceStringTestHelper(const std::string& testString, std::vector<uint32_t> buffer, size_t expectedSize) +{ + // Convert the test string to a SWTrace string + BOOST_CHECK(StringToSwTraceString<SwTracePolicy>(testString, buffer)); + + // The buffer must contain at least the length of the string + BOOST_CHECK(!buffer.empty()); + + // The buffer must be of the expected size (in words) + BOOST_CHECK(buffer.size() == expectedSize); + + // The first word of the byte must be the length of the string including the null-terminator + BOOST_CHECK(buffer[0] == testString.size() + 1); + + // The contents of the buffer must match the test string + BOOST_CHECK(std::memcmp(testString.data(), buffer.data() + 1, testString.size()) == 0); + + // The buffer must include the null-terminator at the end of the string + size_t nullTerminatorIndex = sizeof(uint32_t) + testString.size(); + BOOST_CHECK(reinterpret_cast<unsigned char*>(buffer.data())[nullTerminatorIndex] == '\0'); +} + +BOOST_AUTO_TEST_CASE(StringToSwTraceStringTest) +{ + std::vector<uint32_t> buffer; + + // Valid SWTrace strings (expected size in words) + StringToSwTraceStringTestHelper<SwTraceCharPolicy>("", buffer, 2); + StringToSwTraceStringTestHelper<SwTraceCharPolicy>("_", buffer, 2); + StringToSwTraceStringTestHelper<SwTraceCharPolicy>("0123", buffer, 3); + StringToSwTraceStringTestHelper<SwTraceCharPolicy>("valid_string", buffer, 5); + StringToSwTraceStringTestHelper<SwTraceCharPolicy>("VALID_string_456", buffer, 6); + StringToSwTraceStringTestHelper<SwTraceCharPolicy>(" ", buffer, 2); + StringToSwTraceStringTestHelper<SwTraceCharPolicy>("valid string", buffer, 5); + StringToSwTraceStringTestHelper<SwTraceCharPolicy>("!$%", buffer, 2); + StringToSwTraceStringTestHelper<SwTraceCharPolicy>("valid|\\~string#123", buffer, 6); + + // Invalid SWTrace strings + BOOST_CHECK(!StringToSwTraceString<SwTraceCharPolicy>("€£", buffer)); + BOOST_CHECK(buffer.empty()); + BOOST_CHECK(!StringToSwTraceString<SwTraceCharPolicy>("invalid‡string", buffer)); + BOOST_CHECK(buffer.empty()); + BOOST_CHECK(!StringToSwTraceString<SwTraceCharPolicy>("12Ž34", buffer)); + BOOST_CHECK(buffer.empty()); +} + +BOOST_AUTO_TEST_CASE(StringToSwTraceNameStringTest) +{ + std::vector<uint32_t> buffer; + + // Valid SWTrace namestrings (expected size in words) + StringToSwTraceStringTestHelper<SwTraceNameCharPolicy>("", buffer, 2); + StringToSwTraceStringTestHelper<SwTraceNameCharPolicy>("_", buffer, 2); + StringToSwTraceStringTestHelper<SwTraceNameCharPolicy>("0123", buffer, 3); + StringToSwTraceStringTestHelper<SwTraceNameCharPolicy>("valid_string", buffer, 5); + StringToSwTraceStringTestHelper<SwTraceNameCharPolicy>("VALID_string_456", buffer, 6); + + // Invalid SWTrace namestrings + BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>(" ", buffer)); + BOOST_CHECK(buffer.empty()); + BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("invalid string", buffer)); + BOOST_CHECK(buffer.empty()); + BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("!$%", buffer)); + BOOST_CHECK(buffer.empty()); + BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("invalid|\\~string#123", buffer)); + BOOST_CHECK(buffer.empty()); + BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("€£", buffer)); + BOOST_CHECK(buffer.empty()); + BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("invalid‡string", buffer)); + BOOST_CHECK(buffer.empty()); + BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("12Ž34", buffer)); + BOOST_CHECK(buffer.empty()); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp index c060f168cd..f32c3684c1 100644 --- a/src/profiling/test/SendCounterPacketTests.cpp +++ b/src/profiling/test/SendCounterPacketTests.cpp @@ -5,11 +5,12 @@ #include "SendCounterPacketTests.hpp" -#include <ProfilingUtils.hpp> #include <EncodeVersion.hpp> +#include <ProfilingUtils.hpp> #include <SendCounterPacket.hpp> #include <armnn/Exceptions.hpp> +#include <armnn/Conversion.hpp> #include <boost/test/unit_test.hpp> #include <boost/numeric/conversion/cast.hpp> @@ -17,6 +18,8 @@ #include <chrono> #include <iostream> +using namespace armnn::profiling; + BOOST_AUTO_TEST_SUITE(SendCounterPacketTests) BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest) @@ -48,7 +51,6 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest) sendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, selectedCounterIds); BOOST_TEST(strcmp(buffer, "SendPeriodicCounterSelectionPacket") == 0); - } BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest) @@ -309,5 +311,1346 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest) BOOST_TEST(offset == totalLength); } +BOOST_AUTO_TEST_CASE(CreateDeviceRecordTest) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a device for testing + uint16_t deviceUid = 27; + const std::string deviceName = "some_device"; + uint16_t deviceCores = 3; + const DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, deviceCores); + + // Create a device record + SendCounterPacket::DeviceRecord deviceRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateDeviceRecordTest(device, deviceRecord, errorMessage); + + BOOST_CHECK(result); + BOOST_CHECK(errorMessage.empty()); + BOOST_CHECK(deviceRecord.size() == 6); // Size in words: header [2] + device name [4] + + uint16_t deviceRecordWord0[] + { + static_cast<uint16_t>(deviceRecord[0] >> 16), + static_cast<uint16_t>(deviceRecord[0]) + }; + BOOST_CHECK(deviceRecordWord0[0] == deviceUid); // uid + BOOST_CHECK(deviceRecordWord0[1] == deviceCores); // cores + BOOST_CHECK(deviceRecord[1] == 0); // name_offset + BOOST_CHECK(deviceRecord[2] == deviceName.size() + 1); // The length of the SWTrace string (name) + BOOST_CHECK(std::memcmp(deviceRecord.data() + 3, deviceName.data(), deviceName.size()) == 0); // name +} + +BOOST_AUTO_TEST_CASE(CreateInvalidDeviceRecordTest) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a device for testing + uint16_t deviceUid = 27; + const std::string deviceName = "some€£invalid‡device"; + uint16_t deviceCores = 3; + const DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, deviceCores); + + // Create a device record + SendCounterPacket::DeviceRecord deviceRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateDeviceRecordTest(device, deviceRecord, errorMessage); + + BOOST_CHECK(!result); + BOOST_CHECK(!errorMessage.empty()); + BOOST_CHECK(deviceRecord.empty()); +} + +BOOST_AUTO_TEST_CASE(CreateCounterSetRecordTest) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a counter set for testing + uint16_t counterSetUid = 27; + const std::string counterSetName = "some_counter_set"; + uint16_t counterSetCount = 3421; + const CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, counterSetCount); + + // Create a counter set record + SendCounterPacket::CounterSetRecord counterSetRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateCounterSetRecordTest(counterSet, counterSetRecord, errorMessage); + + BOOST_CHECK(result); + BOOST_CHECK(errorMessage.empty()); + BOOST_CHECK(counterSetRecord.size() == 8); // Size in words: header [2] + counter set name [6] + + uint16_t counterSetRecordWord0[] + { + static_cast<uint16_t>(counterSetRecord[0] >> 16), + static_cast<uint16_t>(counterSetRecord[0]) + }; + BOOST_CHECK(counterSetRecordWord0[0] == counterSetUid); // uid + BOOST_CHECK(counterSetRecordWord0[1] == counterSetCount); // cores + BOOST_CHECK(counterSetRecord[1] == 0); // name_offset + BOOST_CHECK(counterSetRecord[2] == counterSetName.size() + 1); // The length of the SWTrace string (name) + BOOST_CHECK(std::memcmp(counterSetRecord.data() + 3, counterSetName.data(), counterSetName.size()) == 0); // name +} + +BOOST_AUTO_TEST_CASE(CreateInvalidCounterSetRecordTest) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a counter set for testing + uint16_t counterSetUid = 27; + const std::string counterSetName = "some invalid_counter€£set"; + uint16_t counterSetCount = 3421; + const CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, counterSetCount); + + // Create a counter set record + SendCounterPacket::CounterSetRecord counterSetRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateCounterSetRecordTest(counterSet, counterSetRecord, errorMessage); + + BOOST_CHECK(!result); + BOOST_CHECK(!errorMessage.empty()); + BOOST_CHECK(counterSetRecord.empty()); +} + +BOOST_AUTO_TEST_CASE(CreateEventRecordTest) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a counter for testing + uint16_t counterUid = 7256; + uint16_t maxCounterUid = 132; + uint16_t deviceUid = 132; + uint16_t counterSetUid = 4497; + uint16_t counterClass = 1; + uint16_t counterInterpolation = 1; + double counterMultiplier = 1234.567f; + const std::string counterName = "some_valid_counter"; + const std::string counterDescription = "a_counter_for_testing"; + const std::string counterUnits = "Mrads2"; + const CounterPtr counter = std::make_unique<Counter>(counterUid, + maxCounterUid, + counterClass, + counterInterpolation, + counterMultiplier, + counterName, + counterDescription, + counterUnits, + deviceUid, + counterSetUid); + BOOST_ASSERT(counter); + + // Create an event record + SendCounterPacket::EventRecord eventRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateEventRecordTest(counter, eventRecord, errorMessage); + + BOOST_CHECK(result); + BOOST_CHECK(errorMessage.empty()); + BOOST_CHECK(eventRecord.size() == 24); // Size in words: header [8] + counter name [6] + description [7] + units [3] + + uint16_t eventRecordWord0[] + { + static_cast<uint16_t>(eventRecord[0] >> 16), + static_cast<uint16_t>(eventRecord[0]) + }; + uint16_t eventRecordWord1[] + { + static_cast<uint16_t>(eventRecord[1] >> 16), + static_cast<uint16_t>(eventRecord[1]) + }; + uint16_t eventRecordWord2[] + { + static_cast<uint16_t>(eventRecord[2] >> 16), + static_cast<uint16_t>(eventRecord[2]) + }; + uint32_t eventRecordWord34[] + { + eventRecord[3], + eventRecord[4] + }; + BOOST_CHECK(eventRecordWord0[0] == maxCounterUid); // max_counter_uid + BOOST_CHECK(eventRecordWord0[1] == counterUid); // counter_uid + BOOST_CHECK(eventRecordWord1[0] == deviceUid); // device + BOOST_CHECK(eventRecordWord1[1] == counterSetUid); // counter_set + BOOST_CHECK(eventRecordWord2[0] == counterClass); // class + BOOST_CHECK(eventRecordWord2[1] == counterInterpolation); // interpolation + BOOST_CHECK(std::memcmp(eventRecordWord34, &counterMultiplier, sizeof(counterMultiplier)) == 0); // multiplier + + ARMNN_NO_CONVERSION_WARN_BEGIN + uint32_t counterNameOffset = 0; // The name is the first item in pool + uint32_t counterDescriptionOffset = counterNameOffset + // Counter name offset + 4u + // Counter name length (uint32_t) + counterName.size() + // 18u + 1u + // Null-terminator + 1u; // Rounding to the next word + size_t counterUnitsOffset = counterDescriptionOffset + // Counter description offset + 4u + // Counter description length (uint32_t) + counterDescription.size() + // 21u + 1u + // Null-terminator + 2u; // Rounding to the next word + ARMNN_NO_CONVERSION_WARN_END + + BOOST_CHECK(eventRecord[5] == counterNameOffset); // name_offset + BOOST_CHECK(eventRecord[6] == counterDescriptionOffset); // description_offset + BOOST_CHECK(eventRecord[7] == counterUnitsOffset); // units_offset + + auto eventRecordPool = reinterpret_cast<unsigned char*>(eventRecord.data() + 8u); // The start of the pool + size_t uint32_t_size = sizeof(uint32_t); + + // The length of the SWTrace string (name) + BOOST_CHECK(eventRecordPool[counterNameOffset] == counterName.size() + 1); + // The counter name + BOOST_CHECK(std::memcmp(eventRecordPool + + counterNameOffset + // Offset + uint32_t_size /* The length of the name */, + counterName.data(), + counterName.size()) == 0); // name + // The null-terminator at the end of the name + BOOST_CHECK(eventRecordPool[counterNameOffset + uint32_t_size + counterName.size()] == '\0'); + + // The length of the SWTrace string (description) + BOOST_CHECK(eventRecordPool[counterDescriptionOffset] == counterDescription.size() + 1); + // The counter description + BOOST_CHECK(std::memcmp(eventRecordPool + + counterDescriptionOffset + // Offset + uint32_t_size /* The length of the description */, + counterDescription.data(), + counterDescription.size()) == 0); // description + // The null-terminator at the end of the description + BOOST_CHECK(eventRecordPool[counterDescriptionOffset + uint32_t_size + counterDescription.size()] == '\0'); + + // The length of the SWTrace namestring (units) + BOOST_CHECK(eventRecordPool[counterUnitsOffset] == counterUnits.size() + 1); + // The counter units + BOOST_CHECK(std::memcmp(eventRecordPool + + counterUnitsOffset + // Offset + uint32_t_size /* The length of the units */, + counterUnits.data(), + counterUnits.size()) == 0); // units + // The null-terminator at the end of the units + BOOST_CHECK(eventRecordPool[counterUnitsOffset + uint32_t_size + counterUnits.size()] == '\0'); +} + +BOOST_AUTO_TEST_CASE(CreateEventRecordNoUnitsTest) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a counter for testing + uint16_t counterUid = 44312; + uint16_t maxCounterUid = 345; + uint16_t deviceUid = 101; + uint16_t counterSetUid = 34035; + uint16_t counterClass = 0; + uint16_t counterInterpolation = 1; + double counterMultiplier = 4435.0023f; + const std::string counterName = "some_valid_counter"; + const std::string counterDescription = "a_counter_for_testing"; + const CounterPtr counter = std::make_unique<Counter>(counterUid, + maxCounterUid, + counterClass, + counterInterpolation, + counterMultiplier, + counterName, + counterDescription, + "", + deviceUid, + counterSetUid); + BOOST_ASSERT(counter); + + // Create an event record + SendCounterPacket::EventRecord eventRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateEventRecordTest(counter, eventRecord, errorMessage); + + BOOST_CHECK(result); + BOOST_CHECK(errorMessage.empty()); + BOOST_CHECK(eventRecord.size() == 21); // Size in words: header [8] + counter name [6] + description [7] + + uint16_t eventRecordWord0[] + { + static_cast<uint16_t>(eventRecord[0] >> 16), + static_cast<uint16_t>(eventRecord[0]) + }; + uint16_t eventRecordWord1[] + { + static_cast<uint16_t>(eventRecord[1] >> 16), + static_cast<uint16_t>(eventRecord[1]) + }; + uint16_t eventRecordWord2[] + { + static_cast<uint16_t>(eventRecord[2] >> 16), + static_cast<uint16_t>(eventRecord[2]) + }; + uint32_t eventRecordWord34[] + { + eventRecord[3], + eventRecord[4] + }; + BOOST_CHECK(eventRecordWord0[0] == maxCounterUid); // max_counter_uid + BOOST_CHECK(eventRecordWord0[1] == counterUid); // counter_uid + BOOST_CHECK(eventRecordWord1[0] == deviceUid); // device + BOOST_CHECK(eventRecordWord1[1] == counterSetUid); // counter_set + BOOST_CHECK(eventRecordWord2[0] == counterClass); // class + BOOST_CHECK(eventRecordWord2[1] == counterInterpolation); // interpolation + BOOST_CHECK(std::memcmp(eventRecordWord34, &counterMultiplier, sizeof(counterMultiplier)) == 0); // multiplier + + ARMNN_NO_CONVERSION_WARN_BEGIN + uint32_t counterNameOffset = 0; // The name is the first item in pool + uint32_t counterDescriptionOffset = counterNameOffset + // Counter name offset + 4u + // Counter name length (uint32_t) + counterName.size() + // 18u + 1u + // Null-terminator + 1u; // Rounding to the next word + ARMNN_NO_CONVERSION_WARN_END + + BOOST_CHECK(eventRecord[5] == counterNameOffset); // name_offset + BOOST_CHECK(eventRecord[6] == counterDescriptionOffset); // description_offset + BOOST_CHECK(eventRecord[7] == 0); // units_offset + + auto eventRecordPool = reinterpret_cast<unsigned char*>(eventRecord.data() + 8u); // The start of the pool + size_t uint32_t_size = sizeof(uint32_t); + + // The length of the SWTrace string (name) + BOOST_CHECK(eventRecordPool[counterNameOffset] == counterName.size() + 1); + // The counter name + BOOST_CHECK(std::memcmp(eventRecordPool + + counterNameOffset + // Offset + uint32_t_size, // The length of the name + counterName.data(), + counterName.size()) == 0); // name + // The null-terminator at the end of the name + BOOST_CHECK(eventRecordPool[counterNameOffset + uint32_t_size + counterName.size()] == '\0'); + + // The length of the SWTrace string (description) + BOOST_CHECK(eventRecordPool[counterDescriptionOffset] == counterDescription.size() + 1); + // The counter description + BOOST_CHECK(std::memcmp(eventRecordPool + + counterDescriptionOffset + // Offset + uint32_t_size, // The length of the description + counterDescription.data(), + counterDescription.size()) == 0); // description + // The null-terminator at the end of the description + BOOST_CHECK(eventRecordPool[counterDescriptionOffset + uint32_t_size + counterDescription.size()] == '\0'); +} + +BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest1) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a counter for testing + uint16_t counterUid = 7256; + uint16_t maxCounterUid = 132; + uint16_t deviceUid = 132; + uint16_t counterSetUid = 4497; + uint16_t counterClass = 1; + uint16_t counterInterpolation = 1; + double counterMultiplier = 1234.567f; + const std::string counterName = "some_invalid_counter £££"; // Invalid name + const std::string counterDescription = "a_counter_for_testing"; + const std::string counterUnits = "Mrads2"; + const CounterPtr counter = std::make_unique<Counter>(counterUid, + maxCounterUid, + counterClass, + counterInterpolation, + counterMultiplier, + counterName, + counterDescription, + counterUnits, + deviceUid, + counterSetUid); + BOOST_ASSERT(counter); + + // Create an event record + SendCounterPacket::EventRecord eventRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateEventRecordTest(counter, eventRecord, errorMessage); + + BOOST_CHECK(!result); + BOOST_CHECK(!errorMessage.empty()); + BOOST_CHECK(eventRecord.empty()); +} + +BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest2) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a counter for testing + uint16_t counterUid = 7256; + uint16_t maxCounterUid = 132; + uint16_t deviceUid = 132; + uint16_t counterSetUid = 4497; + uint16_t counterClass = 1; + uint16_t counterInterpolation = 1; + double counterMultiplier = 1234.567f; + const std::string counterName = "some_invalid_counter"; + const std::string counterDescription = "an invalid d€scription"; // Invalid description + const std::string counterUnits = "Mrads2"; + const CounterPtr counter = std::make_unique<Counter>(counterUid, + maxCounterUid, + counterClass, + counterInterpolation, + counterMultiplier, + counterName, + counterDescription, + counterUnits, + deviceUid, + counterSetUid); + BOOST_ASSERT(counter); + + // Create an event record + SendCounterPacket::EventRecord eventRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateEventRecordTest(counter, eventRecord, errorMessage); + + BOOST_CHECK(!result); + BOOST_CHECK(!errorMessage.empty()); + BOOST_CHECK(eventRecord.empty()); +} + +BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest3) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a counter for testing + uint16_t counterUid = 7256; + uint16_t maxCounterUid = 132; + uint16_t deviceUid = 132; + uint16_t counterSetUid = 4497; + uint16_t counterClass = 1; + uint16_t counterInterpolation = 1; + double counterMultiplier = 1234.567f; + const std::string counterName = "some_invalid_counter"; + const std::string counterDescription = "a valid description"; + const std::string counterUnits = "Mrad s2"; // Invalid units + const CounterPtr counter = std::make_unique<Counter>(counterUid, + maxCounterUid, + counterClass, + counterInterpolation, + counterMultiplier, + counterName, + counterDescription, + counterUnits, + deviceUid, + counterSetUid); + BOOST_ASSERT(counter); + + // Create an event record + SendCounterPacket::EventRecord eventRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateEventRecordTest(counter, eventRecord, errorMessage); + + BOOST_CHECK(!result); + BOOST_CHECK(!errorMessage.empty()); + BOOST_CHECK(eventRecord.empty()); +} + +BOOST_AUTO_TEST_CASE(CreateCategoryRecordTest) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a category for testing + const std::string categoryName = "some_category"; + uint16_t deviceUid = 1302; + uint16_t counterSetUid = 20734; + const CategoryPtr category = std::make_unique<Category>(categoryName, deviceUid, counterSetUid); + BOOST_ASSERT(category); + category->m_Counters = { 11u, 23u, 5670u }; + + // Create a collection of counters + Counters counters; + counters.insert(std::make_pair<uint16_t, CounterPtr>(11, + CounterPtr(new Counter(11, + 1234, + 0, + 1, + 534.0003f, + "counter1", + "the first counter", + "millipi2", + 0, + 0)))); + counters.insert(std::make_pair<uint16_t, CounterPtr>(23, + CounterPtr(new Counter(23, + 344, + 1, + 1, + 534.0003f, + "this is counter 2", + "the second counter", + "", + 0, + 0)))); + counters.insert(std::make_pair<uint16_t, CounterPtr>(5670, + CounterPtr(new Counter(5670, + 31, + 0, + 0, + 534.0003f, + "and this is number 3", + "the third counter", + "blah_per_second", + 0, + 0)))); + Counter* counter1 = counters.find(11)->second.get(); + Counter* counter2 = counters.find(23)->second.get(); + Counter* counter3 = counters.find(5670)->second.get(); + BOOST_ASSERT(counter1); + BOOST_ASSERT(counter2); + BOOST_ASSERT(counter3); + uint16_t categoryEventCount = boost::numeric_cast<uint16_t>(counters.size()); + + // Create a category record + SendCounterPacket::CategoryRecord categoryRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateCategoryRecordTest(category, counters, categoryRecord, errorMessage); + + BOOST_CHECK(result); + BOOST_CHECK(errorMessage.empty()); + BOOST_CHECK(categoryRecord.size() == 80); // Size in words: header [4] + event pointer table [3] + + // category name [5] + event records [68 = 22 + 20 + 26] + + uint16_t categoryRecordWord0[] + { + static_cast<uint16_t>(categoryRecord[0] >> 16), + static_cast<uint16_t>(categoryRecord[0]) + }; + uint16_t categoryRecordWord1[] + { + static_cast<uint16_t>(categoryRecord[1] >> 16), + static_cast<uint16_t>(categoryRecord[1]) + }; + BOOST_CHECK(categoryRecordWord0[0] == deviceUid); // device + BOOST_CHECK(categoryRecordWord0[1] == counterSetUid); // counter_set + BOOST_CHECK(categoryRecordWord1[0] == categoryEventCount); // event_count + BOOST_CHECK(categoryRecordWord1[1] == 0); // reserved + + size_t uint32_t_size = sizeof(uint32_t); + + ARMNN_NO_CONVERSION_WARN_BEGIN + uint32_t eventPointerTableOffset = 0; // The event pointer table is the first item in pool + uint32_t categoryNameOffset = eventPointerTableOffset + // Event pointer table offset + categoryEventCount * uint32_t_size; // The size of the event pointer table + ARMNN_NO_CONVERSION_WARN_END + + BOOST_CHECK(categoryRecord[2] == eventPointerTableOffset); // event_pointer_table_offset + BOOST_CHECK(categoryRecord[3] == categoryNameOffset); // name_offset + + auto categoryRecordPool = reinterpret_cast<unsigned char*>(categoryRecord.data() + 4u); // The start of the pool + + // The event pointer table + uint32_t eventRecord0Offset = categoryRecordPool[eventPointerTableOffset + 0 * uint32_t_size]; + uint32_t eventRecord1Offset = categoryRecordPool[eventPointerTableOffset + 1 * uint32_t_size]; + uint32_t eventRecord2Offset = categoryRecordPool[eventPointerTableOffset + 2 * uint32_t_size]; + BOOST_CHECK(eventRecord0Offset == 32); + BOOST_CHECK(eventRecord1Offset == 120); + BOOST_CHECK(eventRecord2Offset == 200); + + // The length of the SWTrace namestring (name) + BOOST_CHECK(categoryRecordPool[categoryNameOffset] == categoryName.size() + 1); + // The category name + BOOST_CHECK(std::memcmp(categoryRecordPool + + categoryNameOffset + // Offset + uint32_t_size, // The length of the name + categoryName.data(), + categoryName.size()) == 0); // name + // The null-terminator at the end of the name + BOOST_CHECK(categoryRecordPool[categoryNameOffset + uint32_t_size + categoryName.size()] == '\0'); + + // For brevity, checking only the UIDs, max counter UIDs and names of the counters in the event records, + // as the event records already have a number of unit tests dedicated to them + + // Counter1 UID and max counter UID + uint16_t eventRecord0Word0[2] = { 0u, 0u }; + std::memcpy(eventRecord0Word0, categoryRecordPool + eventRecord0Offset, sizeof(eventRecord0Word0)); + BOOST_CHECK(eventRecord0Word0[0] == counter1->m_Uid); + BOOST_CHECK(eventRecord0Word0[1] == counter1->m_MaxCounterUid); + + // Counter1 name + uint32_t counter1NameOffset = 0; + std::memcpy(&counter1NameOffset, categoryRecordPool + eventRecord0Offset + 5u * uint32_t_size, uint32_t_size); + BOOST_CHECK(counter1NameOffset == 0); + // The length of the SWTrace string (name) + BOOST_CHECK(categoryRecordPool[eventRecord0Offset + // Offset to the event record + 8u * uint32_t_size + // Offset to the event record pool + counter1NameOffset // Offset to the name of the counter + ] == counter1->m_Name.size() + 1); // The length of the name including the + // null-terminator + // The counter1 name + BOOST_CHECK(std::memcmp(categoryRecordPool + // The beginning of the category pool + eventRecord0Offset + // Offset to the event record + 8u * uint32_t_size + // Offset to the event record pool + counter1NameOffset + // Offset to the name of the counter + uint32_t_size, // The length of the name + counter1->m_Name.data(), + counter1->m_Name.size()) == 0); // name + // The null-terminator at the end of the counter1 name + BOOST_CHECK(categoryRecordPool[eventRecord0Offset + // Offset to the event record + 8u * uint32_t_size + // Offset to the event record pool + counter1NameOffset + // Offset to the name of the counter + uint32_t_size + // The length of the name + counter1->m_Name.size() // The name of the counter + ] == '\0'); + + // Counter2 name + uint32_t counter2NameOffset = 0; + std::memcpy(&counter2NameOffset, categoryRecordPool + eventRecord1Offset + 5u * uint32_t_size, uint32_t_size); + BOOST_CHECK(counter2NameOffset == 0); + // The length of the SWTrace string (name) + BOOST_CHECK(categoryRecordPool[eventRecord1Offset + // Offset to the event record + 8u * uint32_t_size + // Offset to the event record pool + counter2NameOffset // Offset to the name of the counter + ] == counter2->m_Name.size() + 1); // The length of the name including the + // null-terminator + // The counter2 name + BOOST_CHECK(std::memcmp(categoryRecordPool + // The beginning of the category pool + eventRecord1Offset + // Offset to the event record + 8u * uint32_t_size + // Offset to the event record pool + counter2NameOffset + // Offset to the name of the counter + uint32_t_size, // The length of the name + counter2->m_Name.data(), + counter2->m_Name.size()) == 0); // name + // The null-terminator at the end of the counter2 name + BOOST_CHECK(categoryRecordPool[eventRecord1Offset + // Offset to the event record + 8u * uint32_t_size + // Offset to the event record pool + counter2NameOffset + // Offset to the name of the counter + uint32_t_size + // The length of the name + counter2->m_Name.size() // The name of the counter + ] == '\0'); + + // Counter3 name + uint32_t counter3NameOffset = 0; + std::memcpy(&counter3NameOffset, categoryRecordPool + eventRecord2Offset + 5u * uint32_t_size, uint32_t_size); + BOOST_CHECK(counter3NameOffset == 0); + // The length of the SWTrace string (name) + BOOST_CHECK(categoryRecordPool[eventRecord2Offset + // Offset to the event record + 8u * uint32_t_size + // Offset to the event record pool + counter3NameOffset // Offset to the name of the counter + ] == counter3->m_Name.size() + 1); // The length of the name including the + // null-terminator + // The counter3 name + BOOST_CHECK(std::memcmp(categoryRecordPool + // The beginning of the category pool + eventRecord2Offset + // Offset to the event record + 8u * uint32_t_size + // Offset to the event record pool + counter3NameOffset + // Offset to the name of the counter + uint32_t_size, // The length of the name + counter3->m_Name.data(), + counter3->m_Name.size()) == 0); // name + // The null-terminator at the end of the counter3 name + BOOST_CHECK(categoryRecordPool[eventRecord2Offset + // Offset to the event record + 8u * uint32_t_size + // Offset to the event record pool + counter3NameOffset + // Offset to the name of the counter + uint32_t_size + // The length of the name + counter3->m_Name.size() // The name of the counter + ] == '\0'); +} + +BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest1) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a category for testing + const std::string categoryName = "some invalid category"; + uint16_t deviceUid = 1302; + uint16_t counterSetUid = 20734; + const CategoryPtr category = std::make_unique<Category>(categoryName, deviceUid, counterSetUid); + BOOST_CHECK(category); + + // Create a category record + Counters counters; + SendCounterPacket::CategoryRecord categoryRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateCategoryRecordTest(category, counters, categoryRecord, errorMessage); + + BOOST_CHECK(!result); + BOOST_CHECK(!errorMessage.empty()); + BOOST_CHECK(categoryRecord.empty()); +} + +BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest2) +{ + MockBuffer mockBuffer(0); + SendCounterPacketTest sendCounterPacketTest(mockBuffer); + + // Create a category for testing + const std::string categoryName = "some_category"; + uint16_t deviceUid = 1302; + uint16_t counterSetUid = 20734; + const CategoryPtr category = std::make_unique<Category>(categoryName, deviceUid, counterSetUid); + BOOST_CHECK(category); + category->m_Counters = { 11u, 23u, 5670u }; + + // Create a collection of counters + Counters counters; + counters.insert(std::make_pair<uint16_t, CounterPtr>(11, + CounterPtr(new Counter(11, + 1234, + 0, + 1, + 534.0003f, + "count€r1", // Invalid name + "the first counter", + "millipi2", + 0, + 0)))); + + Counter* counter1 = counters.find(11)->second.get(); + BOOST_CHECK(counter1); + + // Create a category record + SendCounterPacket::CategoryRecord categoryRecord; + std::string errorMessage; + bool result = sendCounterPacketTest.CreateCategoryRecordTest(category, counters, categoryRecord, errorMessage); + + BOOST_CHECK(!result); + BOOST_CHECK(!errorMessage.empty()); + BOOST_CHECK(categoryRecord.empty()); +} + +BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest1) +{ + // The counter directory used for testing + CounterDirectory counterDirectory; + + // Register a device + const std::string device1Name = "device1"; + const Device* device1 = nullptr; + BOOST_CHECK_NO_THROW(device1 = counterDirectory.RegisterDevice(device1Name, 3)); + BOOST_CHECK(counterDirectory.GetDeviceCount() == 1); + BOOST_CHECK(device1); + + // Register a device + const std::string device2Name = "device2"; + const Device* device2 = nullptr; + BOOST_CHECK_NO_THROW(device2 = counterDirectory.RegisterDevice(device2Name)); + BOOST_CHECK(counterDirectory.GetDeviceCount() == 2); + BOOST_CHECK(device2); + + // Buffer with not enough space + MockBuffer mockBuffer(10); + SendCounterPacket sendCounterPacket(mockBuffer); + BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), + armnn::profiling::BufferExhaustion); +} + +BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) +{ + // The counter directory used for testing + CounterDirectory counterDirectory; + + // Register a device + const std::string device1Name = "device1"; + const Device* device1 = nullptr; + BOOST_CHECK_NO_THROW(device1 = counterDirectory.RegisterDevice(device1Name, 3)); + BOOST_CHECK(counterDirectory.GetDeviceCount() == 1); + BOOST_CHECK(device1); + + // Register a device + const std::string device2Name = "device2"; + const Device* device2 = nullptr; + BOOST_CHECK_NO_THROW(device2 = counterDirectory.RegisterDevice(device2Name)); + BOOST_CHECK(counterDirectory.GetDeviceCount() == 2); + BOOST_CHECK(device2); + + // Register a counter set + const std::string counterSet1Name = "counterset1"; + const CounterSet* counterSet1 = nullptr; + BOOST_CHECK_NO_THROW(counterSet1 = counterDirectory.RegisterCounterSet(counterSet1Name)); + BOOST_CHECK(counterDirectory.GetCounterSetCount() == 1); + BOOST_CHECK(counterSet1); + + // Register a category associated to "device1" and "counterset1" + const std::string category1Name = "category1"; + const Category* category1 = nullptr; + BOOST_CHECK_NO_THROW(category1 = counterDirectory.RegisterCategory(category1Name, + device1->m_Uid, + counterSet1->m_Uid)); + BOOST_CHECK(counterDirectory.GetCategoryCount() == 1); + BOOST_CHECK(category1); + + // Register a category not associated to "device2" but no counter set + const std::string category2Name = "category2"; + const Category* category2 = nullptr; + BOOST_CHECK_NO_THROW(category2 = counterDirectory.RegisterCategory(category2Name, + device2->m_Uid)); + BOOST_CHECK(counterDirectory.GetCategoryCount() == 2); + BOOST_CHECK(category2); + + // Register a counter associated to "category1" + const Counter* counter1 = nullptr; + BOOST_CHECK_NO_THROW(counter1 = counterDirectory.RegisterCounter(category1Name, + 0, + 1, + 123.45f, + "counter1", + "counter1description", + std::string("counter1units"))); + BOOST_CHECK(counterDirectory.GetCounterCount() == 3); + BOOST_CHECK(counter1); + + // Register a counter associated to "category1" + const Counter* counter2 = nullptr; + BOOST_CHECK_NO_THROW(counter2 = counterDirectory.RegisterCounter(category1Name, + 1, + 0, + 330.1245656765f, + "counter2", + "counter2description", + std::string("counter2units"), + armnn::EmptyOptional(), + device2->m_Uid, + 0)); + BOOST_CHECK(counterDirectory.GetCounterCount() == 4); + BOOST_CHECK(counter2); + + // Register a counter associated to "category2" + const Counter* counter3 = nullptr; + BOOST_CHECK_NO_THROW(counter3 = counterDirectory.RegisterCounter(category2Name, + 1, + 1, + 0.0000045399f, + "counter3", + "counter3description", + armnn::EmptyOptional(), + 5, + device2->m_Uid, + counterSet1->m_Uid)); + BOOST_CHECK(counterDirectory.GetCounterCount() == 9); + BOOST_CHECK(counter3); + + // Buffer with enough space + MockBuffer mockBuffer(1024); + SendCounterPacket sendCounterPacket(mockBuffer); + BOOST_CHECK_NO_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory)); + + // Get the read buffer + unsigned int sizeRead = 0; + const unsigned char* readBuffer = mockBuffer.GetReadBuffer(sizeRead); + + // Check the packet header + uint32_t packetHeaderWord0 = ReadUint32(readBuffer, 0); + uint32_t packetHeaderWord1 = ReadUint32(readBuffer, 4); + BOOST_TEST(((packetHeaderWord0 >> 26) & 0x3F) == 0); // packet_family + BOOST_TEST(((packetHeaderWord0 >> 16) & 0x3FF) == 2); // packet_id + BOOST_TEST(packetHeaderWord1 == 944); // data_length + + // Check the body header + uint32_t bodyHeaderWord0 = ReadUint32(readBuffer, 8); + uint32_t bodyHeaderWord1 = ReadUint32(readBuffer, 12); + uint32_t bodyHeaderWord2 = ReadUint32(readBuffer, 16); + uint32_t bodyHeaderWord3 = ReadUint32(readBuffer, 20); + uint32_t bodyHeaderWord4 = ReadUint32(readBuffer, 24); + uint32_t bodyHeaderWord5 = ReadUint32(readBuffer, 28); + uint16_t deviceRecordCount = static_cast<uint16_t>(bodyHeaderWord0 >> 16); + uint16_t counterSetRecordCount = static_cast<uint16_t>(bodyHeaderWord2 >> 16); + uint16_t categoryRecordCount = static_cast<uint16_t>(bodyHeaderWord4 >> 16); + BOOST_TEST(deviceRecordCount == 2); // device_records_count + BOOST_TEST(bodyHeaderWord1 == 0); // device_records_pointer_table_offset + BOOST_TEST(counterSetRecordCount == 1); // counter_set_count + BOOST_TEST(bodyHeaderWord3 == 8); // counter_set_pointer_table_offset + BOOST_TEST(categoryRecordCount == 2); // categories_count + BOOST_TEST(bodyHeaderWord5 == 12); // categories_pointer_table_offset + + // Check the device records pointer table + uint32_t deviceRecordOffset0 = ReadUint32(readBuffer, 32); + uint32_t deviceRecordOffset1 = ReadUint32(readBuffer, 36); + BOOST_TEST(deviceRecordOffset0 == 0); // Device record offset for "device1" + BOOST_TEST(deviceRecordOffset1 == 20); // Device record offset for "device2" + + // Check the counter set pointer table + uint32_t counterSetRecordOffset0 = ReadUint32(readBuffer, 40); + BOOST_TEST(counterSetRecordOffset0 == 40); // Counter set record offset for "counterset1" + + // Check the category pointer table + uint32_t categoryRecordOffset0 = ReadUint32(readBuffer, 44); + uint32_t categoryRecordOffset1 = ReadUint32(readBuffer, 48); + BOOST_TEST(categoryRecordOffset0 == 64); // Category record offset for "category1" + BOOST_TEST(categoryRecordOffset1 == 476); // Category record offset for "category2" + + // Get the device record pool offset + uint32_t uint32_t_size = sizeof(uint32_t); + uint32_t packetBodyPoolOffset = 2u * uint32_t_size + // packet_header + 6u * uint32_t_size + // body_header + deviceRecordCount * uint32_t_size + // Size of device_records_pointer_table + counterSetRecordCount * uint32_t_size + // Size of counter_set_pointer_table + categoryRecordCount * uint32_t_size; // Size of categories_pointer_table + + // Device record structure/collection used for testing + struct DeviceRecord + { + uint16_t uid; + uint16_t cores; + uint32_t name_offset; + uint32_t name_length; + std::string name; + }; + std::vector<DeviceRecord> deviceRecords; + uint32_t deviceRecordsPointerTableOffset = 2u * uint32_t_size + // packet_header + 6u * uint32_t_size + // body_header + bodyHeaderWord1; // device_records_pointer_table_offset + for (uint32_t i = 0; i < deviceRecordCount; i++) + { + // Get the device record offset + uint32_t deviceRecordOffset = ReadUint32(readBuffer, deviceRecordsPointerTableOffset + i * uint32_t_size); + + // Collect the data for the device record + uint32_t deviceRecordWord0 = ReadUint32(readBuffer, + packetBodyPoolOffset + deviceRecordOffset + 0 * uint32_t_size); + uint32_t deviceRecordWord1 = ReadUint32(readBuffer, + packetBodyPoolOffset + deviceRecordOffset + 1 * uint32_t_size); + DeviceRecord deviceRecord; + deviceRecord.uid = static_cast<uint16_t>(deviceRecordWord0 >> 16); // uid + deviceRecord.cores = static_cast<uint16_t>(deviceRecordWord0); // cores + deviceRecord.name_offset = deviceRecordWord1; // name_offset + + uint32_t deviceRecordPoolOffset = packetBodyPoolOffset + // Packet body offset + deviceRecordOffset + // Device record offset + 2 * uint32_t_size + // Device record header + deviceRecord.name_offset; // Device name offset + uint32_t deviceRecordNameLength = ReadUint32(readBuffer, deviceRecordPoolOffset); + deviceRecord.name_length = deviceRecordNameLength; // name_length + unsigned char deviceRecordNameNullTerminator = // name null-terminator + ReadUint8(readBuffer, deviceRecordPoolOffset + uint32_t_size + deviceRecordNameLength - 1); + BOOST_CHECK(deviceRecordNameNullTerminator == '\0'); + std::vector<unsigned char> deviceRecordNameBuffer(deviceRecord.name_length - 1); + std::memcpy(deviceRecordNameBuffer.data(), + readBuffer + deviceRecordPoolOffset + uint32_t_size, deviceRecordNameBuffer.size()); + deviceRecord.name.assign(deviceRecordNameBuffer.begin(), deviceRecordNameBuffer.end()); // name + + deviceRecords.push_back(deviceRecord); + } + + // Check that the device records are correct + BOOST_CHECK(deviceRecords.size() == 2); + for (const DeviceRecord& deviceRecord : deviceRecords) + { + const Device* device = counterDirectory.GetDevice(deviceRecord.uid); + BOOST_CHECK(device); + BOOST_CHECK(device->m_Uid == deviceRecord.uid); + BOOST_CHECK(device->m_Cores == deviceRecord.cores); + BOOST_CHECK(device->m_Name == deviceRecord.name); + } + + // Counter set record structure/collection used for testing + struct CounterSetRecord + { + uint16_t uid; + uint16_t count; + uint32_t name_offset; + uint32_t name_length; + std::string name; + }; + std::vector<CounterSetRecord> counterSetRecords; + uint32_t counterSetRecordsPointerTableOffset = 2u * uint32_t_size + // packet_header + 6u * uint32_t_size + // body_header + bodyHeaderWord3; // counter_set_pointer_table_offset + for (uint32_t i = 0; i < counterSetRecordCount; i++) + { + // Get the counter set record offset + uint32_t counterSetRecordOffset = ReadUint32(readBuffer, + counterSetRecordsPointerTableOffset + i * uint32_t_size); + + // Collect the data for the counter set record + uint32_t counterSetRecordWord0 = ReadUint32(readBuffer, + packetBodyPoolOffset + counterSetRecordOffset + 0 * uint32_t_size); + uint32_t counterSetRecordWord1 = ReadUint32(readBuffer, + packetBodyPoolOffset + counterSetRecordOffset + 1 * uint32_t_size); + CounterSetRecord counterSetRecord; + counterSetRecord.uid = static_cast<uint16_t>(counterSetRecordWord0 >> 16); // uid + counterSetRecord.count = static_cast<uint16_t>(counterSetRecordWord0); // count + counterSetRecord.name_offset = counterSetRecordWord1; // name_offset + + uint32_t counterSetRecordPoolOffset = packetBodyPoolOffset + // Packet body offset + counterSetRecordOffset + // Counter set record offset + 2 * uint32_t_size + // Counter set record header + counterSetRecord.name_offset; // Counter set name offset + uint32_t counterSetRecordNameLength = ReadUint32(readBuffer, counterSetRecordPoolOffset); + counterSetRecord.name_length = counterSetRecordNameLength; // name_length + unsigned char counterSetRecordNameNullTerminator = // name null-terminator + ReadUint8(readBuffer, counterSetRecordPoolOffset + uint32_t_size + counterSetRecordNameLength - 1); + BOOST_CHECK(counterSetRecordNameNullTerminator == '\0'); + std::vector<unsigned char> counterSetRecordNameBuffer(counterSetRecord.name_length - 1); + std::memcpy(counterSetRecordNameBuffer.data(), + readBuffer + counterSetRecordPoolOffset + uint32_t_size, counterSetRecordNameBuffer.size()); + counterSetRecord.name.assign(counterSetRecordNameBuffer.begin(), counterSetRecordNameBuffer.end()); // name + + counterSetRecords.push_back(counterSetRecord); + } + + // Check that the counter set records are correct + BOOST_CHECK(counterSetRecords.size() == 1); + for (const CounterSetRecord& counterSetRecord : counterSetRecords) + { + const CounterSet* counterSet = counterDirectory.GetCounterSet(counterSetRecord.uid); + BOOST_CHECK(counterSet); + BOOST_CHECK(counterSet->m_Uid == counterSetRecord.uid); + BOOST_CHECK(counterSet->m_Count == counterSetRecord.count); + BOOST_CHECK(counterSet->m_Name == counterSetRecord.name); + } + + // Event record structure/collection used for testing + struct EventRecord + { + uint16_t counter_uid; + uint16_t max_counter_uid; + uint16_t device; + uint16_t counter_set; + uint16_t counter_class; + uint16_t interpolation; + double multiplier; + uint32_t name_offset; + uint32_t name_length; + std::string name; + uint32_t description_offset; + uint32_t description_length; + std::string description; + uint32_t units_offset; + uint32_t units_length; + std::string units; + }; + // Category record structure/collection used for testing + struct CategoryRecord + { + uint16_t device; + uint16_t counter_set; + uint16_t event_count; + uint32_t event_pointer_table_offset; + uint32_t name_offset; + uint32_t name_length; + std::string name; + std::vector<uint32_t> event_pointer_table; + std::vector<EventRecord> event_records; + }; + std::vector<CategoryRecord> categoryRecords; + uint32_t categoryRecordsPointerTableOffset = 2u * uint32_t_size + // packet_header + 6u * uint32_t_size + // body_header + bodyHeaderWord5; // categories_pointer_table_offset + for (uint32_t i = 0; i < categoryRecordCount; i++) + { + // Get the category record offset + uint32_t categoryRecordOffset = ReadUint32(readBuffer, categoryRecordsPointerTableOffset + i * uint32_t_size); + + // Collect the data for the category record + uint32_t categoryRecordWord0 = ReadUint32(readBuffer, + packetBodyPoolOffset + categoryRecordOffset + 0 * uint32_t_size); + uint32_t categoryRecordWord1 = ReadUint32(readBuffer, + packetBodyPoolOffset + categoryRecordOffset + 1 * uint32_t_size); + uint32_t categoryRecordWord2 = ReadUint32(readBuffer, + packetBodyPoolOffset + categoryRecordOffset + 2 * uint32_t_size); + uint32_t categoryRecordWord3 = ReadUint32(readBuffer, + packetBodyPoolOffset + categoryRecordOffset + 3 * uint32_t_size); + CategoryRecord categoryRecord; + categoryRecord.device = static_cast<uint16_t>(categoryRecordWord0 >> 16); // device + categoryRecord.counter_set = static_cast<uint16_t>(categoryRecordWord0); // counter_set + categoryRecord.event_count = static_cast<uint16_t>(categoryRecordWord1 >> 16); // event_count + categoryRecord.event_pointer_table_offset = categoryRecordWord2; // event_pointer_table_offset + categoryRecord.name_offset = categoryRecordWord3; // name_offset + + uint32_t categoryRecordPoolOffset = packetBodyPoolOffset + // Packet body offset + categoryRecordOffset + // Category record offset + 4 * uint32_t_size; // Category record header + + uint32_t categoryRecordNameLength = ReadUint32(readBuffer, + categoryRecordPoolOffset + categoryRecord.name_offset); + categoryRecord.name_length = categoryRecordNameLength; // name_length + unsigned char categoryRecordNameNullTerminator = + ReadUint8(readBuffer, + categoryRecordPoolOffset + + categoryRecord.name_offset + + uint32_t_size + + categoryRecordNameLength - 1); // name null-terminator + BOOST_CHECK(categoryRecordNameNullTerminator == '\0'); + std::vector<unsigned char> categoryRecordNameBuffer(categoryRecord.name_length - 1); + std::memcpy(categoryRecordNameBuffer.data(), + readBuffer + + categoryRecordPoolOffset + + categoryRecord.name_offset + + uint32_t_size, + categoryRecordNameBuffer.size()); + categoryRecord.name.assign(categoryRecordNameBuffer.begin(), categoryRecordNameBuffer.end()); // name + + categoryRecord.event_pointer_table.resize(categoryRecord.event_count); + for (uint32_t eventIndex = 0; eventIndex < categoryRecord.event_count; eventIndex++) + { + uint32_t eventRecordOffset = ReadUint32(readBuffer, + categoryRecordPoolOffset + + categoryRecord.event_pointer_table_offset + + eventIndex * uint32_t_size); + categoryRecord.event_pointer_table[eventIndex] = eventRecordOffset; + + // Collect the data for the event record + uint32_t eventRecordWord0 = ReadUint32(readBuffer, + categoryRecordPoolOffset + eventRecordOffset + 0 * uint32_t_size); + uint32_t eventRecordWord1 = ReadUint32(readBuffer, + categoryRecordPoolOffset + eventRecordOffset + 1 * uint32_t_size); + uint32_t eventRecordWord2 = ReadUint32(readBuffer, + categoryRecordPoolOffset + eventRecordOffset + 2 * uint32_t_size); + uint64_t eventRecordWord34 = ReadUint64(readBuffer, + categoryRecordPoolOffset + eventRecordOffset + 3 * uint32_t_size); + uint32_t eventRecordWord5 = ReadUint32(readBuffer, + categoryRecordPoolOffset + eventRecordOffset + 5 * uint32_t_size); + uint32_t eventRecordWord6 = ReadUint32(readBuffer, + categoryRecordPoolOffset + eventRecordOffset + 6 * uint32_t_size); + uint32_t eventRecordWord7 = ReadUint32(readBuffer, + categoryRecordPoolOffset + eventRecordOffset + 7 * uint32_t_size); + EventRecord eventRecord; + eventRecord.counter_uid = static_cast<uint16_t>(eventRecordWord0); // counter_uid + eventRecord.max_counter_uid = static_cast<uint16_t>(eventRecordWord0 >> 16); // max_counter_uid + eventRecord.device = static_cast<uint16_t>(eventRecordWord1 >> 16); // device + eventRecord.counter_set = static_cast<uint16_t>(eventRecordWord1); // counter_set + eventRecord.counter_class = static_cast<uint16_t>(eventRecordWord2 >> 16); // class + eventRecord.interpolation = static_cast<uint16_t>(eventRecordWord2); // interpolation + std::memcpy(&eventRecord.multiplier, &eventRecordWord34, sizeof(eventRecord.multiplier)); // multiplier + eventRecord.name_offset = static_cast<uint32_t>(eventRecordWord5); // name_offset + eventRecord.description_offset = static_cast<uint32_t>(eventRecordWord6); // description_offset + eventRecord.units_offset = static_cast<uint32_t>(eventRecordWord7); // units_offset + + uint32_t eventRecordPoolOffset = categoryRecordPoolOffset + // Category record pool offset + eventRecordOffset + // Event record offset + 8 * uint32_t_size; // Event record header + + uint32_t eventRecordNameLength = ReadUint32(readBuffer, + eventRecordPoolOffset + eventRecord.name_offset); + eventRecord.name_length = eventRecordNameLength; // name_length + unsigned char eventRecordNameNullTerminator = + ReadUint8(readBuffer, + eventRecordPoolOffset + + eventRecord.name_offset + + uint32_t_size + + eventRecordNameLength - 1); // name null-terminator + BOOST_CHECK(eventRecordNameNullTerminator == '\0'); + std::vector<unsigned char> eventRecordNameBuffer(eventRecord.name_length - 1); + std::memcpy(eventRecordNameBuffer.data(), + readBuffer + + eventRecordPoolOffset + + eventRecord.name_offset + + uint32_t_size, + eventRecordNameBuffer.size()); + eventRecord.name.assign(eventRecordNameBuffer.begin(), eventRecordNameBuffer.end()); // name + + uint32_t eventRecordDescriptionLength = ReadUint32(readBuffer, + eventRecordPoolOffset + eventRecord.description_offset); + eventRecord.description_length = eventRecordDescriptionLength; // description_length + unsigned char eventRecordDescriptionNullTerminator = + ReadUint8(readBuffer, + eventRecordPoolOffset + + eventRecord.description_offset + + uint32_t_size + + eventRecordDescriptionLength - 1); // description null-terminator + BOOST_CHECK(eventRecordDescriptionNullTerminator == '\0'); + std::vector<unsigned char> eventRecordDescriptionBuffer(eventRecord.description_length - 1); + std::memcpy(eventRecordDescriptionBuffer.data(), + readBuffer + + eventRecordPoolOffset + + eventRecord.description_offset + + uint32_t_size, + eventRecordDescriptionBuffer.size()); + eventRecord.description.assign(eventRecordDescriptionBuffer.begin(), + eventRecordDescriptionBuffer.end()); // description + + if (eventRecord.units_offset > 0) + { + uint32_t eventRecordUnitsLength = ReadUint32(readBuffer, + eventRecordPoolOffset + eventRecord.units_offset); + eventRecord.units_length = eventRecordUnitsLength; // units_length + unsigned char eventRecordUnitsNullTerminator = + ReadUint8(readBuffer, + eventRecordPoolOffset + + eventRecord.units_offset + + uint32_t_size + + eventRecordUnitsLength - 1); // units null-terminator + BOOST_CHECK(eventRecordUnitsNullTerminator == '\0'); + std::vector<unsigned char> eventRecordUnitsBuffer(eventRecord.units_length - 1); + std::memcpy(eventRecordUnitsBuffer.data(), + readBuffer + + eventRecordPoolOffset + + eventRecord.units_offset + + uint32_t_size, + eventRecordUnitsBuffer.size()); + eventRecord.units.assign(eventRecordUnitsBuffer.begin(), eventRecordUnitsBuffer.end()); // units + } + + categoryRecord.event_records.push_back(eventRecord); + } + + categoryRecords.push_back(categoryRecord); + } + + // Check that the category records are correct + BOOST_CHECK(categoryRecords.size() == 2); + for (const CategoryRecord& categoryRecord : categoryRecords) + { + const Category* category = counterDirectory.GetCategory(categoryRecord.name); + BOOST_CHECK(category); + BOOST_CHECK(category->m_Name == categoryRecord.name); + BOOST_CHECK(category->m_DeviceUid == categoryRecord.device); + BOOST_CHECK(category->m_CounterSetUid == categoryRecord.counter_set); + BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count); + + // Check that the event records are correct + for (const EventRecord& eventRecord : categoryRecord.event_records) + { + const Counter* counter = counterDirectory.GetCounter(eventRecord.counter_uid); + BOOST_CHECK(counter); + BOOST_CHECK(counter->m_MaxCounterUid == eventRecord.max_counter_uid); + BOOST_CHECK(counter->m_DeviceUid == eventRecord.device); + BOOST_CHECK(counter->m_CounterSetUid == eventRecord.counter_set); + BOOST_CHECK(counter->m_Class == eventRecord.counter_class); + BOOST_CHECK(counter->m_Interpolation == eventRecord.interpolation); + BOOST_CHECK(counter->m_Multiplier == eventRecord.multiplier); + BOOST_CHECK(counter->m_Name == eventRecord.name); + BOOST_CHECK(counter->m_Description == eventRecord.description); + BOOST_CHECK(counter->m_Units == eventRecord.units); + } + } +} + +BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest3) +{ + // Using a mock counter directory that allows to register invalid objects + MockCounterDirectory counterDirectory; + + // Register an invalid device + const std::string deviceName = "inv@lid dev!c€"; + const Device* device = nullptr; + BOOST_CHECK_NO_THROW(device = counterDirectory.RegisterDevice(deviceName, 3)); + BOOST_CHECK(counterDirectory.GetDeviceCount() == 1); + BOOST_CHECK(device); + + // Buffer with enough space + MockBuffer mockBuffer(1024); + SendCounterPacket sendCounterPacket(mockBuffer); + BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); +} + +BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest4) +{ + // Using a mock counter directory that allows to register invalid objects + MockCounterDirectory counterDirectory; + + // Register an invalid counter set + const std::string counterSetName = "inv@lid count€rs€t"; + const CounterSet* counterSet = nullptr; + BOOST_CHECK_NO_THROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName)); + BOOST_CHECK(counterDirectory.GetCounterSetCount() == 1); + BOOST_CHECK(counterSet); + + // Buffer with enough space + MockBuffer mockBuffer(1024); + SendCounterPacket sendCounterPacket(mockBuffer); + BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); +} + +BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest5) +{ + // Using a mock counter directory that allows to register invalid objects + MockCounterDirectory counterDirectory; + + // Register an invalid category + const std::string categoryName = "c@t€gory"; + const Category* category = nullptr; + BOOST_CHECK_NO_THROW(category = counterDirectory.RegisterCategory(categoryName)); + BOOST_CHECK(counterDirectory.GetCategoryCount() == 1); + BOOST_CHECK(category); + + // Buffer with enough space + MockBuffer mockBuffer(1024); + SendCounterPacket sendCounterPacket(mockBuffer); + BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); +} + +BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest6) +{ + // Using a mock counter directory that allows to register invalid objects + MockCounterDirectory counterDirectory; + + // Register an invalid device + const std::string deviceName = "inv@lid dev!c€"; + const Device* device = nullptr; + BOOST_CHECK_NO_THROW(device = counterDirectory.RegisterDevice(deviceName, 3)); + BOOST_CHECK(counterDirectory.GetDeviceCount() == 1); + BOOST_CHECK(device); + + // Register an invalid counter set + const std::string counterSetName = "inv@lid count€rs€t"; + const CounterSet* counterSet = nullptr; + BOOST_CHECK_NO_THROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName)); + BOOST_CHECK(counterDirectory.GetCounterSetCount() == 1); + BOOST_CHECK(counterSet); + + // Register an invalid category associated to an invalid device and an invalid counter set + const std::string categoryName = "c@t€gory"; + const Category* category = nullptr; + BOOST_CHECK_NO_THROW(category = counterDirectory.RegisterCategory(categoryName, + device->m_Uid, + counterSet->m_Uid)); + BOOST_CHECK(counterDirectory.GetCategoryCount() == 1); + BOOST_CHECK(category); + + // Buffer with enough space + MockBuffer mockBuffer(1024); + SendCounterPacket sendCounterPacket(mockBuffer); + BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); +} + +BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest7) +{ + // Using a mock counter directory that allows to register invalid objects + MockCounterDirectory counterDirectory; + + // Register an valid device + const std::string deviceName = "valid device"; + const Device* device = nullptr; + BOOST_CHECK_NO_THROW(device = counterDirectory.RegisterDevice(deviceName, 3)); + BOOST_CHECK(counterDirectory.GetDeviceCount() == 1); + BOOST_CHECK(device); + + // Register an valid counter set + const std::string counterSetName = "valid counterset"; + const CounterSet* counterSet = nullptr; + BOOST_CHECK_NO_THROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName)); + BOOST_CHECK(counterDirectory.GetCounterSetCount() == 1); + BOOST_CHECK(counterSet); + + // Register an valid category associated to a valid device and a valid counter set + const std::string categoryName = "category"; + const Category* category = nullptr; + BOOST_CHECK_NO_THROW(category = counterDirectory.RegisterCategory(categoryName, + device->m_Uid, + counterSet->m_Uid)); + BOOST_CHECK(counterDirectory.GetCategoryCount() == 1); + BOOST_CHECK(category); + + // Register an invalid counter associated to a valid category + const Counter* counter = nullptr; + BOOST_CHECK_NO_THROW(counter = counterDirectory.RegisterCounter(categoryName, + 0, + 1, + 123.45f, + "counter", + "counter description", + std::string("invalid counter units"), + 5, + device->m_Uid, + counterSet->m_Uid)); + BOOST_CHECK(counterDirectory.GetCounterCount() == 5); + BOOST_CHECK(counter); + + // Buffer with enough space + MockBuffer mockBuffer(1024); + SendCounterPacket sendCounterPacket(mockBuffer); + BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); +} BOOST_AUTO_TEST_SUITE_END() diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp index a22d02bd63..6c7bb50362 100644 --- a/src/profiling/test/SendCounterPacketTests.hpp +++ b/src/profiling/test/SendCounterPacketTests.hpp @@ -66,7 +66,7 @@ public: memcpy(buffer, message.c_str(), static_cast<unsigned int>(message.size()) + 1); } - void SendCounterDirectoryPacket(const CounterDirectory& counterDirectory) override + void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override { std::string message("SendCounterDirectoryPacket"); unsigned int reserved = 0; @@ -99,3 +99,244 @@ public: private: IBufferWrapper& m_Buffer; }; + +class MockCounterDirectory : public ICounterDirectory +{ +public: + MockCounterDirectory() = default; + ~MockCounterDirectory() = default; + + // Register profiling objects + const Category* RegisterCategory(const std::string& categoryName, + const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(), + const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional()) + { + // Get the device UID + uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0; + + // Get the counter set UID + uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0; + + // Create the category + CategoryPtr category = std::make_unique<Category>(categoryName, deviceUidValue, counterSetUidValue); + BOOST_ASSERT(category); + + // Get the raw category pointer + const Category* categoryPtr = category.get(); + BOOST_ASSERT(categoryPtr); + + // Register the category + m_Categories.insert(std::move(category)); + + return categoryPtr; + } + + const Device* RegisterDevice(const std::string& deviceName, + uint16_t cores = 0, + const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional()) + { + // Get the device UID + uint16_t deviceUid = GetNextUid(); + + // Create the device + DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, cores); + BOOST_ASSERT(device); + + // Get the raw device pointer + const Device* devicePtr = device.get(); + BOOST_ASSERT(devicePtr); + + // Register the device + m_Devices.insert(std::make_pair(deviceUid, std::move(device))); + + // Connect the counter set to the parent category, if required + if (parentCategoryName.has_value()) + { + // Set the counter set UID in the parent category + Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value())); + BOOST_ASSERT(parentCategory); + parentCategory->m_DeviceUid = deviceUid; + } + + return devicePtr; + } + + const CounterSet* RegisterCounterSet( + const std::string& counterSetName, + uint16_t count = 0, + const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional()) + { + // Get the counter set UID + uint16_t counterSetUid = GetNextUid(); + + // Create the counter set + CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, count); + BOOST_ASSERT(counterSet); + + // Get the raw counter set pointer + const CounterSet* counterSetPtr = counterSet.get(); + BOOST_ASSERT(counterSetPtr); + + // Register the counter set + m_CounterSets.insert(std::make_pair(counterSetUid, std::move(counterSet))); + + // Connect the counter set to the parent category, if required + if (parentCategoryName.has_value()) + { + // Set the counter set UID in the parent category + Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value())); + BOOST_ASSERT(parentCategory); + parentCategory->m_CounterSetUid = counterSetUid; + } + + return counterSetPtr; + } + + const Counter* RegisterCounter(const std::string& parentCategoryName, + uint16_t counterClass, + uint16_t interpolation, + double multiplier, + const std::string& name, + const std::string& description, + const armnn::Optional<std::string>& units = armnn::EmptyOptional(), + const armnn::Optional<uint16_t>& numberOfCores = armnn::EmptyOptional(), + const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(), + const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional()) + { + // Get the number of cores from the argument only + uint16_t deviceCores = numberOfCores.has_value() ? numberOfCores.value() : 0; + + // Get the device UID + uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0; + + // Get the counter set UID + uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0; + + // Get the counter UIDs and calculate the max counter UID + std::vector<uint16_t> counterUids = GetNextCounterUids(deviceCores); + BOOST_ASSERT(!counterUids.empty()); + uint16_t maxCounterUid = deviceCores <= 1 ? counterUids.front() : counterUids.back(); + + // Get the counter units + const std::string unitsValue = units.has_value() ? units.value() : ""; + + // Create the counter + CounterPtr counter = std::make_shared<Counter>(counterUids.front(), + maxCounterUid, + counterClass, + interpolation, + multiplier, + name, + description, + unitsValue, + deviceUidValue, + counterSetUidValue); + BOOST_ASSERT(counter); + + // Get the raw counter pointer + const Counter* counterPtr = counter.get(); + BOOST_ASSERT(counterPtr); + + // Process multiple counters if necessary + for (uint16_t counterUid : counterUids) + { + // Connect the counter to the parent category + Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName)); + BOOST_ASSERT(parentCategory); + parentCategory->m_Counters.push_back(counterUid); + + // Register the counter + m_Counters.insert(std::make_pair(counterUid, counter)); + } + + return counterPtr; + } + + // Getters for counts + uint16_t GetCategoryCount() const override { return boost::numeric_cast<uint16_t>(m_Categories.size()); } + uint16_t GetDeviceCount() const override { return boost::numeric_cast<uint16_t>(m_Devices.size()); } + uint16_t GetCounterSetCount() const override { return boost::numeric_cast<uint16_t>(m_CounterSets.size()); } + uint16_t GetCounterCount() const override { return boost::numeric_cast<uint16_t>(m_Counters.size()); } + + // Getters for collections + const Categories& GetCategories() const override { return m_Categories; } + const Devices& GetDevices() const override { return m_Devices; } + const CounterSets& GetCounterSets() const override { return m_CounterSets; } + const Counters& GetCounters() const override { return m_Counters; } + + // Getters for profiling objects + const Category* GetCategory(const std::string& name) const override + { + auto it = std::find_if(m_Categories.begin(), m_Categories.end(), [&name](const CategoryPtr& category) + { + BOOST_ASSERT(category); + + return category->m_Name == name; + }); + + if (it == m_Categories.end()) + { + return nullptr; + } + + return it->get(); + } + + const Device* GetDevice(uint16_t uid) const override + { + return nullptr; // Not used by the unit tests + } + + const CounterSet* GetCounterSet(uint16_t uid) const override + { + return nullptr; // Not used by the unit tests + } + + const Counter* GetCounter(uint16_t uid) const override + { + return nullptr; // Not used by the unit tests + } + +private: + Categories m_Categories; + Devices m_Devices; + CounterSets m_CounterSets; + Counters m_Counters; +}; + +class SendCounterPacketTest : public SendCounterPacket +{ +public: + SendCounterPacketTest(IBufferWrapper& buffer) + : SendCounterPacket(buffer) + {} + + bool CreateDeviceRecordTest(const DevicePtr& device, + DeviceRecord& deviceRecord, + std::string& errorMessage) + { + return CreateDeviceRecord(device, deviceRecord, errorMessage); + } + + bool CreateCounterSetRecordTest(const CounterSetPtr& counterSet, + CounterSetRecord& counterSetRecord, + std::string& errorMessage) + { + return CreateCounterSetRecord(counterSet, counterSetRecord, errorMessage); + } + + bool CreateEventRecordTest(const CounterPtr& counter, + EventRecord& eventRecord, + std::string& errorMessage) + { + return CreateEventRecord(counter, eventRecord, errorMessage); + } + + bool CreateCategoryRecordTest(const CategoryPtr& category, + const Counters& counters, + CategoryRecord& categoryRecord, + std::string& errorMessage) + { + return CreateCategoryRecord(category, counters, categoryRecord, errorMessage); + } +}; |