aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJim Flynn <jim.flynn@arm.com>2020-02-14 10:18:08 +0000
committerJim Flynn <jim.flynn@arm.com>2020-02-16 13:37:52 +0000
commit640635528c5236458a05f970391894adf0a515d0 (patch)
tree0559eb6e0b2b34f6a314f2cc79acb8b3e189467a
parenta856501ac5c2d0cca70068993e5c7cc714872890 (diff)
downloadarmnn-640635528c5236458a05f970391894adf0a515d0.tar.gz
IVGCVSW-4320 Implement ReportCounters in BackendProfiling
Change-Id: Idd545079f6331bb4241709fa1534635f3fdf610b Signed-off-by: Jim Flynn <jim.flynn@arm.com>
-rw-r--r--CMakeLists.txt1
-rw-r--r--src/profiling/IProfilingService.hpp33
-rw-r--r--src/profiling/ProfilingService.cpp2
-rw-r--r--src/profiling/ProfilingService.hpp16
-rw-r--r--src/profiling/backends/BackendProfiling.cpp18
-rw-r--r--src/profiling/backends/BackendProfiling.hpp9
-rw-r--r--src/profiling/test/ProfilingMocks.hpp652
-rw-r--r--src/profiling/test/ProfilingTests.hpp2
-rw-r--r--src/profiling/test/SendCounterPacketTests.cpp1
-rw-r--r--src/profiling/test/SendCounterPacketTests.hpp616
-rw-r--r--src/profiling/test/SendTimelinePacketTests.cpp2
-rw-r--r--src/profiling/test/TimelineUtilityMethodsTests.cpp2
-rw-r--r--tests/profiling/gatordmock/tests/GatordMockTests.cpp2
13 files changed, 725 insertions, 631 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d534c0a6d5..a9b1e64e4c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -649,6 +649,7 @@ if(BUILD_UNIT_TESTS)
src/profiling/test/FileOnlyProfilingDecoratorTests.cpp
src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp
src/profiling/test/ProfilingGuidTest.cpp
+ src/profiling/test/ProfilingMocks.hpp
src/profiling/test/ProfilingTests.cpp
src/profiling/test/ProfilingTests.hpp
src/profiling/test/ProfilingTestUtils.cpp
diff --git a/src/profiling/IProfilingService.hpp b/src/profiling/IProfilingService.hpp
new file mode 100644
index 0000000000..7f3ff70062
--- /dev/null
+++ b/src/profiling/IProfilingService.hpp
@@ -0,0 +1,33 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "CounterIdMap.hpp"
+#include "Holder.hpp"
+#include "ISendCounterPacket.hpp"
+#include "ProfilingGuidGenerator.hpp"
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class IProfilingService : public IProfilingGuidGenerator
+{
+public:
+ virtual ~IProfilingService() {};
+ virtual std::unique_ptr<ISendTimelinePacket> GetSendTimelinePacket() const = 0;
+ virtual const ICounterMappings& GetCounterMappings() const = 0;
+ virtual ISendCounterPacket& GetSendCounterPacket() = 0;
+ virtual bool IsProfilingEnabled() const = 0;
+ virtual CaptureData GetCaptureData() = 0;
+};
+
+} // namespace profiling
+
+} // namespace armnn
+
diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp
index b07465f077..0c33ca0e7d 100644
--- a/src/profiling/ProfilingService.cpp
+++ b/src/profiling/ProfilingService.cpp
@@ -30,7 +30,7 @@ void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOpti
}
}
-bool ProfilingService::IsProfilingEnabled()
+bool ProfilingService::IsProfilingEnabled() const
{
return m_Options.m_EnableProfiling;
}
diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp
index 2584c76020..4629c126bf 100644
--- a/src/profiling/ProfilingService.hpp
+++ b/src/profiling/ProfilingService.hpp
@@ -12,6 +12,7 @@
#include "CounterIdMap.hpp"
#include "ICounterRegistry.hpp"
#include "ICounterValues.hpp"
+#include "IProfilingService.hpp"
#include "PeriodicCounterCapture.hpp"
#include "PeriodicCounterSelectionCommandHandler.hpp"
#include "PerJobCounterSelectionCommandHandler.hpp"
@@ -38,7 +39,7 @@ static const uint16_t UNREGISTERED_BACKENDS = 3;
static const uint16_t INFERENCES_RUN = 4;
static const uint16_t MAX_ARMNN_COUNTER = INFERENCES_RUN;
-class ProfilingService : public IReadWriteCounterValues, public IProfilingGuidGenerator
+class ProfilingService : public IReadWriteCounterValues, public IProfilingService
{
public:
using ExternalProfilingOptions = IRuntime::CreationOptions::ExternalProfilingOptions;
@@ -77,13 +78,13 @@ public:
uint32_t GetCounterValue(uint16_t counterUid) const override;
uint16_t GetCounterCount() const override;
// counter global/backend mapping functions
- const ICounterMappings& GetCounterMappings() const;
+ const ICounterMappings& GetCounterMappings() const override;
IRegisterCounterMapping& GetCounterMappingRegistry();
// Getters for the profiling service state
- bool IsProfilingEnabled();
+ bool IsProfilingEnabled() const override;
- CaptureData GetCaptureData();
+ CaptureData GetCaptureData() override;
void SetCaptureData(uint32_t capturePeriod,
const std::vector<uint16_t>& counterIds,
const std::set<BackendId>& activeBackends);
@@ -100,7 +101,12 @@ public:
/// Create a ProfilingStaticGuid based on a hash of the string
ProfilingStaticGuid GenerateStaticId(const std::string& str) override;
- std::unique_ptr<ISendTimelinePacket> GetSendTimelinePacket() const;
+ std::unique_ptr<ISendTimelinePacket> GetSendTimelinePacket() const override;
+
+ ISendCounterPacket& GetSendCounterPacket() override
+ {
+ return m_SendCounterPacket;
+ }
/// Check if the profiling is enabled
bool IsEnabled() { return m_Options.m_EnableProfiling; }
diff --git a/src/profiling/backends/BackendProfiling.cpp b/src/profiling/backends/BackendProfiling.cpp
index 884fb3f2ff..0926879f67 100644
--- a/src/profiling/backends/BackendProfiling.cpp
+++ b/src/profiling/backends/BackendProfiling.cpp
@@ -29,6 +29,24 @@ IProfilingGuidGenerator& BackendProfiling::GetProfilingGuidGenerator()
return m_ProfilingService;
}
+void BackendProfiling::ReportCounters(const std::vector<Timestamp>& timestamps)
+{
+ for (const auto timestampInfo : timestamps)
+ {
+ std::vector<CounterValue> backendCounterValues = timestampInfo.counterValues;
+ for_each(backendCounterValues.begin(), backendCounterValues.end(), [&](CounterValue& backendCounterValue)
+ {
+ // translate the counterId to globalCounterId
+ backendCounterValue.counterId = m_ProfilingService.GetCounterMappings().GetGlobalId(
+ backendCounterValue.counterId, m_BackendId);
+ });
+
+ // Send Periodic Counter Capture Packet for the Timestamp
+ m_ProfilingService.GetSendCounterPacket().SendPeriodicCounterCapturePacket(
+ timestampInfo.timestamp, backendCounterValues);
+ }
+}
+
CounterStatus BackendProfiling::GetCounterStatus(uint16_t backendCounterId)
{
uint16_t globalCounterId = m_ProfilingService.GetCounterMappings().GetGlobalId(backendCounterId, m_BackendId);
diff --git a/src/profiling/backends/BackendProfiling.hpp b/src/profiling/backends/BackendProfiling.hpp
index e0e0f58e7d..c0f3eea978 100644
--- a/src/profiling/backends/BackendProfiling.hpp
+++ b/src/profiling/backends/BackendProfiling.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include "ProfilingService.hpp"
+#include "IProfilingService.hpp"
#include <armnn/backends/profiling/IBackendProfiling.hpp>
namespace armnn
@@ -18,7 +18,7 @@ class BackendProfiling : public IBackendProfiling
{
public:
BackendProfiling(const IRuntime::CreationOptions& options,
- ProfilingService& profilingService,
+ IProfilingService& profilingService,
const BackendId& backendId)
: m_Options(options),
m_ProfilingService(profilingService),
@@ -34,8 +34,7 @@ public:
IProfilingGuidGenerator& GetProfilingGuidGenerator() override;
- void ReportCounters(const std::vector<Timestamp>&) override
- {}
+ void ReportCounters(const std::vector<Timestamp>&) override;
CounterStatus GetCounterStatus(uint16_t backendCounterId) override;
@@ -45,7 +44,7 @@ public:
private:
IRuntime::CreationOptions m_Options;
- ProfilingService& m_ProfilingService;
+ IProfilingService& m_ProfilingService;
BackendId m_BackendId;
};
} // namespace profiling
diff --git a/src/profiling/test/ProfilingMocks.hpp b/src/profiling/test/ProfilingMocks.hpp
new file mode 100644
index 0000000000..9d1321345a
--- /dev/null
+++ b/src/profiling/test/ProfilingMocks.hpp
@@ -0,0 +1,652 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <IProfilingConnectionFactory.hpp>
+#include <IProfilingService.hpp>
+#include <ProfilingGuidGenerator.hpp>
+#include <ProfilingUtils.hpp>
+#include <SendCounterPacket.hpp>
+#include <SendThread.hpp>
+
+#include <armnn/Exceptions.hpp>
+#include <armnn/Optional.hpp>
+#include <armnn/Conversion.hpp>
+
+#include <boost/assert.hpp>
+#include <boost/core/ignore_unused.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+
+#include <atomic>
+#include <condition_variable>
+#include <mutex>
+#include <thread>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class MockProfilingConnection : public IProfilingConnection
+{
+public:
+ MockProfilingConnection()
+ : m_IsOpen(true)
+ , m_WrittenData()
+ , m_Packet()
+ {}
+
+ enum class PacketType
+ {
+ StreamMetaData,
+ ConnectionAcknowledge,
+ CounterDirectory,
+ ReqCounterDirectory,
+ PeriodicCounterSelection,
+ PerJobCounterSelection,
+ TimelineMessageDirectory,
+ PeriodicCounterCapture,
+ Unknown
+ };
+
+ bool IsOpen() const override
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ return m_IsOpen;
+ }
+
+ void Close() override
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ m_IsOpen = false;
+ }
+
+ bool WritePacket(const unsigned char* buffer, uint32_t length) override
+ {
+ if (buffer == nullptr || length == 0)
+ {
+ return false;
+ }
+
+ uint32_t header = ReadUint32(buffer, 0);
+
+ uint32_t packetFamily = (header >> 26);
+ uint32_t packetId = ((header >> 16) & 1023);
+
+ PacketType packetType;
+
+ switch (packetFamily)
+ {
+ case 0:
+ packetType = packetId < 6 ? PacketType(packetId) : PacketType::Unknown;
+ break;
+ case 1:
+ packetType = packetId == 0 ? PacketType::TimelineMessageDirectory : PacketType::Unknown;
+ break;
+ case 3:
+ packetType = packetId == 0 ? PacketType::PeriodicCounterCapture : PacketType::Unknown;
+ break;
+ default:
+ packetType = PacketType::Unknown;
+ }
+
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ m_WrittenData.push_back({ packetType, length });
+ return true;
+ }
+
+ long CheckForPacket(const std::pair<PacketType, uint32_t> packetInfo)
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ if(packetInfo.second != 0)
+ {
+ return std::count(m_WrittenData.begin(), m_WrittenData.end(), packetInfo);
+ }
+ else
+ {
+ return std::count_if(m_WrittenData.begin(), m_WrittenData.end(),
+ [&packetInfo](const std::pair<PacketType, uint32_t> pair) { return packetInfo.first == pair.first; });
+ }
+ }
+
+ bool WritePacket(Packet&& packet)
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ m_Packet = std::move(packet);
+ return true;
+ }
+
+ Packet ReadPacket(uint32_t timeout) override
+ {
+ boost::ignore_unused(timeout);
+
+ // Simulate a delay in the reading process. The default timeout is way too long.
+ std::this_thread::sleep_for(std::chrono::milliseconds(5));
+ std::lock_guard<std::mutex> lock(m_Mutex);
+ return std::move(m_Packet);
+ }
+
+ unsigned long GetWrittenDataSize()
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ return m_WrittenData.size();
+ }
+
+ void Clear()
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ m_WrittenData.clear();
+ }
+
+private:
+ bool m_IsOpen;
+ std::vector<std::pair<PacketType, uint32_t>> m_WrittenData;
+ Packet m_Packet;
+ mutable std::mutex m_Mutex;
+};
+
+class MockProfilingConnectionFactory : public IProfilingConnectionFactory
+{
+public:
+ IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
+ {
+ boost::ignore_unused(options);
+ return std::make_unique<MockProfilingConnection>();
+ }
+};
+
+class MockPacketBuffer : public IPacketBuffer
+{
+public:
+ MockPacketBuffer(unsigned int maxSize)
+ : m_MaxSize(maxSize)
+ , m_Size(0)
+ , m_Data(std::make_unique<unsigned char[]>(m_MaxSize))
+ {}
+
+ ~MockPacketBuffer() {}
+
+ const unsigned char* GetReadableData() const override { return m_Data.get(); }
+
+ unsigned int GetSize() const override { return m_Size; }
+
+ void MarkRead() override { m_Size = 0; }
+
+ void Commit(unsigned int size) override { m_Size = size; }
+
+ void Release() override { m_Size = 0; }
+
+ unsigned char* GetWritableData() override { return m_Data.get(); }
+
+private:
+ unsigned int m_MaxSize;
+ unsigned int m_Size;
+ std::unique_ptr<unsigned char[]> m_Data;
+};
+
+class MockBufferManager : public IBufferManager
+{
+public:
+ MockBufferManager(unsigned int size)
+ : m_BufferSize(size),
+ m_Buffer(std::make_unique<MockPacketBuffer>(size)) {}
+
+ ~MockBufferManager() {}
+
+ IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
+ {
+ if (requestedSize > m_BufferSize)
+ {
+ reservedSize = m_BufferSize;
+ }
+ else
+ {
+ reservedSize = requestedSize;
+ }
+
+ return std::move(m_Buffer);
+ }
+
+ void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
+ {
+ packetBuffer->Commit(size);
+ m_Buffer = std::move(packetBuffer);
+
+ if (notifyConsumer)
+ {
+ FlushReadList();
+ }
+ }
+
+ IPacketBufferPtr GetReadableBuffer() override
+ {
+ return std::move(m_Buffer);
+ }
+
+ void Release(IPacketBufferPtr& packetBuffer) override
+ {
+ packetBuffer->Release();
+ m_Buffer = std::move(packetBuffer);
+ }
+
+ void MarkRead(IPacketBufferPtr& packetBuffer) override
+ {
+ packetBuffer->MarkRead();
+ m_Buffer = std::move(packetBuffer);
+ }
+
+ void SetConsumer(IConsumer* consumer) override
+ {
+ if (consumer != nullptr)
+ {
+ m_Consumer = consumer;
+ }
+ }
+
+ void FlushReadList() override
+ {
+ // notify consumer that packet is ready to read
+ if (m_Consumer != nullptr)
+ {
+ m_Consumer->SetReadyToRead();
+ }
+ }
+
+private:
+ unsigned int m_BufferSize;
+ IPacketBufferPtr m_Buffer;
+ IConsumer* m_Consumer = nullptr;
+};
+
+class MockStreamCounterBuffer : public IBufferManager
+{
+public:
+ MockStreamCounterBuffer(unsigned int maxBufferSize = 4096)
+ : m_MaxBufferSize(maxBufferSize)
+ , m_BufferList()
+ , m_CommittedSize(0)
+ , m_ReadableSize(0)
+ , m_ReadSize(0)
+ {}
+ ~MockStreamCounterBuffer() {}
+
+ IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ reservedSize = 0;
+ if (requestedSize > m_MaxBufferSize)
+ {
+ throw armnn::InvalidArgumentException("The maximum buffer size that can be requested is [" +
+ std::to_string(m_MaxBufferSize) + "] bytes");
+ }
+ reservedSize = requestedSize;
+ return std::make_unique<MockPacketBuffer>(requestedSize);
+ }
+
+ void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ packetBuffer->Commit(size);
+ m_BufferList.push_back(std::move(packetBuffer));
+ m_CommittedSize += size;
+
+ if (notifyConsumer)
+ {
+ FlushReadList();
+ }
+ }
+
+ void Release(IPacketBufferPtr& packetBuffer) override
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ packetBuffer->Release();
+ }
+
+ IPacketBufferPtr GetReadableBuffer() override
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ if (m_BufferList.empty())
+ {
+ return nullptr;
+ }
+ IPacketBufferPtr buffer = std::move(m_BufferList.back());
+ m_BufferList.pop_back();
+ m_ReadableSize += buffer->GetSize();
+ return buffer;
+ }
+
+ void MarkRead(IPacketBufferPtr& packetBuffer) override
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ m_ReadSize += packetBuffer->GetSize();
+ packetBuffer->MarkRead();
+ }
+
+ void SetConsumer(IConsumer* consumer) override
+ {
+ if (consumer != nullptr)
+ {
+ m_Consumer = consumer;
+ }
+ }
+
+ void FlushReadList() override
+ {
+ // notify consumer that packet is ready to read
+ if (m_Consumer != nullptr)
+ {
+ m_Consumer->SetReadyToRead();
+ }
+ }
+
+ unsigned int GetCommittedSize() const { return m_CommittedSize; }
+ unsigned int GetReadableSize() const { return m_ReadableSize; }
+ unsigned int GetReadSize() const { return m_ReadSize; }
+
+private:
+ // The maximum buffer size when creating a new buffer
+ unsigned int m_MaxBufferSize;
+
+ // A list of buffers
+ std::vector<IPacketBufferPtr> m_BufferList;
+
+ // The mutex to synchronize this mock's methods
+ std::mutex m_Mutex;
+
+ // The total size of the buffers that has been committed for reading
+ unsigned int m_CommittedSize;
+
+ // The total size of the buffers that can be read
+ unsigned int m_ReadableSize;
+
+ // The total size of the buffers that has already been read
+ unsigned int m_ReadSize;
+
+ // Consumer thread to notify packet is ready to read
+ IConsumer* m_Consumer = nullptr;
+};
+
+class MockSendCounterPacket : public ISendCounterPacket
+{
+public:
+ MockSendCounterPacket(IBufferManager& sendBuffer) : m_BufferManager(sendBuffer) {}
+
+ void SendStreamMetaDataPacket() override
+ {
+ std::string message("SendStreamMetaDataPacket");
+ unsigned int reserved = 0;
+ IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
+ memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
+ m_BufferManager.Commit(buffer, reserved, false);
+ }
+
+ void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override
+ {
+ boost::ignore_unused(counterDirectory);
+
+ std::string message("SendCounterDirectoryPacket");
+ unsigned int reserved = 0;
+ IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
+ memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
+ m_BufferManager.Commit(buffer, reserved);
+ }
+
+ void SendPeriodicCounterCapturePacket(uint64_t timestamp,
+ const std::vector<CounterValue>& values) override
+ {
+ boost::ignore_unused(timestamp, values);
+
+ std::string message("SendPeriodicCounterCapturePacket");
+ unsigned int reserved = 0;
+ IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
+ memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
+ m_BufferManager.Commit(buffer, reserved);
+ }
+
+ void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
+ const std::vector<uint16_t>& selectedCounterIds) override
+ {
+ boost::ignore_unused(capturePeriod, selectedCounterIds);
+
+ std::string message("SendPeriodicCounterSelectionPacket");
+ unsigned int reserved = 0;
+ IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
+ memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
+ m_BufferManager.Commit(buffer, reserved);
+ }
+
+private:
+ IBufferManager& m_BufferManager;
+};
+
+class MockCounterDirectory : public ICounterDirectory
+{
+public:
+ MockCounterDirectory() = default;
+ ~MockCounterDirectory() = default;
+
+ // Register profiling objects
+ const Category* RegisterCategory(const std::string& categoryName,
+ const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
+ const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
+ {
+ // Get the device UID
+ uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
+
+ // Get the counter set UID
+ uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
+
+ // Create the category
+ CategoryPtr category = std::make_unique<Category>(categoryName, deviceUidValue, counterSetUidValue);
+ BOOST_ASSERT(category);
+
+ // Get the raw category pointer
+ const Category* categoryPtr = category.get();
+ BOOST_ASSERT(categoryPtr);
+
+ // Register the category
+ m_Categories.insert(std::move(category));
+
+ return categoryPtr;
+ }
+
+ const Device* RegisterDevice(const std::string& deviceName,
+ uint16_t cores = 0,
+ const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
+ {
+ // Get the device UID
+ uint16_t deviceUid = GetNextUid();
+
+ // Create the device
+ DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, cores);
+ BOOST_ASSERT(device);
+
+ // Get the raw device pointer
+ const Device* devicePtr = device.get();
+ BOOST_ASSERT(devicePtr);
+
+ // Register the device
+ m_Devices.insert(std::make_pair(deviceUid, std::move(device)));
+
+ // Connect the counter set to the parent category, if required
+ if (parentCategoryName.has_value())
+ {
+ // Set the counter set UID in the parent category
+ Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
+ BOOST_ASSERT(parentCategory);
+ parentCategory->m_DeviceUid = deviceUid;
+ }
+
+ return devicePtr;
+ }
+
+ const CounterSet* RegisterCounterSet(
+ const std::string& counterSetName,
+ uint16_t count = 0,
+ const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
+ {
+ // Get the counter set UID
+ uint16_t counterSetUid = GetNextUid();
+
+ // Create the counter set
+ CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, count);
+ BOOST_ASSERT(counterSet);
+
+ // Get the raw counter set pointer
+ const CounterSet* counterSetPtr = counterSet.get();
+ BOOST_ASSERT(counterSetPtr);
+
+ // Register the counter set
+ m_CounterSets.insert(std::make_pair(counterSetUid, std::move(counterSet)));
+
+ // Connect the counter set to the parent category, if required
+ if (parentCategoryName.has_value())
+ {
+ // Set the counter set UID in the parent category
+ Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
+ BOOST_ASSERT(parentCategory);
+ parentCategory->m_CounterSetUid = counterSetUid;
+ }
+
+ return counterSetPtr;
+ }
+
+ const Counter* RegisterCounter(const BackendId& backendId,
+ const uint16_t uid,
+ const std::string& parentCategoryName,
+ uint16_t counterClass,
+ uint16_t interpolation,
+ double multiplier,
+ const std::string& name,
+ const std::string& description,
+ const armnn::Optional<std::string>& units = armnn::EmptyOptional(),
+ const armnn::Optional<uint16_t>& numberOfCores = armnn::EmptyOptional(),
+ const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
+ const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
+ {
+ boost::ignore_unused(backendId);
+
+ // Get the number of cores from the argument only
+ uint16_t deviceCores = numberOfCores.has_value() ? numberOfCores.value() : 0;
+
+ // Get the device UID
+ uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
+
+ // Get the counter set UID
+ uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
+
+ // Get the counter UIDs and calculate the max counter UID
+ std::vector<uint16_t> counterUids = GetNextCounterUids(uid, deviceCores);
+ BOOST_ASSERT(!counterUids.empty());
+ uint16_t maxCounterUid = deviceCores <= 1 ? counterUids.front() : counterUids.back();
+
+ // Get the counter units
+ const std::string unitsValue = units.has_value() ? units.value() : "";
+
+ // Create the counter
+ CounterPtr counter = std::make_shared<Counter>(armnn::profiling::BACKEND_ID,
+ counterUids.front(),
+ maxCounterUid,
+ counterClass,
+ interpolation,
+ multiplier,
+ name,
+ description,
+ unitsValue,
+ deviceUidValue,
+ counterSetUidValue);
+ BOOST_ASSERT(counter);
+
+ // Get the raw counter pointer
+ const Counter* counterPtr = counter.get();
+ BOOST_ASSERT(counterPtr);
+
+ // Process multiple counters if necessary
+ for (uint16_t counterUid : counterUids)
+ {
+ // Connect the counter to the parent category
+ Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName));
+ BOOST_ASSERT(parentCategory);
+ parentCategory->m_Counters.push_back(counterUid);
+
+ // Register the counter
+ m_Counters.insert(std::make_pair(counterUid, counter));
+ }
+
+ return counterPtr;
+ }
+
+ // Getters for counts
+ uint16_t GetCategoryCount() const override { return boost::numeric_cast<uint16_t>(m_Categories.size()); }
+ uint16_t GetDeviceCount() const override { return boost::numeric_cast<uint16_t>(m_Devices.size()); }
+ uint16_t GetCounterSetCount() const override { return boost::numeric_cast<uint16_t>(m_CounterSets.size()); }
+ uint16_t GetCounterCount() const override { return boost::numeric_cast<uint16_t>(m_Counters.size()); }
+
+ // Getters for collections
+ const Categories& GetCategories() const override { return m_Categories; }
+ const Devices& GetDevices() const override { return m_Devices; }
+ const CounterSets& GetCounterSets() const override { return m_CounterSets; }
+ const Counters& GetCounters() const override { return m_Counters; }
+
+ // Getters for profiling objects
+ const Category* GetCategory(const std::string& name) const override
+ {
+ auto it = std::find_if(m_Categories.begin(), m_Categories.end(), [&name](const CategoryPtr& category)
+ {
+ BOOST_ASSERT(category);
+
+ return category->m_Name == name;
+ });
+
+ if (it == m_Categories.end())
+ {
+ return nullptr;
+ }
+
+ return it->get();
+ }
+
+ const Device* GetDevice(uint16_t uid) const override
+ {
+ boost::ignore_unused(uid);
+ return nullptr; // Not used by the unit tests
+ }
+
+ const CounterSet* GetCounterSet(uint16_t uid) const override
+ {
+ boost::ignore_unused(uid);
+ return nullptr; // Not used by the unit tests
+ }
+
+ const Counter* GetCounter(uint16_t uid) const override
+ {
+ boost::ignore_unused(uid);
+ return nullptr; // Not used by the unit tests
+ }
+
+private:
+ Categories m_Categories;
+ Devices m_Devices;
+ CounterSets m_CounterSets;
+ Counters m_Counters;
+};
+
+} // namespace profiling
+
+} // namespace armnn
diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp
index d6300e6429..8b4bc84bd1 100644
--- a/src/profiling/test/ProfilingTests.hpp
+++ b/src/profiling/test/ProfilingTests.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include "SendCounterPacketTests.hpp"
+#include "ProfilingMocks.hpp"
#include <armnn/Logging.hpp>
diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp
index b87583ce5e..48a0bb64de 100644
--- a/src/profiling/test/SendCounterPacketTests.cpp
+++ b/src/profiling/test/SendCounterPacketTests.cpp
@@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT
//
+#include "ProfilingMocks.hpp"
#include "SendCounterPacketTests.hpp"
#include <BufferManager.hpp>
diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp
index 4118989b1b..8b46ed17d6 100644
--- a/src/profiling/test/SendCounterPacketTests.hpp
+++ b/src/profiling/test/SendCounterPacketTests.hpp
@@ -29,622 +29,6 @@ namespace armnn
namespace profiling
{
-class MockProfilingConnection : public IProfilingConnection
-{
-public:
- MockProfilingConnection()
- : m_IsOpen(true)
- , m_WrittenData()
- , m_Packet()
- {}
-
- enum class PacketType
- {
- StreamMetaData,
- ConnectionAcknowledge,
- CounterDirectory,
- ReqCounterDirectory,
- PeriodicCounterSelection,
- PerJobCounterSelection,
- TimelineMessageDirectory,
- PeriodicCounterCapture,
- Unknown
- };
-
- bool IsOpen() const override
- {
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- return m_IsOpen;
- }
-
- void Close() override
- {
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- m_IsOpen = false;
- }
-
- bool WritePacket(const unsigned char* buffer, uint32_t length) override
- {
- if (buffer == nullptr || length == 0)
- {
- return false;
- }
-
- uint32_t header = ReadUint32(buffer, 0);
-
- uint32_t packetFamily = (header >> 26);
- uint32_t packetId = ((header >> 16) & 1023);
-
- PacketType packetType;
-
- switch (packetFamily)
- {
- case 0:
- packetType = packetId < 6 ? PacketType(packetId) : PacketType::Unknown;
- break;
- case 1:
- packetType = packetId == 0 ? PacketType::TimelineMessageDirectory : PacketType::Unknown;
- break;
- case 3:
- packetType = packetId == 0 ? PacketType::PeriodicCounterCapture : PacketType::Unknown;
- break;
- default:
- packetType = PacketType::Unknown;
- }
-
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- m_WrittenData.push_back({ packetType, length });
- return true;
- }
-
- long CheckForPacket(const std::pair<PacketType, uint32_t> packetInfo)
- {
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- if(packetInfo.second != 0)
- {
- return std::count(m_WrittenData.begin(), m_WrittenData.end(), packetInfo);
- }
- else
- {
- return std::count_if(m_WrittenData.begin(), m_WrittenData.end(),
- [&packetInfo](const std::pair<PacketType, uint32_t> pair) { return packetInfo.first == pair.first; });
- }
- }
-
- bool WritePacket(Packet&& packet)
- {
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- m_Packet = std::move(packet);
- return true;
- }
-
- Packet ReadPacket(uint32_t timeout) override
- {
- boost::ignore_unused(timeout);
-
- // Simulate a delay in the reading process. The default timeout is way too long.
- std::this_thread::sleep_for(std::chrono::milliseconds(5));
- std::lock_guard<std::mutex> lock(m_Mutex);
- return std::move(m_Packet);
- }
-
- unsigned long GetWrittenDataSize()
- {
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- return m_WrittenData.size();
- }
-
- void Clear()
- {
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- m_WrittenData.clear();
- }
-
-private:
- bool m_IsOpen;
- std::vector<std::pair<PacketType, uint32_t>> m_WrittenData;
- Packet m_Packet;
- mutable std::mutex m_Mutex;
-};
-
-class MockProfilingConnectionFactory : public IProfilingConnectionFactory
-{
-public:
- IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
- {
- boost::ignore_unused(options);
- return std::make_unique<MockProfilingConnection>();
- }
-};
-
-class MockPacketBuffer : public IPacketBuffer
-{
-public:
- MockPacketBuffer(unsigned int maxSize)
- : m_MaxSize(maxSize)
- , m_Size(0)
- , m_Data(std::make_unique<unsigned char[]>(m_MaxSize))
- {}
-
- ~MockPacketBuffer() {}
-
- const unsigned char* GetReadableData() const override { return m_Data.get(); }
-
- unsigned int GetSize() const override { return m_Size; }
-
- void MarkRead() override { m_Size = 0; }
-
- void Commit(unsigned int size) override { m_Size = size; }
-
- void Release() override { m_Size = 0; }
-
- unsigned char* GetWritableData() override { return m_Data.get(); }
-
-private:
- unsigned int m_MaxSize;
- unsigned int m_Size;
- std::unique_ptr<unsigned char[]> m_Data;
-};
-
-class MockBufferManager : public IBufferManager
-{
-public:
- MockBufferManager(unsigned int size)
- : m_BufferSize(size),
- m_Buffer(std::make_unique<MockPacketBuffer>(size)) {}
-
- ~MockBufferManager() {}
-
- IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
- {
- if (requestedSize > m_BufferSize)
- {
- reservedSize = m_BufferSize;
- }
- else
- {
- reservedSize = requestedSize;
- }
-
- return std::move(m_Buffer);
- }
-
- void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
- {
- packetBuffer->Commit(size);
- m_Buffer = std::move(packetBuffer);
-
- if (notifyConsumer)
- {
- FlushReadList();
- }
- }
-
- IPacketBufferPtr GetReadableBuffer() override
- {
- return std::move(m_Buffer);
- }
-
- void Release(IPacketBufferPtr& packetBuffer) override
- {
- packetBuffer->Release();
- m_Buffer = std::move(packetBuffer);
- }
-
- void MarkRead(IPacketBufferPtr& packetBuffer) override
- {
- packetBuffer->MarkRead();
- m_Buffer = std::move(packetBuffer);
- }
-
- void SetConsumer(IConsumer* consumer) override
- {
- if (consumer != nullptr)
- {
- m_Consumer = consumer;
- }
- }
-
- void FlushReadList() override
- {
- // notify consumer that packet is ready to read
- if (m_Consumer != nullptr)
- {
- m_Consumer->SetReadyToRead();
- }
- }
-
-private:
- unsigned int m_BufferSize;
- IPacketBufferPtr m_Buffer;
- IConsumer* m_Consumer = nullptr;
-};
-
-class MockStreamCounterBuffer : public IBufferManager
-{
-public:
- MockStreamCounterBuffer(unsigned int maxBufferSize = 4096)
- : m_MaxBufferSize(maxBufferSize)
- , m_BufferList()
- , m_CommittedSize(0)
- , m_ReadableSize(0)
- , m_ReadSize(0)
- {}
- ~MockStreamCounterBuffer() {}
-
- IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
- {
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- reservedSize = 0;
- if (requestedSize > m_MaxBufferSize)
- {
- throw armnn::InvalidArgumentException("The maximum buffer size that can be requested is [" +
- std::to_string(m_MaxBufferSize) + "] bytes");
- }
- reservedSize = requestedSize;
- return std::make_unique<MockPacketBuffer>(requestedSize);
- }
-
- void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
- {
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- packetBuffer->Commit(size);
- m_BufferList.push_back(std::move(packetBuffer));
- m_CommittedSize += size;
-
- if (notifyConsumer)
- {
- FlushReadList();
- }
- }
-
- void Release(IPacketBufferPtr& packetBuffer) override
- {
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- packetBuffer->Release();
- }
-
- IPacketBufferPtr GetReadableBuffer() override
- {
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- if (m_BufferList.empty())
- {
- return nullptr;
- }
- IPacketBufferPtr buffer = std::move(m_BufferList.back());
- m_BufferList.pop_back();
- m_ReadableSize += buffer->GetSize();
- return buffer;
- }
-
- void MarkRead(IPacketBufferPtr& packetBuffer) override
- {
- std::lock_guard<std::mutex> lock(m_Mutex);
-
- m_ReadSize += packetBuffer->GetSize();
- packetBuffer->MarkRead();
- }
-
- void SetConsumer(IConsumer* consumer) override
- {
- if (consumer != nullptr)
- {
- m_Consumer = consumer;
- }
- }
-
- void FlushReadList() override
- {
- // notify consumer that packet is ready to read
- if (m_Consumer != nullptr)
- {
- m_Consumer->SetReadyToRead();
- }
- }
-
- unsigned int GetCommittedSize() const { return m_CommittedSize; }
- unsigned int GetReadableSize() const { return m_ReadableSize; }
- unsigned int GetReadSize() const { return m_ReadSize; }
-
-private:
- // The maximum buffer size when creating a new buffer
- unsigned int m_MaxBufferSize;
-
- // A list of buffers
- std::vector<IPacketBufferPtr> m_BufferList;
-
- // The mutex to synchronize this mock's methods
- std::mutex m_Mutex;
-
- // The total size of the buffers that has been committed for reading
- unsigned int m_CommittedSize;
-
- // The total size of the buffers that can be read
- unsigned int m_ReadableSize;
-
- // The total size of the buffers that has already been read
- unsigned int m_ReadSize;
-
- // Consumer thread to notify packet is ready to read
- IConsumer* m_Consumer = nullptr;
-};
-
-class MockSendCounterPacket : public ISendCounterPacket
-{
-public:
- MockSendCounterPacket(IBufferManager& sendBuffer) : m_BufferManager(sendBuffer) {}
-
- void SendStreamMetaDataPacket() override
- {
- std::string message("SendStreamMetaDataPacket");
- unsigned int reserved = 0;
- IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
- memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
- m_BufferManager.Commit(buffer, reserved, false);
- }
-
- void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override
- {
- boost::ignore_unused(counterDirectory);
-
- std::string message("SendCounterDirectoryPacket");
- unsigned int reserved = 0;
- IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
- memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
- m_BufferManager.Commit(buffer, reserved);
- }
-
- void SendPeriodicCounterCapturePacket(uint64_t timestamp,
- const std::vector<CounterValue>& values) override
- {
- boost::ignore_unused(timestamp, values);
-
- std::string message("SendPeriodicCounterCapturePacket");
- unsigned int reserved = 0;
- IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
- memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
- m_BufferManager.Commit(buffer, reserved);
- }
-
- void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
- const std::vector<uint16_t>& selectedCounterIds) override
- {
- boost::ignore_unused(capturePeriod, selectedCounterIds);
-
- std::string message("SendPeriodicCounterSelectionPacket");
- unsigned int reserved = 0;
- IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
- memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
- m_BufferManager.Commit(buffer, reserved);
- }
-
-private:
- IBufferManager& m_BufferManager;
-};
-
-class MockCounterDirectory : public ICounterDirectory
-{
-public:
- MockCounterDirectory() = default;
- ~MockCounterDirectory() = default;
-
- // Register profiling objects
- const Category* RegisterCategory(const std::string& categoryName,
- const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
- const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
- {
- // Get the device UID
- uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
-
- // Get the counter set UID
- uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
-
- // Create the category
- CategoryPtr category = std::make_unique<Category>(categoryName, deviceUidValue, counterSetUidValue);
- BOOST_ASSERT(category);
-
- // Get the raw category pointer
- const Category* categoryPtr = category.get();
- BOOST_ASSERT(categoryPtr);
-
- // Register the category
- m_Categories.insert(std::move(category));
-
- return categoryPtr;
- }
-
- const Device* RegisterDevice(const std::string& deviceName,
- uint16_t cores = 0,
- const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
- {
- // Get the device UID
- uint16_t deviceUid = GetNextUid();
-
- // Create the device
- DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, cores);
- BOOST_ASSERT(device);
-
- // Get the raw device pointer
- const Device* devicePtr = device.get();
- BOOST_ASSERT(devicePtr);
-
- // Register the device
- m_Devices.insert(std::make_pair(deviceUid, std::move(device)));
-
- // Connect the counter set to the parent category, if required
- if (parentCategoryName.has_value())
- {
- // Set the counter set UID in the parent category
- Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
- BOOST_ASSERT(parentCategory);
- parentCategory->m_DeviceUid = deviceUid;
- }
-
- return devicePtr;
- }
-
- const CounterSet* RegisterCounterSet(
- const std::string& counterSetName,
- uint16_t count = 0,
- const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
- {
- // Get the counter set UID
- uint16_t counterSetUid = GetNextUid();
-
- // Create the counter set
- CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, count);
- BOOST_ASSERT(counterSet);
-
- // Get the raw counter set pointer
- const CounterSet* counterSetPtr = counterSet.get();
- BOOST_ASSERT(counterSetPtr);
-
- // Register the counter set
- m_CounterSets.insert(std::make_pair(counterSetUid, std::move(counterSet)));
-
- // Connect the counter set to the parent category, if required
- if (parentCategoryName.has_value())
- {
- // Set the counter set UID in the parent category
- Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
- BOOST_ASSERT(parentCategory);
- parentCategory->m_CounterSetUid = counterSetUid;
- }
-
- return counterSetPtr;
- }
-
- const Counter* RegisterCounter(const BackendId& backendId,
- const uint16_t uid,
- const std::string& parentCategoryName,
- uint16_t counterClass,
- uint16_t interpolation,
- double multiplier,
- const std::string& name,
- const std::string& description,
- const armnn::Optional<std::string>& units = armnn::EmptyOptional(),
- const armnn::Optional<uint16_t>& numberOfCores = armnn::EmptyOptional(),
- const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
- const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
- {
- boost::ignore_unused(backendId);
-
- // Get the number of cores from the argument only
- uint16_t deviceCores = numberOfCores.has_value() ? numberOfCores.value() : 0;
-
- // Get the device UID
- uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
-
- // Get the counter set UID
- uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
-
- // Get the counter UIDs and calculate the max counter UID
- std::vector<uint16_t> counterUids = GetNextCounterUids(uid, deviceCores);
- BOOST_ASSERT(!counterUids.empty());
- uint16_t maxCounterUid = deviceCores <= 1 ? counterUids.front() : counterUids.back();
-
- // Get the counter units
- const std::string unitsValue = units.has_value() ? units.value() : "";
-
- // Create the counter
- CounterPtr counter = std::make_shared<Counter>(armnn::profiling::BACKEND_ID,
- counterUids.front(),
- maxCounterUid,
- counterClass,
- interpolation,
- multiplier,
- name,
- description,
- unitsValue,
- deviceUidValue,
- counterSetUidValue);
- BOOST_ASSERT(counter);
-
- // Get the raw counter pointer
- const Counter* counterPtr = counter.get();
- BOOST_ASSERT(counterPtr);
-
- // Process multiple counters if necessary
- for (uint16_t counterUid : counterUids)
- {
- // Connect the counter to the parent category
- Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName));
- BOOST_ASSERT(parentCategory);
- parentCategory->m_Counters.push_back(counterUid);
-
- // Register the counter
- m_Counters.insert(std::make_pair(counterUid, counter));
- }
-
- return counterPtr;
- }
-
- // Getters for counts
- uint16_t GetCategoryCount() const override { return boost::numeric_cast<uint16_t>(m_Categories.size()); }
- uint16_t GetDeviceCount() const override { return boost::numeric_cast<uint16_t>(m_Devices.size()); }
- uint16_t GetCounterSetCount() const override { return boost::numeric_cast<uint16_t>(m_CounterSets.size()); }
- uint16_t GetCounterCount() const override { return boost::numeric_cast<uint16_t>(m_Counters.size()); }
-
- // Getters for collections
- const Categories& GetCategories() const override { return m_Categories; }
- const Devices& GetDevices() const override { return m_Devices; }
- const CounterSets& GetCounterSets() const override { return m_CounterSets; }
- const Counters& GetCounters() const override { return m_Counters; }
-
- // Getters for profiling objects
- const Category* GetCategory(const std::string& name) const override
- {
- auto it = std::find_if(m_Categories.begin(), m_Categories.end(), [&name](const CategoryPtr& category)
- {
- BOOST_ASSERT(category);
-
- return category->m_Name == name;
- });
-
- if (it == m_Categories.end())
- {
- return nullptr;
- }
-
- return it->get();
- }
-
- const Device* GetDevice(uint16_t uid) const override
- {
- boost::ignore_unused(uid);
- return nullptr; // Not used by the unit tests
- }
-
- const CounterSet* GetCounterSet(uint16_t uid) const override
- {
- boost::ignore_unused(uid);
- return nullptr; // Not used by the unit tests
- }
-
- const Counter* GetCounter(uint16_t uid) const override
- {
- boost::ignore_unused(uid);
- return nullptr; // Not used by the unit tests
- }
-
-private:
- Categories m_Categories;
- Devices m_Devices;
- CounterSets m_CounterSets;
- Counters m_Counters;
-};
-
class SendCounterPacketTest : public SendCounterPacket
{
public:
diff --git a/src/profiling/test/SendTimelinePacketTests.cpp b/src/profiling/test/SendTimelinePacketTests.cpp
index 8071eece7d..af15c57117 100644
--- a/src/profiling/test/SendTimelinePacketTests.cpp
+++ b/src/profiling/test/SendTimelinePacketTests.cpp
@@ -3,7 +3,7 @@
// SPDX-License-Identifier: MIT
//
-#include "SendCounterPacketTests.hpp"
+#include "ProfilingMocks.hpp"
#include <BufferManager.hpp>
#include <ProfilingService.hpp>
diff --git a/src/profiling/test/TimelineUtilityMethodsTests.cpp b/src/profiling/test/TimelineUtilityMethodsTests.cpp
index abacdb5288..efceff2859 100644
--- a/src/profiling/test/TimelineUtilityMethodsTests.cpp
+++ b/src/profiling/test/TimelineUtilityMethodsTests.cpp
@@ -3,7 +3,7 @@
// SPDX-License-Identifier: MIT
//
-#include "SendCounterPacketTests.hpp"
+#include "ProfilingMocks.hpp"
#include "ProfilingTestUtils.hpp"
#include <SendTimelinePacket.hpp>
diff --git a/tests/profiling/gatordmock/tests/GatordMockTests.cpp b/tests/profiling/gatordmock/tests/GatordMockTests.cpp
index 017a95c99f..02adffb2cc 100644
--- a/tests/profiling/gatordmock/tests/GatordMockTests.cpp
+++ b/tests/profiling/gatordmock/tests/GatordMockTests.cpp
@@ -11,7 +11,7 @@
#include <StreamMetadataCommandHandler.hpp>
#include <TimelineDirectoryCaptureCommandHandler.hpp>
-#include <test/SendCounterPacketTests.hpp>
+#include <test/ProfilingMocks.hpp>
#include <boost/cast.hpp>
#include <boost/test/test_tools.hpp>