ArmNN
 21.08
ProfilingTests.cpp File Reference
#include "ProfilingTests.hpp"
#include "ProfilingTestUtils.hpp"
#include <backends/BackendProfiling.hpp>
#include <common/include/EncodeVersion.hpp>
#include <common/include/PacketVersionResolver.hpp>
#include <common/include/SwTrace.hpp>
#include <CommandHandler.hpp>
#include <ConnectionAcknowledgedCommandHandler.hpp>
#include <CounterDirectory.hpp>
#include <CounterIdMap.hpp>
#include <Holder.hpp>
#include <ICounterValues.hpp>
#include <PeriodicCounterCapture.hpp>
#include <PeriodicCounterSelectionCommandHandler.hpp>
#include <ProfilingStateMachine.hpp>
#include <ProfilingUtils.hpp>
#include <RegisterBackendCounters.hpp>
#include <RequestCounterDirectoryCommandHandler.hpp>
#include <Runtime.hpp>
#include <SocketProfilingConnection.hpp>
#include <SendCounterPacket.hpp>
#include <SendThread.hpp>
#include <SendTimelinePacket.hpp>
#include <armnn/Conversion.hpp>
#include <armnn/Types.hpp>
#include <armnn/Utils.hpp>
#include <armnn/utility/IgnoreUnused.hpp>
#include <armnn/utility/NumericCast.hpp>
#include <common/include/CommandHandlerKey.hpp>
#include <common/include/CommandHandlerRegistry.hpp>
#include <common/include/SocketConnectionException.hpp>
#include <common/include/Packet.hpp>
#include <doctest/doctest.h>
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <limits>
#include <map>
#include <random>

Go to the source code of this file.

Typedefs

using PacketType = MockProfilingConnection::PacketType
 

Functions

 TEST_SUITE ("ExternalProfiling")
 

Typedef Documentation

◆ PacketType

Function Documentation

◆ TEST_SUITE()

TEST_SUITE ( "ExternalProfiling"  )

Definition at line 57 of file ProfilingTests.cpp.

References armnn::profiling::Active, ProfilingService::AddCounterValue(), armnn::profiling::ConstructHeader(), ProfilingServiceRuntimeHelper::ForceTransitionToState(), ProfilingService::GetAbsoluteCounterValue(), Holder::GetCaptureData(), CaptureData::GetCapturePeriod(), CounterDirectory::GetCategory(), CounterDirectory::GetCategoryCount(), TestFunctorA::GetCount(), CounterDirectory::GetCounter(), CounterDirectory::GetCounterCount(), ICounterDirectory::GetCounterCount(), ProfilingService::GetCounterDirectory(), CaptureData::GetCounterIds(), ICounterDirectory::GetCounters(), CounterDirectory::GetCounterSet(), CounterDirectory::GetCounterSetCount(), ProfilingStateMachine::GetCurrentState(), ProfilingService::GetCurrentState(), ProfilingService::GetDeltaCounterValue(), CounterDirectory::GetDevice(), CounterDirectory::GetDeviceCount(), armnn::profiling::GetNextCounterUids(), armnn::profiling::GetNextUid(), ProfilingServiceRuntimeHelper::GetProfilingBufferManager(), armnn::GetProfilingService(), BufferManager::GetReadableBuffer(), MockBufferManager::GetReadableBuffer(), armnn::IgnoreUnused(), ProfilingService::IncrementCounterValue(), CommandHandler::IsRunning(), armnn::LOWEST_CAPTURE_PERIOD, Counter::m_Class, Device::m_Cores, CounterSet::m_Count, Category::m_Counters, Counter::m_CounterSetUid, Counter::m_Description, Counter::m_DeviceUid, IRuntime::CreationOptions::ExternalProfilingOptions::m_EnableProfiling, Counter::m_Interpolation, Counter::m_MaxCounterUid, Counter::m_Multiplier, Category::m_Name, Device::m_Name, CounterSet::m_Name, Counter::m_Name, IRuntime::CreationOptions::m_ProfilingOptions, Device::m_Uid, CounterSet::m_Uid, Counter::m_Uid, Counter::m_Units, MockBufferManager::MarkRead(), armnn::profiling::NotConnected, armnn::numeric_cast(), TestProfilingConnectionTimeoutError::ReadCalledCount(), armnn::profiling::ReadUint16(), armnn::profiling::ReadUint32(), CounterDirectory::RegisterCategory(), CounterDirectory::RegisterCounter(), CounterDirectory::RegisterCounterSet(), CounterDirectory::RegisterDevice(), ProfilingService::ResetExternalProfilingOptions(), Holder::SetCaptureData(), CaptureData::SetCapturePeriod(), CaptureData::SetCounterIds(), ProfilingService::SetCounterValue(), CommandHandler::SetStopAfterTimeout(), CommandHandler::Start(), PeriodicCounterCapture::Start(), CommandHandler::Stop(), PeriodicCounterCapture::Stop(), ProfilingService::SubtractCounterValue(), ProfilingStateMachine::TransitionToState(), armnn::profiling::Uninitialised, ProfilingService::Update(), armnn::profiling::WaitingForAck, armnn::profiling::WriteUint16(), and armnn::profiling::WriteUint32().

58 {
59 TEST_CASE("CheckCommandHandlerKeyComparisons")
60 {
61  arm::pipe::CommandHandlerKey testKey1_0(1, 1, 1);
62  arm::pipe::CommandHandlerKey testKey1_1(1, 1, 1);
63  arm::pipe::CommandHandlerKey testKey1_2(1, 2, 1);
64 
65  arm::pipe::CommandHandlerKey testKey0(0, 1, 1);
66  arm::pipe::CommandHandlerKey testKey1(0, 1, 1);
67  arm::pipe::CommandHandlerKey testKey2(0, 1, 1);
68  arm::pipe::CommandHandlerKey testKey3(0, 0, 0);
69  arm::pipe::CommandHandlerKey testKey4(0, 2, 2);
70  arm::pipe::CommandHandlerKey testKey5(0, 0, 2);
71 
72  CHECK(testKey1_0 > testKey0);
73  CHECK(testKey1_0 == testKey1_1);
74  CHECK(testKey1_0 < testKey1_2);
75 
76  CHECK(testKey1 < testKey4);
77  CHECK(testKey1 > testKey3);
78  CHECK(testKey1 <= testKey4);
79  CHECK(testKey1 >= testKey3);
80  CHECK(testKey1 <= testKey2);
81  CHECK(testKey1 >= testKey2);
82  CHECK(testKey1 == testKey2);
83  CHECK(testKey1 == testKey1);
84 
85  CHECK(!(testKey1 == testKey5));
86  CHECK(!(testKey1 != testKey1));
87  CHECK(testKey1 != testKey5);
88 
89  CHECK((testKey1 == testKey2 && testKey2 == testKey1));
90  CHECK((testKey0 == testKey1 && testKey1 == testKey2 && testKey0 == testKey2));
91 
92  CHECK(testKey1.GetPacketId() == 1);
93  CHECK(testKey1.GetVersion() == 1);
94 
95  std::vector<arm::pipe::CommandHandlerKey> vect = {
96  arm::pipe::CommandHandlerKey(0, 0, 1), arm::pipe::CommandHandlerKey(0, 2, 0),
97  arm::pipe::CommandHandlerKey(0, 1, 0), arm::pipe::CommandHandlerKey(0, 2, 1),
98  arm::pipe::CommandHandlerKey(0, 1, 1), arm::pipe::CommandHandlerKey(0, 0, 1),
99  arm::pipe::CommandHandlerKey(0, 2, 0), arm::pipe::CommandHandlerKey(0, 0, 0) };
100 
101  std::sort(vect.begin(), vect.end());
102 
103  std::vector<arm::pipe::CommandHandlerKey> expectedVect = {
104  arm::pipe::CommandHandlerKey(0, 0, 0), arm::pipe::CommandHandlerKey(0, 0, 1),
105  arm::pipe::CommandHandlerKey(0, 0, 1), arm::pipe::CommandHandlerKey(0, 1, 0),
106  arm::pipe::CommandHandlerKey(0, 1, 1), arm::pipe::CommandHandlerKey(0, 2, 0),
107  arm::pipe::CommandHandlerKey(0, 2, 0), arm::pipe::CommandHandlerKey(0, 2, 1) };
108 
109  CHECK(vect == expectedVect);
110 }
111 
112 TEST_CASE("CheckPacketKeyComparisons")
113 {
114  arm::pipe::PacketKey key0(0, 0);
115  arm::pipe::PacketKey key1(0, 0);
116  arm::pipe::PacketKey key2(0, 1);
117  arm::pipe::PacketKey key3(0, 2);
118  arm::pipe::PacketKey key4(1, 0);
119  arm::pipe::PacketKey key5(1, 0);
120  arm::pipe::PacketKey key6(1, 1);
121 
122  CHECK(!(key0 < key1));
123  CHECK(!(key0 > key1));
124  CHECK(key0 <= key1);
125  CHECK(key0 >= key1);
126  CHECK(key0 == key1);
127  CHECK(key0 < key2);
128  CHECK(key2 < key3);
129  CHECK(key3 > key0);
130  CHECK(key4 == key5);
131  CHECK(key4 > key0);
132  CHECK(key5 < key6);
133  CHECK(key5 <= key6);
134  CHECK(key5 != key6);
135 }
136 
137 TEST_CASE("CheckCommandHandler")
138 {
139  arm::pipe::PacketVersionResolver packetVersionResolver;
140  ProfilingStateMachine profilingStateMachine;
141 
142  TestProfilingConnectionBase testProfilingConnectionBase;
143  TestProfilingConnectionTimeoutError testProfilingConnectionTimeOutError;
144  TestProfilingConnectionArmnnError testProfilingConnectionArmnnError;
145  CounterDirectory counterDirectory;
146  MockBufferManager mockBuffer(1024);
147  SendCounterPacket sendCounterPacket(mockBuffer);
148  SendThread sendThread(profilingStateMachine, mockBuffer, sendCounterPacket);
149  SendTimelinePacket sendTimelinePacket(mockBuffer);
150  MockProfilingServiceStatus mockProfilingServiceStatus;
151 
152  ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(0, 1, 4194304, counterDirectory,
153  sendCounterPacket, sendTimelinePacket,
154  profilingStateMachine,
155  mockProfilingServiceStatus);
156  arm::pipe::CommandHandlerRegistry commandHandlerRegistry;
157 
158  commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler);
159 
160  profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
161  profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
162 
163  CommandHandler commandHandler0(1, true, commandHandlerRegistry, packetVersionResolver);
164 
165  // This should start the command handler thread return the connection ack and put the profiling
166  // service into active state.
167  commandHandler0.Start(testProfilingConnectionBase);
168  // Try to start the send thread many times, it must only start once
169  commandHandler0.Start(testProfilingConnectionBase);
170 
171  // This could take up to 20mSec but we'll check often.
172  for (int i = 0; i < 10; i++)
173  {
174  if (profilingStateMachine.GetCurrentState() == ProfilingState::Active)
175  {
176  break;
177  }
178  std::this_thread::sleep_for(std::chrono::milliseconds(2));
179  }
180 
181  CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
182 
183  // Close the thread again.
184  commandHandler0.Stop();
185 
186  profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
187  profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
188 
189  // In this test we'll simulate a timeout without a connection ack packet being received.
190  // Stop after timeout is set so we expect the command handler to stop almost immediately.
191  CommandHandler commandHandler1(1, true, commandHandlerRegistry, packetVersionResolver);
192 
193  commandHandler1.Start(testProfilingConnectionTimeOutError);
194  // Wait until we know a timeout exception has been sent at least once.
195  for (int i = 0; i < 10; i++)
196  {
197  if (testProfilingConnectionTimeOutError.ReadCalledCount())
198  {
199  break;
200  }
201  std::this_thread::sleep_for(std::chrono::milliseconds(2));
202  }
203 
204  // The command handler loop should have stopped after the timeout.
205  // wait for the timeout exception to be processed and the loop to break.
206  uint32_t timeout = 50;
207  uint32_t timeSlept = 0;
208  while (commandHandler1.IsRunning())
209  {
210  if (timeSlept >= timeout)
211  {
212  FAIL("Timeout: The command handler loop did not stop after the timeout");
213  }
214  std::this_thread::sleep_for(std::chrono::milliseconds(1));
215  timeSlept ++;
216  }
217 
218  commandHandler1.Stop();
219  // The state machine should never have received the ack so will still be in WaitingForAck.
220  CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck);
221 
222  // Now try sending a bad connection acknowledged packet
223  TestProfilingConnectionBadAckPacket testProfilingConnectionBadAckPacket;
224  commandHandler1.Start(testProfilingConnectionBadAckPacket);
225  commandHandler1.Stop();
226  // This should also not change the state machine
227  CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck);
228 
229  // Disable stop after timeout and now commandHandler1 should persist after a timeout
230  commandHandler1.SetStopAfterTimeout(false);
231  // Restart the thread.
232  commandHandler1.Start(testProfilingConnectionTimeOutError);
233 
234  // Wait for at the three timeouts and the ack to be sent.
235  for (int i = 0; i < 10; i++)
236  {
237  if (testProfilingConnectionTimeOutError.ReadCalledCount() > 3)
238  {
239  break;
240  }
241  std::this_thread::sleep_for(std::chrono::milliseconds(2));
242  }
243  commandHandler1.Stop();
244 
245  // Even after the 3 exceptions the ack packet should have transitioned the command handler to active.
246  CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
247 
248  // A command handler that gets exceptions other than timeouts should keep going.
249  CommandHandler commandHandler2(1, false, commandHandlerRegistry, packetVersionResolver);
250 
251  commandHandler2.Start(testProfilingConnectionArmnnError);
252 
253  // Wait for two exceptions to be thrown.
254  for (int i = 0; i < 10; i++)
255  {
256  if (testProfilingConnectionTimeOutError.ReadCalledCount() >= 2)
257  {
258  break;
259  }
260  std::this_thread::sleep_for(std::chrono::milliseconds(2));
261  }
262 
263  CHECK(commandHandler2.IsRunning());
264  commandHandler2.Stop();
265 }
266 
267 TEST_CASE("CheckEncodeVersion")
268 {
269  arm::pipe::Version version1(12);
270 
271  CHECK(version1.GetMajor() == 0);
272  CHECK(version1.GetMinor() == 0);
273  CHECK(version1.GetPatch() == 12);
274 
275  arm::pipe::Version version2(4108);
276 
277  CHECK(version2.GetMajor() == 0);
278  CHECK(version2.GetMinor() == 1);
279  CHECK(version2.GetPatch() == 12);
280 
281  arm::pipe::Version version3(4198412);
282 
283  CHECK(version3.GetMajor() == 1);
284  CHECK(version3.GetMinor() == 1);
285  CHECK(version3.GetPatch() == 12);
286 
287  arm::pipe::Version version4(0);
288 
289  CHECK(version4.GetMajor() == 0);
290  CHECK(version4.GetMinor() == 0);
291  CHECK(version4.GetPatch() == 0);
292 
293  arm::pipe::Version version5(1, 0, 0);
294  CHECK(version5.GetEncodedValue() == 4194304);
295 }
296 
297 TEST_CASE("CheckPacketClass")
298 {
299  uint32_t length = 4;
300  std::unique_ptr<unsigned char[]> packetData0 = std::make_unique<unsigned char[]>(length);
301  std::unique_ptr<unsigned char[]> packetData1 = std::make_unique<unsigned char[]>(0);
302  std::unique_ptr<unsigned char[]> nullPacketData;
303 
304  arm::pipe::Packet packetTest0(472580096, length, packetData0);
305 
306  CHECK(packetTest0.GetHeader() == 472580096);
307  CHECK(packetTest0.GetPacketFamily() == 7);
308  CHECK(packetTest0.GetPacketId() == 43);
309  CHECK(packetTest0.GetLength() == length);
310  CHECK(packetTest0.GetPacketType() == 3);
311  CHECK(packetTest0.GetPacketClass() == 5);
312 
313  CHECK_THROWS_AS(arm::pipe::Packet packetTest1(472580096, 0, packetData1), arm::pipe::InvalidArgumentException);
314  CHECK_NOTHROW(arm::pipe::Packet packetTest2(472580096, 0, nullPacketData));
315 
316  arm::pipe::Packet packetTest3(472580096, 0, nullPacketData);
317  CHECK(packetTest3.GetLength() == 0);
318  CHECK(packetTest3.GetData() == nullptr);
319 
320  const unsigned char* packetTest0Data = packetTest0.GetData();
321  arm::pipe::Packet packetTest4(std::move(packetTest0));
322 
323  CHECK(packetTest0.GetData() == nullptr);
324  CHECK(packetTest4.GetData() == packetTest0Data);
325 
326  CHECK(packetTest4.GetHeader() == 472580096);
327  CHECK(packetTest4.GetPacketFamily() == 7);
328  CHECK(packetTest4.GetPacketId() == 43);
329  CHECK(packetTest4.GetLength() == length);
330  CHECK(packetTest4.GetPacketType() == 3);
331  CHECK(packetTest4.GetPacketClass() == 5);
332 }
333 
334 TEST_CASE("CheckCommandHandlerFunctor")
335 {
336  // Hard code the version as it will be the same during a single profiling session
337  uint32_t version = 1;
338 
339  TestFunctorA testFunctorA(7, 461, version);
340  TestFunctorB testFunctorB(8, 963, version);
341  TestFunctorC testFunctorC(5, 983, version);
342 
343  arm::pipe::CommandHandlerKey keyA(
344  testFunctorA.GetFamilyId(), testFunctorA.GetPacketId(), testFunctorA.GetVersion());
345  arm::pipe::CommandHandlerKey keyB(
346  testFunctorB.GetFamilyId(), testFunctorB.GetPacketId(), testFunctorB.GetVersion());
347  arm::pipe::CommandHandlerKey keyC(
348  testFunctorC.GetFamilyId(), testFunctorC.GetPacketId(), testFunctorC.GetVersion());
349 
350  // Create the unwrapped map to simulate the Command Handler Registry
351  std::map<arm::pipe::CommandHandlerKey, arm::pipe::CommandHandlerFunctor*> registry;
352 
353  registry.insert(std::make_pair(keyB, &testFunctorB));
354  registry.insert(std::make_pair(keyA, &testFunctorA));
355  registry.insert(std::make_pair(keyC, &testFunctorC));
356 
357  // Check the order of the map is correct
358  auto it = registry.begin();
359  CHECK(it->first == keyC); // familyId == 5
360  it++;
361  CHECK(it->first == keyA); // familyId == 7
362  it++;
363  CHECK(it->first == keyB); // familyId == 8
364 
365  std::unique_ptr<unsigned char[]> packetDataA;
366  std::unique_ptr<unsigned char[]> packetDataB;
367  std::unique_ptr<unsigned char[]> packetDataC;
368 
369  arm::pipe::Packet packetA(500000000, 0, packetDataA);
370  arm::pipe::Packet packetB(600000000, 0, packetDataB);
371  arm::pipe::Packet packetC(400000000, 0, packetDataC);
372 
373  // Check the correct operator of derived class is called
374  registry.at(arm::pipe::CommandHandlerKey(
375  packetA.GetPacketFamily(), packetA.GetPacketId(), version))->operator()(packetA);
376  CHECK(testFunctorA.GetCount() == 1);
377  CHECK(testFunctorB.GetCount() == 0);
378  CHECK(testFunctorC.GetCount() == 0);
379 
380  registry.at(arm::pipe::CommandHandlerKey(
381  packetB.GetPacketFamily(), packetB.GetPacketId(), version))->operator()(packetB);
382  CHECK(testFunctorA.GetCount() == 1);
383  CHECK(testFunctorB.GetCount() == 1);
384  CHECK(testFunctorC.GetCount() == 0);
385 
386  registry.at(arm::pipe::CommandHandlerKey(
387  packetC.GetPacketFamily(), packetC.GetPacketId(), version))->operator()(packetC);
388  CHECK(testFunctorA.GetCount() == 1);
389  CHECK(testFunctorB.GetCount() == 1);
390  CHECK(testFunctorC.GetCount() == 1);
391 }
392 
393 TEST_CASE("CheckCommandHandlerRegistry")
394 {
395  // Hard code the version as it will be the same during a single profiling session
396  uint32_t version = 1;
397 
398  TestFunctorA testFunctorA(7, 461, version);
399  TestFunctorB testFunctorB(8, 963, version);
400  TestFunctorC testFunctorC(5, 983, version);
401 
402  // Create the Command Handler Registry
403  arm::pipe::CommandHandlerRegistry registry;
404 
405  // Register multiple different derived classes
406  registry.RegisterFunctor(&testFunctorA);
407  registry.RegisterFunctor(&testFunctorB);
408  registry.RegisterFunctor(&testFunctorC);
409 
410  std::unique_ptr<unsigned char[]> packetDataA;
411  std::unique_ptr<unsigned char[]> packetDataB;
412  std::unique_ptr<unsigned char[]> packetDataC;
413 
414  arm::pipe::Packet packetA(500000000, 0, packetDataA);
415  arm::pipe::Packet packetB(600000000, 0, packetDataB);
416  arm::pipe::Packet packetC(400000000, 0, packetDataC);
417 
418  // Check the correct operator of derived class is called
419  registry.GetFunctor(packetA.GetPacketFamily(), packetA.GetPacketId(), version)->operator()(packetA);
420  CHECK(testFunctorA.GetCount() == 1);
421  CHECK(testFunctorB.GetCount() == 0);
422  CHECK(testFunctorC.GetCount() == 0);
423 
424  registry.GetFunctor(packetB.GetPacketFamily(), packetB.GetPacketId(), version)->operator()(packetB);
425  CHECK(testFunctorA.GetCount() == 1);
426  CHECK(testFunctorB.GetCount() == 1);
427  CHECK(testFunctorC.GetCount() == 0);
428 
429  registry.GetFunctor(packetC.GetPacketFamily(), packetC.GetPacketId(), version)->operator()(packetC);
430  CHECK(testFunctorA.GetCount() == 1);
431  CHECK(testFunctorB.GetCount() == 1);
432  CHECK(testFunctorC.GetCount() == 1);
433 
434  // Re-register an existing key with a new function
435  registry.RegisterFunctor(&testFunctorC, testFunctorA.GetFamilyId(), testFunctorA.GetPacketId(), version);
436  registry.GetFunctor(packetA.GetPacketFamily(), packetA.GetPacketId(), version)->operator()(packetC);
437  CHECK(testFunctorA.GetCount() == 1);
438  CHECK(testFunctorB.GetCount() == 1);
439  CHECK(testFunctorC.GetCount() == 2);
440 
441  // Check that non-existent key returns nullptr for its functor
442  CHECK_THROWS_AS(registry.GetFunctor(0, 0, 0), arm::pipe::ProfilingException);
443 }
444 
445 TEST_CASE("CheckPacketVersionResolver")
446 {
447  // Set up random number generator for generating packetId values
448  std::random_device device;
449  std::mt19937 generator(device());
450  std::uniform_int_distribution<uint32_t> distribution(std::numeric_limits<uint32_t>::min(),
451  std::numeric_limits<uint32_t>::max());
452 
453  // NOTE: Expected version is always 1.0.0, regardless of packetId
454  const arm::pipe::Version expectedVersion(1, 0, 0);
455 
456  arm::pipe::PacketVersionResolver packetVersionResolver;
457 
458  constexpr unsigned int numTests = 10u;
459 
460  for (unsigned int i = 0u; i < numTests; ++i)
461  {
462  const uint32_t familyId = distribution(generator);
463  const uint32_t packetId = distribution(generator);
464  arm::pipe::Version resolvedVersion = packetVersionResolver.ResolvePacketVersion(familyId, packetId);
465 
466  CHECK(resolvedVersion == expectedVersion);
467  }
468 }
469 
470 void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
471 {
472  ProfilingState newState = ProfilingState::NotConnected;
473  states.GetCurrentState();
474  states.TransitionToState(newState);
475 }
476 
477 TEST_CASE("CheckProfilingStateMachine")
478 {
479  ProfilingStateMachine profilingState1(ProfilingState::Uninitialised);
480  profilingState1.TransitionToState(ProfilingState::Uninitialised);
481  CHECK(profilingState1.GetCurrentState() == ProfilingState::Uninitialised);
482 
483  ProfilingStateMachine profilingState2(ProfilingState::Uninitialised);
484  profilingState2.TransitionToState(ProfilingState::NotConnected);
485  CHECK(profilingState2.GetCurrentState() == ProfilingState::NotConnected);
486 
487  ProfilingStateMachine profilingState3(ProfilingState::NotConnected);
488  profilingState3.TransitionToState(ProfilingState::NotConnected);
489  CHECK(profilingState3.GetCurrentState() == ProfilingState::NotConnected);
490 
491  ProfilingStateMachine profilingState4(ProfilingState::NotConnected);
492  profilingState4.TransitionToState(ProfilingState::WaitingForAck);
493  CHECK(profilingState4.GetCurrentState() == ProfilingState::WaitingForAck);
494 
495  ProfilingStateMachine profilingState5(ProfilingState::WaitingForAck);
496  profilingState5.TransitionToState(ProfilingState::WaitingForAck);
497  CHECK(profilingState5.GetCurrentState() == ProfilingState::WaitingForAck);
498 
499  ProfilingStateMachine profilingState6(ProfilingState::WaitingForAck);
500  profilingState6.TransitionToState(ProfilingState::Active);
501  CHECK(profilingState6.GetCurrentState() == ProfilingState::Active);
502 
503  ProfilingStateMachine profilingState7(ProfilingState::Active);
504  profilingState7.TransitionToState(ProfilingState::NotConnected);
505  CHECK(profilingState7.GetCurrentState() == ProfilingState::NotConnected);
506 
507  ProfilingStateMachine profilingState8(ProfilingState::Active);
508  profilingState8.TransitionToState(ProfilingState::Active);
509  CHECK(profilingState8.GetCurrentState() == ProfilingState::Active);
510 
511  ProfilingStateMachine profilingState9(ProfilingState::Uninitialised);
512  CHECK_THROWS_AS(profilingState9.TransitionToState(ProfilingState::WaitingForAck), armnn::Exception);
513 
514  ProfilingStateMachine profilingState10(ProfilingState::Uninitialised);
515  CHECK_THROWS_AS(profilingState10.TransitionToState(ProfilingState::Active), armnn::Exception);
516 
517  ProfilingStateMachine profilingState11(ProfilingState::NotConnected);
518  CHECK_THROWS_AS(profilingState11.TransitionToState(ProfilingState::Uninitialised), armnn::Exception);
519 
520  ProfilingStateMachine profilingState12(ProfilingState::NotConnected);
521  CHECK_THROWS_AS(profilingState12.TransitionToState(ProfilingState::Active), armnn::Exception);
522 
523  ProfilingStateMachine profilingState13(ProfilingState::WaitingForAck);
524  CHECK_THROWS_AS(profilingState13.TransitionToState(ProfilingState::Uninitialised), armnn::Exception);
525 
526  ProfilingStateMachine profilingState14(ProfilingState::WaitingForAck);
527  profilingState14.TransitionToState(ProfilingState::NotConnected);
528  CHECK(profilingState14.GetCurrentState() == ProfilingState::NotConnected);
529 
530  ProfilingStateMachine profilingState15(ProfilingState::Active);
531  CHECK_THROWS_AS(profilingState15.TransitionToState(ProfilingState::Uninitialised), armnn::Exception);
532 
534  CHECK_THROWS_AS(profilingState16.TransitionToState(ProfilingState::WaitingForAck), armnn::Exception);
535 
536  ProfilingStateMachine profilingState17(ProfilingState::Uninitialised);
537 
538  std::vector<std::thread> threads;
539  for (unsigned int i = 0; i < 5; ++i)
540  {
541  threads.push_back(std::thread(ProfilingCurrentStateThreadImpl, std::ref(profilingState17)));
542  }
543  std::for_each(threads.begin(), threads.end(), [](std::thread& theThread)
544  {
545  theThread.join();
546  });
547 
548  CHECK((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
549 }
550 
551 void CaptureDataWriteThreadImpl(Holder& holder, uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
552 {
553  holder.SetCaptureData(capturePeriod, counterIds, {});
554 }
555 
556 void CaptureDataReadThreadImpl(const Holder& holder, CaptureData& captureData)
557 {
558  captureData = holder.GetCaptureData();
559 }
560 
561 TEST_CASE("CheckCaptureDataHolder")
562 {
563  std::map<uint32_t, std::vector<uint16_t>> periodIdMap;
564  std::vector<uint16_t> counterIds;
565  uint32_t numThreads = 10;
566  for (uint32_t i = 0; i < numThreads; ++i)
567  {
568  counterIds.emplace_back(i);
569  periodIdMap.insert(std::make_pair(i, counterIds));
570  }
571 
572  // Verify the read and write threads set the holder correctly
573  // and retrieve the expected values
574  Holder holder;
575  CHECK((holder.GetCaptureData()).GetCapturePeriod() == 0);
576  CHECK(((holder.GetCaptureData()).GetCounterIds()).empty());
577 
578  // Check Holder functions
579  std::thread thread1(CaptureDataWriteThreadImpl, std::ref(holder), 2, std::ref(periodIdMap[2]));
580  thread1.join();
581  CHECK((holder.GetCaptureData()).GetCapturePeriod() == 2);
582  CHECK((holder.GetCaptureData()).GetCounterIds() == periodIdMap[2]);
583  // NOTE: now that we have some initial values in the holder we don't have to worry
584  // in the multi-threaded section below about a read thread accessing the holder
585  // before any write thread has gotten to it so we read period = 0, counterIds empty
586  // instead of period = 0, counterIds = {0} as will the case when write thread 0
587  // has executed.
588 
589  CaptureData captureData;
590  std::thread thread2(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData));
591  thread2.join();
592  CHECK(captureData.GetCapturePeriod() == 2);
593  CHECK(captureData.GetCounterIds() == periodIdMap[2]);
594 
595  std::map<uint32_t, CaptureData> captureDataIdMap;
596  for (uint32_t i = 0; i < numThreads; ++i)
597  {
598  CaptureData perThreadCaptureData;
599  captureDataIdMap.insert(std::make_pair(i, perThreadCaptureData));
600  }
601 
602  std::vector<std::thread> threadsVect;
603  std::vector<std::thread> readThreadsVect;
604  for (uint32_t i = 0; i < numThreads; ++i)
605  {
606  threadsVect.emplace_back(
607  std::thread(CaptureDataWriteThreadImpl, std::ref(holder), i, std::ref(periodIdMap[i])));
608 
609  // Verify that the CaptureData goes into the thread in a virgin state
610  CHECK(captureDataIdMap.at(i).GetCapturePeriod() == 0);
611  CHECK(captureDataIdMap.at(i).GetCounterIds().empty());
612  readThreadsVect.emplace_back(
613  std::thread(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureDataIdMap.at(i))));
614  }
615 
616  for (uint32_t i = 0; i < numThreads; ++i)
617  {
618  threadsVect[i].join();
619  readThreadsVect[i].join();
620  }
621 
622  // Look at the CaptureData that each read thread has filled
623  // the capture period it read should match the counter ids entry
624  for (uint32_t i = 0; i < numThreads; ++i)
625  {
626  CaptureData perThreadCaptureData = captureDataIdMap.at(i);
627  CHECK(perThreadCaptureData.GetCounterIds() == periodIdMap.at(perThreadCaptureData.GetCapturePeriod()));
628  }
629 }
630 
631 TEST_CASE("CaptureDataMethods")
632 {
633  // Check CaptureData setter and getter functions
634  std::vector<uint16_t> counterIds = { 42, 29, 13 };
635  CaptureData captureData;
636  CHECK(captureData.GetCapturePeriod() == 0);
637  CHECK((captureData.GetCounterIds()).empty());
638  captureData.SetCapturePeriod(150);
639  captureData.SetCounterIds(counterIds);
640  CHECK(captureData.GetCapturePeriod() == 150);
641  CHECK(captureData.GetCounterIds() == counterIds);
642 
643  // Check assignment operator
644  CaptureData secondCaptureData;
645 
646  secondCaptureData = captureData;
647  CHECK(secondCaptureData.GetCapturePeriod() == 150);
648  CHECK(secondCaptureData.GetCounterIds() == counterIds);
649 
650  // Check copy constructor
651  CaptureData copyConstructedCaptureData(captureData);
652 
653  CHECK(copyConstructedCaptureData.GetCapturePeriod() == 150);
654  CHECK(copyConstructedCaptureData.GetCounterIds() == counterIds);
655 }
656 
657 TEST_CASE("CheckProfilingServiceDisabled")
658 {
660  armnn::profiling::ProfilingService profilingService;
661  profilingService.ResetExternalProfilingOptions(options, true);
662  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
663  profilingService.Update();
664  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
665 }
666 
667 TEST_CASE("CheckProfilingServiceCounterDirectory")
668 {
670  armnn::profiling::ProfilingService profilingService;
671  profilingService.ResetExternalProfilingOptions(options, true);
672 
673  const ICounterDirectory& counterDirectory0 = profilingService.GetCounterDirectory();
674  CHECK(counterDirectory0.GetCounterCount() == 0);
675  profilingService.Update();
676  CHECK(counterDirectory0.GetCounterCount() == 0);
677 
678  options.m_EnableProfiling = true;
679  profilingService.ResetExternalProfilingOptions(options);
680 
681  const ICounterDirectory& counterDirectory1 = profilingService.GetCounterDirectory();
682  CHECK(counterDirectory1.GetCounterCount() == 0);
683  profilingService.Update();
684  CHECK(counterDirectory1.GetCounterCount() != 0);
685  // Reset the profiling service to stop any running thread
686  options.m_EnableProfiling = false;
687  profilingService.ResetExternalProfilingOptions(options, true);
688 }
689 
690 TEST_CASE("CheckProfilingServiceCounterValues")
691 {
693  options.m_EnableProfiling = true;
694  armnn::profiling::ProfilingService profilingService;
695  profilingService.ResetExternalProfilingOptions(options, true);
696 
697  profilingService.Update();
698  const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
699  const Counters& counters = counterDirectory.GetCounters();
700  CHECK(!counters.empty());
701 
702  std::vector<std::thread> writers;
703 
704  CHECK(!counters.empty());
705  uint16_t inferencesRun = armnn::profiling::INFERENCES_RUN;
706 
707  // Test GetAbsoluteCounterValue
708  for (int i = 0; i < 4; ++i)
709  {
710  // Increment and decrement the INFERENCES_RUN counter 250 times
711  writers.push_back(std::thread([&profilingService, inferencesRun]()
712  {
713  for (int i = 0; i < 250; ++i)
714  {
715  profilingService.IncrementCounterValue(inferencesRun);
716  }
717  }));
718  // Add 10 to the INFERENCES_RUN counter 200 times
719  writers.push_back(std::thread([&profilingService, inferencesRun]()
720  {
721  for (int i = 0; i < 200; ++i)
722  {
723  profilingService.AddCounterValue(inferencesRun, 10);
724  }
725  }));
726  // Subtract 5 from the INFERENCES_RUN counter 200 times
727  writers.push_back(std::thread([&profilingService, inferencesRun]()
728  {
729  for (int i = 0; i < 200; ++i)
730  {
731  profilingService.SubtractCounterValue(inferencesRun, 5);
732  }
733  }));
734  }
735  std::for_each(writers.begin(), writers.end(), mem_fn(&std::thread::join));
736 
737  uint32_t absoluteCounterValue = 0;
738 
739  CHECK_NOTHROW(absoluteCounterValue = profilingService.GetAbsoluteCounterValue(INFERENCES_RUN));
740  CHECK(absoluteCounterValue == 5000);
741 
742  // Test SetCounterValue
743  CHECK_NOTHROW(profilingService.SetCounterValue(INFERENCES_RUN, 0));
744  CHECK_NOTHROW(absoluteCounterValue = profilingService.GetAbsoluteCounterValue(INFERENCES_RUN));
745  CHECK(absoluteCounterValue == 0);
746 
747  // Test GetDeltaCounterValue
748  writers.clear();
749  uint32_t deltaCounterValue = 0;
750  //Start a reading thread to randomly read the INFERENCES_RUN counter value
751  std::thread reader([&profilingService, inferencesRun](uint32_t& deltaCounterValue)
752  {
753  for (int i = 0; i < 300; ++i)
754  {
755  deltaCounterValue += profilingService.GetDeltaCounterValue(inferencesRun);
756  }
757  }, std::ref(deltaCounterValue));
758 
759  for (int i = 0; i < 4; ++i)
760  {
761  // Increment and decrement the INFERENCES_RUN counter 250 times
762  writers.push_back(std::thread([&profilingService, inferencesRun]()
763  {
764  for (int i = 0; i < 250; ++i)
765  {
766  profilingService.IncrementCounterValue(inferencesRun);
767  }
768  }));
769  // Add 10 to the INFERENCES_RUN counter 200 times
770  writers.push_back(std::thread([&profilingService, inferencesRun]()
771  {
772  for (int i = 0; i < 200; ++i)
773  {
774  profilingService.AddCounterValue(inferencesRun, 10);
775  }
776  }));
777  // Subtract 5 from the INFERENCES_RUN counter 200 times
778  writers.push_back(std::thread([&profilingService, inferencesRun]()
779  {
780  for (int i = 0; i < 200; ++i)
781  {
782  profilingService.SubtractCounterValue(inferencesRun, 5);
783  }
784  }));
785  }
786 
787  std::for_each(writers.begin(), writers.end(), mem_fn(&std::thread::join));
788  reader.join();
789 
790  // Do one last read in case the reader stopped early
791  deltaCounterValue += profilingService.GetDeltaCounterValue(INFERENCES_RUN);
792  CHECK(deltaCounterValue == 5000);
793 
794  // Reset the profiling service to stop any running thread
795  options.m_EnableProfiling = false;
796  profilingService.ResetExternalProfilingOptions(options, true);
797 }
798 
799 TEST_CASE("CheckProfilingObjectUids")
800 {
801  uint16_t uid = 0;
802  CHECK_NOTHROW(uid = GetNextUid());
803  CHECK(uid >= 1);
804 
805  uint16_t nextUid = 0;
806  CHECK_NOTHROW(nextUid = GetNextUid());
807  CHECK(nextUid > uid);
808 
809  std::vector<uint16_t> counterUids;
810  CHECK_NOTHROW(counterUids = GetNextCounterUids(uid,0));
811  CHECK(counterUids.size() == 1);
812 
813  std::vector<uint16_t> nextCounterUids;
814  CHECK_NOTHROW(nextCounterUids = GetNextCounterUids(nextUid, 2));
815  CHECK(nextCounterUids.size() == 2);
816  CHECK(nextCounterUids[0] > counterUids[0]);
817 
818  std::vector<uint16_t> counterUidsMultiCore;
819  uint16_t thirdUid = nextCounterUids[0];
820  uint16_t numberOfCores = 13;
821  CHECK_NOTHROW(counterUidsMultiCore = GetNextCounterUids(thirdUid, numberOfCores));
822  CHECK(counterUidsMultiCore.size() == numberOfCores);
823  CHECK(counterUidsMultiCore.front() >= nextCounterUids[0]);
824  for (size_t i = 1; i < numberOfCores; i++)
825  {
826  CHECK(counterUidsMultiCore[i] == counterUidsMultiCore[i - 1] + 1);
827  }
828  CHECK(counterUidsMultiCore.back() == counterUidsMultiCore.front() + numberOfCores - 1);
829 }
830 
831 TEST_CASE("CheckCounterDirectoryRegisterCategory")
832 {
833  CounterDirectory counterDirectory;
834  CHECK(counterDirectory.GetCategoryCount() == 0);
835  CHECK(counterDirectory.GetDeviceCount() == 0);
836  CHECK(counterDirectory.GetCounterSetCount() == 0);
837  CHECK(counterDirectory.GetCounterCount() == 0);
838 
839  // Register a category with an invalid name
840  const Category* noCategory = nullptr;
841  CHECK_THROWS_AS(noCategory = counterDirectory.RegisterCategory(""), armnn::InvalidArgumentException);
842  CHECK(counterDirectory.GetCategoryCount() == 0);
843  CHECK(!noCategory);
844 
845  // Register a category with an invalid name
846  CHECK_THROWS_AS(noCategory = counterDirectory.RegisterCategory("invalid category"),
848  CHECK(counterDirectory.GetCategoryCount() == 0);
849  CHECK(!noCategory);
850 
851  // Register a new category
852  const std::string categoryName = "some_category";
853  const Category* category = nullptr;
854  CHECK_NOTHROW(category = counterDirectory.RegisterCategory(categoryName));
855  CHECK(counterDirectory.GetCategoryCount() == 1);
856  CHECK(category);
857  CHECK(category->m_Name == categoryName);
858  CHECK(category->m_Counters.empty());
859 
860  // Get the registered category
861  const Category* registeredCategory = counterDirectory.GetCategory(categoryName);
862  CHECK(counterDirectory.GetCategoryCount() == 1);
863  CHECK(registeredCategory);
864  CHECK(registeredCategory == category);
865 
866  // Try to get a category not registered
867  const Category* notRegisteredCategory = counterDirectory.GetCategory("not_registered_category");
868  CHECK(counterDirectory.GetCategoryCount() == 1);
869  CHECK(!notRegisteredCategory);
870 
871  // Register a category already registered
872  const Category* anotherCategory = nullptr;
873  CHECK_THROWS_AS(anotherCategory = counterDirectory.RegisterCategory(categoryName),
875  CHECK(counterDirectory.GetCategoryCount() == 1);
876  CHECK(!anotherCategory);
877 
878  // Register a device for testing
879  const std::string deviceName = "some_device";
880  const Device* device = nullptr;
881  CHECK_NOTHROW(device = counterDirectory.RegisterDevice(deviceName));
882  CHECK(counterDirectory.GetDeviceCount() == 1);
883  CHECK(device);
884  CHECK(device->m_Uid >= 1);
885  CHECK(device->m_Name == deviceName);
886  CHECK(device->m_Cores == 0);
887 
888  // Register a new category not associated to any device
889  const std::string categoryWoDeviceName = "some_category_without_device";
890  const Category* categoryWoDevice = nullptr;
891  CHECK_NOTHROW(categoryWoDevice = counterDirectory.RegisterCategory(categoryWoDeviceName));
892  CHECK(counterDirectory.GetCategoryCount() == 2);
893  CHECK(categoryWoDevice);
894  CHECK(categoryWoDevice->m_Name == categoryWoDeviceName);
895  CHECK(categoryWoDevice->m_Counters.empty());
896 
897  // Register a new category associated to an invalid device name (already exist)
898  const Category* categoryInvalidDeviceName = nullptr;
899  CHECK_THROWS_AS(categoryInvalidDeviceName =
900  counterDirectory.RegisterCategory(categoryWoDeviceName),
902  CHECK(counterDirectory.GetCategoryCount() == 2);
903  CHECK(!categoryInvalidDeviceName);
904 
905  // Register a new category associated to a valid device
906  const std::string categoryWValidDeviceName = "some_category_with_valid_device";
907  const Category* categoryWValidDevice = nullptr;
908  CHECK_NOTHROW(categoryWValidDevice =
909  counterDirectory.RegisterCategory(categoryWValidDeviceName));
910  CHECK(counterDirectory.GetCategoryCount() == 3);
911  CHECK(categoryWValidDevice);
912  CHECK(categoryWValidDevice != category);
913  CHECK(categoryWValidDevice->m_Name == categoryWValidDeviceName);
914 
915  // Register a counter set for testing
916  const std::string counterSetName = "some_counter_set";
917  const CounterSet* counterSet = nullptr;
918  CHECK_NOTHROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName));
919  CHECK(counterDirectory.GetCounterSetCount() == 1);
920  CHECK(counterSet);
921  CHECK(counterSet->m_Uid >= 1);
922  CHECK(counterSet->m_Name == counterSetName);
923  CHECK(counterSet->m_Count == 0);
924 
925  // Register a new category not associated to any counter set
926  const std::string categoryWoCounterSetName = "some_category_without_counter_set";
927  const Category* categoryWoCounterSet = nullptr;
928  CHECK_NOTHROW(categoryWoCounterSet =
929  counterDirectory.RegisterCategory(categoryWoCounterSetName));
930  CHECK(counterDirectory.GetCategoryCount() == 4);
931  CHECK(categoryWoCounterSet);
932  CHECK(categoryWoCounterSet->m_Name == categoryWoCounterSetName);
933 
934  // Register a new category associated to a valid counter set
935  const std::string categoryWValidCounterSetName = "some_category_with_valid_counter_set";
936  const Category* categoryWValidCounterSet = nullptr;
937  CHECK_NOTHROW(categoryWValidCounterSet = counterDirectory.RegisterCategory(categoryWValidCounterSetName));
938  CHECK(counterDirectory.GetCategoryCount() == 5);
939  CHECK(categoryWValidCounterSet);
940  CHECK(categoryWValidCounterSet != category);
941  CHECK(categoryWValidCounterSet->m_Name == categoryWValidCounterSetName);
942 
943  // Register a new category associated to a valid device and counter set
944  const std::string categoryWValidDeviceAndValidCounterSetName = "some_category_with_valid_device_and_counter_set";
945  const Category* categoryWValidDeviceAndValidCounterSet = nullptr;
946  CHECK_NOTHROW(categoryWValidDeviceAndValidCounterSet = counterDirectory.RegisterCategory(
947  categoryWValidDeviceAndValidCounterSetName));
948  CHECK(counterDirectory.GetCategoryCount() == 6);
949  CHECK(categoryWValidDeviceAndValidCounterSet);
950  CHECK(categoryWValidDeviceAndValidCounterSet != category);
951  CHECK(categoryWValidDeviceAndValidCounterSet->m_Name == categoryWValidDeviceAndValidCounterSetName);
952 }
953 
954 TEST_CASE("CheckCounterDirectoryRegisterDevice")
955 {
956  CounterDirectory counterDirectory;
957  CHECK(counterDirectory.GetCategoryCount() == 0);
958  CHECK(counterDirectory.GetDeviceCount() == 0);
959  CHECK(counterDirectory.GetCounterSetCount() == 0);
960  CHECK(counterDirectory.GetCounterCount() == 0);
961 
962  // Register a device with an invalid name
963  const Device* noDevice = nullptr;
964  CHECK_THROWS_AS(noDevice = counterDirectory.RegisterDevice(""), armnn::InvalidArgumentException);
965  CHECK(counterDirectory.GetDeviceCount() == 0);
966  CHECK(!noDevice);
967 
968  // Register a device with an invalid name
969  CHECK_THROWS_AS(noDevice = counterDirectory.RegisterDevice("inv@lid nam€"), armnn::InvalidArgumentException);
970  CHECK(counterDirectory.GetDeviceCount() == 0);
971  CHECK(!noDevice);
972 
973  // Register a new device with no cores or parent category
974  const std::string deviceName = "some_device";
975  const Device* device = nullptr;
976  CHECK_NOTHROW(device = counterDirectory.RegisterDevice(deviceName));
977  CHECK(counterDirectory.GetDeviceCount() == 1);
978  CHECK(device);
979  CHECK(device->m_Name == deviceName);
980  CHECK(device->m_Uid >= 1);
981  CHECK(device->m_Cores == 0);
982 
983  // Try getting an unregistered device
984  const Device* unregisteredDevice = counterDirectory.GetDevice(9999);
985  CHECK(!unregisteredDevice);
986 
987  // Get the registered device
988  const Device* registeredDevice = counterDirectory.GetDevice(device->m_Uid);
989  CHECK(counterDirectory.GetDeviceCount() == 1);
990  CHECK(registeredDevice);
991  CHECK(registeredDevice == device);
992 
993  // Register a device with the name of a device already registered
994  const Device* deviceSameName = nullptr;
995  CHECK_THROWS_AS(deviceSameName = counterDirectory.RegisterDevice(deviceName), armnn::InvalidArgumentException);
996  CHECK(counterDirectory.GetDeviceCount() == 1);
997  CHECK(!deviceSameName);
998 
999  // Register a new device with cores and no parent category
1000  const std::string deviceWCoresName = "some_device_with_cores";
1001  const Device* deviceWCores = nullptr;
1002  CHECK_NOTHROW(deviceWCores = counterDirectory.RegisterDevice(deviceWCoresName, 2));
1003  CHECK(counterDirectory.GetDeviceCount() == 2);
1004  CHECK(deviceWCores);
1005  CHECK(deviceWCores->m_Name == deviceWCoresName);
1006  CHECK(deviceWCores->m_Uid >= 1);
1007  CHECK(deviceWCores->m_Uid > device->m_Uid);
1008  CHECK(deviceWCores->m_Cores == 2);
1009 
1010  // Get the registered device
1011  const Device* registeredDeviceWCores = counterDirectory.GetDevice(deviceWCores->m_Uid);
1012  CHECK(counterDirectory.GetDeviceCount() == 2);
1013  CHECK(registeredDeviceWCores);
1014  CHECK(registeredDeviceWCores == deviceWCores);
1015  CHECK(registeredDeviceWCores != device);
1016 
1017  // Register a new device with cores and invalid parent category
1018  const std::string deviceWCoresWInvalidParentCategoryName = "some_device_with_cores_with_invalid_parent_category";
1019  const Device* deviceWCoresWInvalidParentCategory = nullptr;
1020  CHECK_THROWS_AS(deviceWCoresWInvalidParentCategory =
1021  counterDirectory.RegisterDevice(deviceWCoresWInvalidParentCategoryName, 3, std::string("")),
1023  CHECK(counterDirectory.GetDeviceCount() == 2);
1024  CHECK(!deviceWCoresWInvalidParentCategory);
1025 
1026  // Register a new device with cores and invalid parent category
1027  const std::string deviceWCoresWInvalidParentCategoryName2 = "some_device_with_cores_with_invalid_parent_category2";
1028  const Device* deviceWCoresWInvalidParentCategory2 = nullptr;
1029  CHECK_THROWS_AS(deviceWCoresWInvalidParentCategory2 = counterDirectory.RegisterDevice(
1030  deviceWCoresWInvalidParentCategoryName2, 3, std::string("invalid_parent_category")),
1032  CHECK(counterDirectory.GetDeviceCount() == 2);
1033  CHECK(!deviceWCoresWInvalidParentCategory2);
1034 
1035  // Register a category for testing
1036  const std::string categoryName = "some_category";
1037  const Category* category = nullptr;
1038  CHECK_NOTHROW(category = counterDirectory.RegisterCategory(categoryName));
1039  CHECK(counterDirectory.GetCategoryCount() == 1);
1040  CHECK(category);
1041  CHECK(category->m_Name == categoryName);
1042  CHECK(category->m_Counters.empty());
1043 
1044  // Register a new device with cores and valid parent category
1045  const std::string deviceWCoresWValidParentCategoryName = "some_device_with_cores_with_valid_parent_category";
1046  const Device* deviceWCoresWValidParentCategory = nullptr;
1047  CHECK_NOTHROW(deviceWCoresWValidParentCategory =
1048  counterDirectory.RegisterDevice(deviceWCoresWValidParentCategoryName, 4, categoryName));
1049  CHECK(counterDirectory.GetDeviceCount() == 3);
1050  CHECK(deviceWCoresWValidParentCategory);
1051  CHECK(deviceWCoresWValidParentCategory->m_Name == deviceWCoresWValidParentCategoryName);
1052  CHECK(deviceWCoresWValidParentCategory->m_Uid >= 1);
1053  CHECK(deviceWCoresWValidParentCategory->m_Uid > device->m_Uid);
1054  CHECK(deviceWCoresWValidParentCategory->m_Uid > deviceWCores->m_Uid);
1055  CHECK(deviceWCoresWValidParentCategory->m_Cores == 4);
1056 }
1057 
1058 TEST_CASE("CheckCounterDirectoryRegisterCounterSet")
1059 {
1060  CounterDirectory counterDirectory;
1061  CHECK(counterDirectory.GetCategoryCount() == 0);
1062  CHECK(counterDirectory.GetDeviceCount() == 0);
1063  CHECK(counterDirectory.GetCounterSetCount() == 0);
1064  CHECK(counterDirectory.GetCounterCount() == 0);
1065 
1066  // Register a counter set with an invalid name
1067  const CounterSet* noCounterSet = nullptr;
1068  CHECK_THROWS_AS(noCounterSet = counterDirectory.RegisterCounterSet(""), armnn::InvalidArgumentException);
1069  CHECK(counterDirectory.GetCounterSetCount() == 0);
1070  CHECK(!noCounterSet);
1071 
1072  // Register a counter set with an invalid name
1073  CHECK_THROWS_AS(noCounterSet = counterDirectory.RegisterCounterSet("invalid name"),
1075  CHECK(counterDirectory.GetCounterSetCount() == 0);
1076  CHECK(!noCounterSet);
1077 
1078  // Register a new counter set with no count or parent category
1079  const std::string counterSetName = "some_counter_set";
1080  const CounterSet* counterSet = nullptr;
1081  CHECK_NOTHROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName));
1082  CHECK(counterDirectory.GetCounterSetCount() == 1);
1083  CHECK(counterSet);
1084  CHECK(counterSet->m_Name == counterSetName);
1085  CHECK(counterSet->m_Uid >= 1);
1086  CHECK(counterSet->m_Count == 0);
1087 
1088  // Try getting an unregistered counter set
1089  const CounterSet* unregisteredCounterSet = counterDirectory.GetCounterSet(9999);
1090  CHECK(!unregisteredCounterSet);
1091 
1092  // Get the registered counter set
1093  const CounterSet* registeredCounterSet = counterDirectory.GetCounterSet(counterSet->m_Uid);
1094  CHECK(counterDirectory.GetCounterSetCount() == 1);
1095  CHECK(registeredCounterSet);
1096  CHECK(registeredCounterSet == counterSet);
1097 
1098  // Register a counter set with the name of a counter set already registered
1099  const CounterSet* counterSetSameName = nullptr;
1100  CHECK_THROWS_AS(counterSetSameName = counterDirectory.RegisterCounterSet(counterSetName),
1102  CHECK(counterDirectory.GetCounterSetCount() == 1);
1103  CHECK(!counterSetSameName);
1104 
1105  // Register a new counter set with count and no parent category
1106  const std::string counterSetWCountName = "some_counter_set_with_count";
1107  const CounterSet* counterSetWCount = nullptr;
1108  CHECK_NOTHROW(counterSetWCount = counterDirectory.RegisterCounterSet(counterSetWCountName, 37));
1109  CHECK(counterDirectory.GetCounterSetCount() == 2);
1110  CHECK(counterSetWCount);
1111  CHECK(counterSetWCount->m_Name == counterSetWCountName);
1112  CHECK(counterSetWCount->m_Uid >= 1);
1113  CHECK(counterSetWCount->m_Uid > counterSet->m_Uid);
1114  CHECK(counterSetWCount->m_Count == 37);
1115 
1116  // Get the registered counter set
1117  const CounterSet* registeredCounterSetWCount = counterDirectory.GetCounterSet(counterSetWCount->m_Uid);
1118  CHECK(counterDirectory.GetCounterSetCount() == 2);
1119  CHECK(registeredCounterSetWCount);
1120  CHECK(registeredCounterSetWCount == counterSetWCount);
1121  CHECK(registeredCounterSetWCount != counterSet);
1122 
1123  // Register a new counter set with count and invalid parent category
1124  const std::string counterSetWCountWInvalidParentCategoryName = "some_counter_set_with_count_"
1125  "with_invalid_parent_category";
1126  const CounterSet* counterSetWCountWInvalidParentCategory = nullptr;
1127  CHECK_THROWS_AS(counterSetWCountWInvalidParentCategory = counterDirectory.RegisterCounterSet(
1128  counterSetWCountWInvalidParentCategoryName, 42, std::string("")),
1130  CHECK(counterDirectory.GetCounterSetCount() == 2);
1131  CHECK(!counterSetWCountWInvalidParentCategory);
1132 
1133  // Register a new counter set with count and invalid parent category
1134  const std::string counterSetWCountWInvalidParentCategoryName2 = "some_counter_set_with_count_"
1135  "with_invalid_parent_category2";
1136  const CounterSet* counterSetWCountWInvalidParentCategory2 = nullptr;
1137  CHECK_THROWS_AS(counterSetWCountWInvalidParentCategory2 = counterDirectory.RegisterCounterSet(
1138  counterSetWCountWInvalidParentCategoryName2, 42, std::string("invalid_parent_category")),
1140  CHECK(counterDirectory.GetCounterSetCount() == 2);
1141  CHECK(!counterSetWCountWInvalidParentCategory2);
1142 
1143  // Register a category for testing
1144  const std::string categoryName = "some_category";
1145  const Category* category = nullptr;
1146  CHECK_NOTHROW(category = counterDirectory.RegisterCategory(categoryName));
1147  CHECK(counterDirectory.GetCategoryCount() == 1);
1148  CHECK(category);
1149  CHECK(category->m_Name == categoryName);
1150  CHECK(category->m_Counters.empty());
1151 
1152  // Register a new counter set with count and valid parent category
1153  const std::string counterSetWCountWValidParentCategoryName = "some_counter_set_with_count_"
1154  "with_valid_parent_category";
1155  const CounterSet* counterSetWCountWValidParentCategory = nullptr;
1156  CHECK_NOTHROW(counterSetWCountWValidParentCategory = counterDirectory.RegisterCounterSet(
1157  counterSetWCountWValidParentCategoryName, 42, categoryName));
1158  CHECK(counterDirectory.GetCounterSetCount() == 3);
1159  CHECK(counterSetWCountWValidParentCategory);
1160  CHECK(counterSetWCountWValidParentCategory->m_Name == counterSetWCountWValidParentCategoryName);
1161  CHECK(counterSetWCountWValidParentCategory->m_Uid >= 1);
1162  CHECK(counterSetWCountWValidParentCategory->m_Uid > counterSet->m_Uid);
1163  CHECK(counterSetWCountWValidParentCategory->m_Uid > counterSetWCount->m_Uid);
1164  CHECK(counterSetWCountWValidParentCategory->m_Count == 42);
1165 
1166  // Register a counter set associated to a category with invalid name
1167  const std::string counterSetSameCategoryName = "some_counter_set_with_invalid_parent_category";
1168  const std::string invalidCategoryName = "";
1169  const CounterSet* counterSetSameCategory = nullptr;
1170  CHECK_THROWS_AS(counterSetSameCategory =
1171  counterDirectory.RegisterCounterSet(counterSetSameCategoryName, 0, invalidCategoryName),
1173  CHECK(counterDirectory.GetCounterSetCount() == 3);
1174  CHECK(!counterSetSameCategory);
1175 }
1176 
1177 TEST_CASE("CheckCounterDirectoryRegisterCounter")
1178 {
1179  CounterDirectory counterDirectory;
1180  CHECK(counterDirectory.GetCategoryCount() == 0);
1181  CHECK(counterDirectory.GetDeviceCount() == 0);
1182  CHECK(counterDirectory.GetCounterSetCount() == 0);
1183  CHECK(counterDirectory.GetCounterCount() == 0);
1184 
1185  // Register a counter with an invalid parent category name
1186  const Counter* noCounter = nullptr;
1187  CHECK_THROWS_AS(noCounter =
1188  counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1189  0,
1190  "",
1191  0,
1192  1,
1193  123.45f,
1194  "valid ",
1195  "name"),
1197  CHECK(counterDirectory.GetCounterCount() == 0);
1198  CHECK(!noCounter);
1199 
1200  // Register a counter with an invalid parent category name
1201  CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1202  1,
1203  "invalid parent category",
1204  0,
1205  1,
1206  123.45f,
1207  "valid name",
1208  "valid description"),
1210  CHECK(counterDirectory.GetCounterCount() == 0);
1211  CHECK(!noCounter);
1212 
1213  // Register a counter with an invalid class
1214  CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1215  2,
1216  "valid_parent_category",
1217  2,
1218  1,
1219  123.45f,
1220  "valid "
1221  "name",
1222  "valid description"),
1224  CHECK(counterDirectory.GetCounterCount() == 0);
1225  CHECK(!noCounter);
1226 
1227  // Register a counter with an invalid interpolation
1228  CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1229  4,
1230  "valid_parent_category",
1231  0,
1232  3,
1233  123.45f,
1234  "valid "
1235  "name",
1236  "valid description"),
1238  CHECK(counterDirectory.GetCounterCount() == 0);
1239  CHECK(!noCounter);
1240 
1241  // Register a counter with an invalid multiplier
1242  CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1243  5,
1244  "valid_parent_category",
1245  0,
1246  1,
1247  .0f,
1248  "valid "
1249  "name",
1250  "valid description"),
1252  CHECK(counterDirectory.GetCounterCount() == 0);
1253  CHECK(!noCounter);
1254 
1255  // Register a counter with an invalid name
1256  CHECK_THROWS_AS(
1257  noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1258  6,
1259  "valid_parent_category",
1260  0,
1261  1,
1262  123.45f,
1263  "",
1264  "valid description"),
1266  CHECK(counterDirectory.GetCounterCount() == 0);
1267  CHECK(!noCounter);
1268 
1269  // Register a counter with an invalid name
1270  CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1271  7,
1272  "valid_parent_category",
1273  0,
1274  1,
1275  123.45f,
1276  "invalid nam€",
1277  "valid description"),
1279  CHECK(counterDirectory.GetCounterCount() == 0);
1280  CHECK(!noCounter);
1281 
1282  // Register a counter with an invalid description
1283  CHECK_THROWS_AS(noCounter =
1284  counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1285  8,
1286  "valid_parent_category",
1287  0,
1288  1,
1289  123.45f,
1290  "valid name",
1291  ""),
1293  CHECK(counterDirectory.GetCounterCount() == 0);
1294  CHECK(!noCounter);
1295 
1296  // Register a counter with an invalid description
1297  CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1298  9,
1299  "valid_parent_category",
1300  0,
1301  1,
1302  123.45f,
1303  "valid "
1304  "name",
1305  "inv@lid description"),
1307  CHECK(counterDirectory.GetCounterCount() == 0);
1308  CHECK(!noCounter);
1309 
1310  // Register a counter with an invalid unit2
1311  CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1312  10,
1313  "valid_parent_category",
1314  0,
1315  1,
1316  123.45f,
1317  "valid name",
1318  "valid description",
1319  std::string("Mb/s2")),
1321  CHECK(counterDirectory.GetCounterCount() == 0);
1322  CHECK(!noCounter);
1323 
1324  // Register a counter with a non-existing parent category name
1325  CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1326  11,
1327  "invalid_parent_category",
1328  0,
1329  1,
1330  123.45f,
1331  "valid name",
1332  "valid description"),
1334  CHECK(counterDirectory.GetCounterCount() == 0);
1335  CHECK(!noCounter);
1336 
1337  // Try getting an unregistered counter
1338  const Counter* unregisteredCounter = counterDirectory.GetCounter(9999);
1339  CHECK(!unregisteredCounter);
1340 
1341  // Register a category for testing
1342  const std::string categoryName = "some_category";
1343  const Category* category = nullptr;
1344  CHECK_NOTHROW(category = counterDirectory.RegisterCategory(categoryName));
1345  CHECK(counterDirectory.GetCategoryCount() == 1);
1346  CHECK(category);
1347  CHECK(category->m_Name == categoryName);
1348  CHECK(category->m_Counters.empty());
1349 
1350  // Register a counter with a valid parent category name
1351  const Counter* counter = nullptr;
1352  CHECK_NOTHROW(
1353  counter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1354  12,
1355  categoryName,
1356  0,
1357  1,
1358  123.45f,
1359  "valid name",
1360  "valid description"));
1361  CHECK(counterDirectory.GetCounterCount() == 1);
1362  CHECK(counter);
1363  CHECK(counter->m_MaxCounterUid == counter->m_Uid);
1364  CHECK(counter->m_Class == 0);
1365  CHECK(counter->m_Interpolation == 1);
1366  CHECK(counter->m_Multiplier == 123.45f);
1367  CHECK(counter->m_Name == "valid name");
1368  CHECK(counter->m_Description == "valid description");
1369  CHECK(counter->m_Units == "");
1370  CHECK(counter->m_DeviceUid == 0);
1371  CHECK(counter->m_CounterSetUid == 0);
1372  CHECK(category->m_Counters.size() == 1);
1373  CHECK(category->m_Counters.back() == counter->m_Uid);
1374 
1375  // Register a counter with a name of a counter already registered for the given parent category name
1376  const Counter* counterSameName = nullptr;
1377  CHECK_THROWS_AS(counterSameName =
1378  counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1379  13,
1380  categoryName,
1381  0,
1382  0,
1383  1.0f,
1384  "valid name",
1385  "valid description",
1386  std::string("description")),
1388  CHECK(counterDirectory.GetCounterCount() == 1);
1389  CHECK(!counterSameName);
1390 
1391  // Register a counter with a valid parent category name and units
1392  const Counter* counterWUnits = nullptr;
1393  CHECK_NOTHROW(counterWUnits = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1394  14,
1395  categoryName,
1396  0,
1397  1,
1398  123.45f,
1399  "valid name 2",
1400  "valid description",
1401  std::string("Mnnsq2"))); // Units
1402  CHECK(counterDirectory.GetCounterCount() == 2);
1403  CHECK(counterWUnits);
1404  CHECK(counterWUnits->m_Uid > counter->m_Uid);
1405  CHECK(counterWUnits->m_MaxCounterUid == counterWUnits->m_Uid);
1406  CHECK(counterWUnits->m_Class == 0);
1407  CHECK(counterWUnits->m_Interpolation == 1);
1408  CHECK(counterWUnits->m_Multiplier == 123.45f);
1409  CHECK(counterWUnits->m_Name == "valid name 2");
1410  CHECK(counterWUnits->m_Description == "valid description");
1411  CHECK(counterWUnits->m_Units == "Mnnsq2");
1412  CHECK(counterWUnits->m_DeviceUid == 0);
1413  CHECK(counterWUnits->m_CounterSetUid == 0);
1414  CHECK(category->m_Counters.size() == 2);
1415  CHECK(category->m_Counters.back() == counterWUnits->m_Uid);
1416 
1417  // Register a counter with a valid parent category name and not associated with a device
1418  const Counter* counterWoDevice = nullptr;
1419  CHECK_NOTHROW(counterWoDevice = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1420  26,
1421  categoryName,
1422  0,
1423  1,
1424  123.45f,
1425  "valid name 3",
1426  "valid description",
1427  armnn::EmptyOptional(),// Units
1428  armnn::EmptyOptional(),// Number of cores
1429  0)); // Device UID
1430  CHECK(counterDirectory.GetCounterCount() == 3);
1431  CHECK(counterWoDevice);
1432  CHECK(counterWoDevice->m_Uid > counter->m_Uid);
1433  CHECK(counterWoDevice->m_MaxCounterUid == counterWoDevice->m_Uid);
1434  CHECK(counterWoDevice->m_Class == 0);
1435  CHECK(counterWoDevice->m_Interpolation == 1);
1436  CHECK(counterWoDevice->m_Multiplier == 123.45f);
1437  CHECK(counterWoDevice->m_Name == "valid name 3");
1438  CHECK(counterWoDevice->m_Description == "valid description");
1439  CHECK(counterWoDevice->m_Units == "");
1440  CHECK(counterWoDevice->m_DeviceUid == 0);
1441  CHECK(counterWoDevice->m_CounterSetUid == 0);
1442  CHECK(category->m_Counters.size() == 3);
1443  CHECK(category->m_Counters.back() == counterWoDevice->m_Uid);
1444 
1445  // Register a counter with a valid parent category name and associated to an invalid device
1446  CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1447  15,
1448  categoryName,
1449  0,
1450  1,
1451  123.45f,
1452  "valid name 4",
1453  "valid description",
1454  armnn::EmptyOptional(), // Units
1455  armnn::EmptyOptional(), // Number of cores
1456  100), // Device UID
1458  CHECK(counterDirectory.GetCounterCount() == 3);
1459  CHECK(!noCounter);
1460 
1461  // Register a device for testing
1462  const std::string deviceName = "some_device";
1463  const Device* device = nullptr;
1464  CHECK_NOTHROW(device = counterDirectory.RegisterDevice(deviceName));
1465  CHECK(counterDirectory.GetDeviceCount() == 1);
1466  CHECK(device);
1467  CHECK(device->m_Name == deviceName);
1468  CHECK(device->m_Uid >= 1);
1469  CHECK(device->m_Cores == 0);
1470 
1471  // Register a counter with a valid parent category name and associated to a device
1472  const Counter* counterWDevice = nullptr;
1473  CHECK_NOTHROW(counterWDevice = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1474  16,
1475  categoryName,
1476  0,
1477  1,
1478  123.45f,
1479  "valid name 5",
1480  std::string("valid description"),
1481  armnn::EmptyOptional(), // Units
1482  armnn::EmptyOptional(), // Number of cores
1483  device->m_Uid)); // Device UID
1484  CHECK(counterDirectory.GetCounterCount() == 4);
1485  CHECK(counterWDevice);
1486  CHECK(counterWDevice->m_Uid > counter->m_Uid);
1487  CHECK(counterWDevice->m_MaxCounterUid == counterWDevice->m_Uid);
1488  CHECK(counterWDevice->m_Class == 0);
1489  CHECK(counterWDevice->m_Interpolation == 1);
1490  CHECK(counterWDevice->m_Multiplier == 123.45f);
1491  CHECK(counterWDevice->m_Name == "valid name 5");
1492  CHECK(counterWDevice->m_Description == "valid description");
1493  CHECK(counterWDevice->m_Units == "");
1494  CHECK(counterWDevice->m_DeviceUid == device->m_Uid);
1495  CHECK(counterWDevice->m_CounterSetUid == 0);
1496  CHECK(category->m_Counters.size() == 4);
1497  CHECK(category->m_Counters.back() == counterWDevice->m_Uid);
1498 
1499  // Register a counter with a valid parent category name and not associated with a counter set
1500  const Counter* counterWoCounterSet = nullptr;
1501  CHECK_NOTHROW(counterWoCounterSet = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1502  17,
1503  categoryName,
1504  0,
1505  1,
1506  123.45f,
1507  "valid name 6",
1508  "valid description",
1509  armnn::EmptyOptional(),// Units
1510  armnn::EmptyOptional(),// No of cores
1511  armnn::EmptyOptional(),// Device UID
1512  0)); // CounterSet UID
1513  CHECK(counterDirectory.GetCounterCount() == 5);
1514  CHECK(counterWoCounterSet);
1515  CHECK(counterWoCounterSet->m_Uid > counter->m_Uid);
1516  CHECK(counterWoCounterSet->m_MaxCounterUid == counterWoCounterSet->m_Uid);
1517  CHECK(counterWoCounterSet->m_Class == 0);
1518  CHECK(counterWoCounterSet->m_Interpolation == 1);
1519  CHECK(counterWoCounterSet->m_Multiplier == 123.45f);
1520  CHECK(counterWoCounterSet->m_Name == "valid name 6");
1521  CHECK(counterWoCounterSet->m_Description == "valid description");
1522  CHECK(counterWoCounterSet->m_Units == "");
1523  CHECK(counterWoCounterSet->m_DeviceUid == 0);
1524  CHECK(counterWoCounterSet->m_CounterSetUid == 0);
1525  CHECK(category->m_Counters.size() == 5);
1526  CHECK(category->m_Counters.back() == counterWoCounterSet->m_Uid);
1527 
1528  // Register a counter with a valid parent category name and associated to an invalid counter set
1529  CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1530  18,
1531  categoryName,
1532  0,
1533  1,
1534  123.45f,
1535  "valid ",
1536  "name 7",
1537  std::string("valid description"),
1538  armnn::EmptyOptional(), // Units
1539  armnn::EmptyOptional(), // Number of cores
1540  100), // Counter set UID
1542  CHECK(counterDirectory.GetCounterCount() == 5);
1543  CHECK(!noCounter);
1544 
1545  // Register a counter with a valid parent category name and with a given number of cores
1546  const Counter* counterWNumberOfCores = nullptr;
1547  uint16_t numberOfCores = 15;
1548  CHECK_NOTHROW(counterWNumberOfCores = counterDirectory.RegisterCounter(
1549  armnn::profiling::BACKEND_ID, 50,
1550  categoryName, 0, 1, 123.45f, "valid name 8", "valid description",
1551  armnn::EmptyOptional(), // Units
1552  numberOfCores, // Number of cores
1553  armnn::EmptyOptional(), // Device UID
1554  armnn::EmptyOptional())); // Counter set UID
1555  CHECK(counterDirectory.GetCounterCount() == 20);
1556  CHECK(counterWNumberOfCores);
1557  CHECK(counterWNumberOfCores->m_Uid > counter->m_Uid);
1558  CHECK(counterWNumberOfCores->m_MaxCounterUid == counterWNumberOfCores->m_Uid + numberOfCores - 1);
1559  CHECK(counterWNumberOfCores->m_Class == 0);
1560  CHECK(counterWNumberOfCores->m_Interpolation == 1);
1561  CHECK(counterWNumberOfCores->m_Multiplier == 123.45f);
1562  CHECK(counterWNumberOfCores->m_Name == "valid name 8");
1563  CHECK(counterWNumberOfCores->m_Description == "valid description");
1564  CHECK(counterWNumberOfCores->m_Units == "");
1565  CHECK(counterWNumberOfCores->m_DeviceUid == 0);
1566  CHECK(counterWNumberOfCores->m_CounterSetUid == 0);
1567  CHECK(category->m_Counters.size() == 20);
1568  for (size_t i = 0; i < numberOfCores; i++)
1569  {
1570  CHECK(category->m_Counters[category->m_Counters.size() - numberOfCores + i] ==
1571  counterWNumberOfCores->m_Uid + i);
1572  }
1573 
1574  // Register a multi-core device for testing
1575  const std::string multiCoreDeviceName = "some_multi_core_device";
1576  const Device* multiCoreDevice = nullptr;
1577  CHECK_NOTHROW(multiCoreDevice = counterDirectory.RegisterDevice(multiCoreDeviceName, 4));
1578  CHECK(counterDirectory.GetDeviceCount() == 2);
1579  CHECK(multiCoreDevice);
1580  CHECK(multiCoreDevice->m_Name == multiCoreDeviceName);
1581  CHECK(multiCoreDevice->m_Uid >= 1);
1582  CHECK(multiCoreDevice->m_Cores == 4);
1583 
1584  // Register a counter with a valid parent category name and associated to the multi-core device
1585  const Counter* counterWMultiCoreDevice = nullptr;
1586  CHECK_NOTHROW(counterWMultiCoreDevice = counterDirectory.RegisterCounter(
1587  armnn::profiling::BACKEND_ID, 19, categoryName, 0, 1,
1588  123.45f, "valid name 9", "valid description",
1589  armnn::EmptyOptional(), // Units
1590  armnn::EmptyOptional(), // Number of cores
1591  multiCoreDevice->m_Uid, // Device UID
1592  armnn::EmptyOptional())); // Counter set UID
1593  CHECK(counterDirectory.GetCounterCount() == 24);
1594  CHECK(counterWMultiCoreDevice);
1595  CHECK(counterWMultiCoreDevice->m_Uid > counter->m_Uid);
1596  CHECK(counterWMultiCoreDevice->m_MaxCounterUid ==
1597  counterWMultiCoreDevice->m_Uid + multiCoreDevice->m_Cores - 1);
1598  CHECK(counterWMultiCoreDevice->m_Class == 0);
1599  CHECK(counterWMultiCoreDevice->m_Interpolation == 1);
1600  CHECK(counterWMultiCoreDevice->m_Multiplier == 123.45f);
1601  CHECK(counterWMultiCoreDevice->m_Name == "valid name 9");
1602  CHECK(counterWMultiCoreDevice->m_Description == "valid description");
1603  CHECK(counterWMultiCoreDevice->m_Units == "");
1604  CHECK(counterWMultiCoreDevice->m_DeviceUid == multiCoreDevice->m_Uid);
1605  CHECK(counterWMultiCoreDevice->m_CounterSetUid == 0);
1606  CHECK(category->m_Counters.size() == 24);
1607  for (size_t i = 0; i < 4; i++)
1608  {
1609  CHECK(category->m_Counters[category->m_Counters.size() - 4 + i] == counterWMultiCoreDevice->m_Uid + i);
1610  }
1611 
1612  // Register a multi-core device associate to a parent category for testing
1613  const std::string multiCoreDeviceNameWParentCategory = "some_multi_core_device_with_parent_category";
1614  const Device* multiCoreDeviceWParentCategory = nullptr;
1615  CHECK_NOTHROW(multiCoreDeviceWParentCategory =
1616  counterDirectory.RegisterDevice(multiCoreDeviceNameWParentCategory, 2, categoryName));
1617  CHECK(counterDirectory.GetDeviceCount() == 3);
1618  CHECK(multiCoreDeviceWParentCategory);
1619  CHECK(multiCoreDeviceWParentCategory->m_Name == multiCoreDeviceNameWParentCategory);
1620  CHECK(multiCoreDeviceWParentCategory->m_Uid >= 1);
1621  CHECK(multiCoreDeviceWParentCategory->m_Cores == 2);
1622 
1623  // Register a counter with a valid parent category name and getting the number of cores of the multi-core device
1624  // associated to that category
1625  const Counter* counterWMultiCoreDeviceWParentCategory = nullptr;
1626  uint16_t numberOfCourse = multiCoreDeviceWParentCategory->m_Cores;
1627  CHECK_NOTHROW(counterWMultiCoreDeviceWParentCategory =
1628  counterDirectory.RegisterCounter(
1629  armnn::profiling::BACKEND_ID,
1630  100,
1631  categoryName,
1632  0,
1633  1,
1634  123.45f,
1635  "valid name 10",
1636  "valid description",
1637  armnn::EmptyOptional(), // Units
1638  numberOfCourse, // Number of cores
1639  armnn::EmptyOptional(), // Device UID
1640  armnn::EmptyOptional()));// Counter set UID
1641  CHECK(counterDirectory.GetCounterCount() == 26);
1642  CHECK(counterWMultiCoreDeviceWParentCategory);
1643  CHECK(counterWMultiCoreDeviceWParentCategory->m_Uid > counter->m_Uid);
1644  CHECK(counterWMultiCoreDeviceWParentCategory->m_MaxCounterUid ==
1645  counterWMultiCoreDeviceWParentCategory->m_Uid + multiCoreDeviceWParentCategory->m_Cores - 1);
1646  CHECK(counterWMultiCoreDeviceWParentCategory->m_Class == 0);
1647  CHECK(counterWMultiCoreDeviceWParentCategory->m_Interpolation == 1);
1648  CHECK(counterWMultiCoreDeviceWParentCategory->m_Multiplier == 123.45f);
1649  CHECK(counterWMultiCoreDeviceWParentCategory->m_Name == "valid name 10");
1650  CHECK(counterWMultiCoreDeviceWParentCategory->m_Description == "valid description");
1651  CHECK(counterWMultiCoreDeviceWParentCategory->m_Units == "");
1652  CHECK(category->m_Counters.size() == 26);
1653  for (size_t i = 0; i < 2; i++)
1654  {
1655  CHECK(category->m_Counters[category->m_Counters.size() - 2 + i] ==
1656  counterWMultiCoreDeviceWParentCategory->m_Uid + i);
1657  }
1658 
1659  // Register a counter set for testing
1660  const std::string counterSetName = "some_counter_set";
1661  const CounterSet* counterSet = nullptr;
1662  CHECK_NOTHROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName));
1663  CHECK(counterDirectory.GetCounterSetCount() == 1);
1664  CHECK(counterSet);
1665  CHECK(counterSet->m_Name == counterSetName);
1666  CHECK(counterSet->m_Uid >= 1);
1667  CHECK(counterSet->m_Count == 0);
1668 
1669  // Register a counter with a valid parent category name and associated to a counter set
1670  const Counter* counterWCounterSet = nullptr;
1671  CHECK_NOTHROW(counterWCounterSet = counterDirectory.RegisterCounter(
1672  armnn::profiling::BACKEND_ID, 300,
1673  categoryName, 0, 1, 123.45f, "valid name 11", "valid description",
1674  armnn::EmptyOptional(), // Units
1675  0, // Number of cores
1676  armnn::EmptyOptional(), // Device UID
1677  counterSet->m_Uid)); // Counter set UID
1678  CHECK(counterDirectory.GetCounterCount() == 27);
1679  CHECK(counterWCounterSet);
1680  CHECK(counterWCounterSet->m_Uid > counter->m_Uid);
1681  CHECK(counterWCounterSet->m_MaxCounterUid == counterWCounterSet->m_Uid);
1682  CHECK(counterWCounterSet->m_Class == 0);
1683  CHECK(counterWCounterSet->m_Interpolation == 1);
1684  CHECK(counterWCounterSet->m_Multiplier == 123.45f);
1685  CHECK(counterWCounterSet->m_Name == "valid name 11");
1686  CHECK(counterWCounterSet->m_Description == "valid description");
1687  CHECK(counterWCounterSet->m_Units == "");
1688  CHECK(counterWCounterSet->m_DeviceUid == 0);
1689  CHECK(counterWCounterSet->m_CounterSetUid == counterSet->m_Uid);
1690  CHECK(category->m_Counters.size() == 27);
1691  CHECK(category->m_Counters.back() == counterWCounterSet->m_Uid);
1692 
1693  // Register a counter with a valid parent category name and associated to a device and a counter set
1694  const Counter* counterWDeviceWCounterSet = nullptr;
1695  CHECK_NOTHROW(counterWDeviceWCounterSet = counterDirectory.RegisterCounter(
1696  armnn::profiling::BACKEND_ID, 23,
1697  categoryName, 0, 1, 123.45f, "valid name 12", "valid description",
1698  armnn::EmptyOptional(), // Units
1699  1, // Number of cores
1700  device->m_Uid, // Device UID
1701  counterSet->m_Uid)); // Counter set UID
1702  CHECK(counterDirectory.GetCounterCount() == 28);
1703  CHECK(counterWDeviceWCounterSet);
1704  CHECK(counterWDeviceWCounterSet->m_Uid > counter->m_Uid);
1705  CHECK(counterWDeviceWCounterSet->m_MaxCounterUid == counterWDeviceWCounterSet->m_Uid);
1706  CHECK(counterWDeviceWCounterSet->m_Class == 0);
1707  CHECK(counterWDeviceWCounterSet->m_Interpolation == 1);
1708  CHECK(counterWDeviceWCounterSet->m_Multiplier == 123.45f);
1709  CHECK(counterWDeviceWCounterSet->m_Name == "valid name 12");
1710  CHECK(counterWDeviceWCounterSet->m_Description == "valid description");
1711  CHECK(counterWDeviceWCounterSet->m_Units == "");
1712  CHECK(counterWDeviceWCounterSet->m_DeviceUid == device->m_Uid);
1713  CHECK(counterWDeviceWCounterSet->m_CounterSetUid == counterSet->m_Uid);
1714  CHECK(category->m_Counters.size() == 28);
1715  CHECK(category->m_Counters.back() == counterWDeviceWCounterSet->m_Uid);
1716 
1717  // Register another category for testing
1718  const std::string anotherCategoryName = "some_other_category";
1719  const Category* anotherCategory = nullptr;
1720  CHECK_NOTHROW(anotherCategory = counterDirectory.RegisterCategory(anotherCategoryName));
1721  CHECK(counterDirectory.GetCategoryCount() == 2);
1722  CHECK(anotherCategory);
1723  CHECK(anotherCategory != category);
1724  CHECK(anotherCategory->m_Name == anotherCategoryName);
1725  CHECK(anotherCategory->m_Counters.empty());
1726 
1727  // Register a counter to the other category
1728  const Counter* anotherCounter = nullptr;
1729  CHECK_NOTHROW(anotherCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID, 24,
1730  anotherCategoryName, 1, 0, .00043f,
1731  "valid name", "valid description",
1732  armnn::EmptyOptional(), // Units
1733  armnn::EmptyOptional(), // Number of cores
1734  device->m_Uid, // Device UID
1735  counterSet->m_Uid)); // Counter set UID
1736  CHECK(counterDirectory.GetCounterCount() == 29);
1737  CHECK(anotherCounter);
1738  CHECK(anotherCounter->m_MaxCounterUid == anotherCounter->m_Uid);
1739  CHECK(anotherCounter->m_Class == 1);
1740  CHECK(anotherCounter->m_Interpolation == 0);
1741  CHECK(anotherCounter->m_Multiplier == .00043f);
1742  CHECK(anotherCounter->m_Name == "valid name");
1743  CHECK(anotherCounter->m_Description == "valid description");
1744  CHECK(anotherCounter->m_Units == "");
1745  CHECK(anotherCounter->m_DeviceUid == device->m_Uid);
1746  CHECK(anotherCounter->m_CounterSetUid == counterSet->m_Uid);
1747  CHECK(anotherCategory->m_Counters.size() == 1);
1748  CHECK(anotherCategory->m_Counters.back() == anotherCounter->m_Uid);
1749 }
1750 
1751 TEST_CASE("CounterSelectionCommandHandlerParseData")
1752 {
1753  ProfilingStateMachine profilingStateMachine;
1754 
1755  class TestCaptureThread : public IPeriodicCounterCapture
1756  {
1757  void Start() override
1758  {}
1759  void Stop() override
1760  {}
1761  };
1762 
1763  class TestReadCounterValues : public IReadCounterValues
1764  {
1765  bool IsCounterRegistered(uint16_t counterUid) const override
1766  {
1767  armnn::IgnoreUnused(counterUid);
1768  return true;
1769  }
1770  uint16_t GetCounterCount() const override
1771  {
1772  return 0;
1773  }
1774  uint32_t GetAbsoluteCounterValue(uint16_t counterUid) const override
1775  {
1776  armnn::IgnoreUnused(counterUid);
1777  return 0;
1778  }
1779  uint32_t GetDeltaCounterValue(uint16_t counterUid) override
1780  {
1781  armnn::IgnoreUnused(counterUid);
1782  return 0;
1783  }
1784  };
1785  const uint32_t familyId = 0;
1786  const uint32_t packetId = 0x40000;
1787 
1788  uint32_t version = 1;
1789  const std::unordered_map<armnn::BackendId,
1790  std::shared_ptr<armnn::profiling::IBackendProfilingContext>> backendProfilingContext;
1791  CounterIdMap counterIdMap;
1792  Holder holder;
1793  TestCaptureThread captureThread;
1794  TestReadCounterValues readCounterValues;
1795  MockBufferManager mockBuffer(512);
1796  SendCounterPacket sendCounterPacket(mockBuffer);
1797  SendThread sendThread(profilingStateMachine, mockBuffer, sendCounterPacket);
1798 
1799  uint32_t sizeOfUint32 = armnn::numeric_cast<uint32_t>(sizeof(uint32_t));
1800  uint32_t sizeOfUint16 = armnn::numeric_cast<uint32_t>(sizeof(uint16_t));
1801 
1802  // Data with period and counters
1803  uint32_t period1 = armnn::LOWEST_CAPTURE_PERIOD;
1804  uint32_t dataLength1 = 8;
1805  uint32_t offset = 0;
1806 
1807  std::unique_ptr<unsigned char[]> uniqueData1 = std::make_unique<unsigned char[]>(dataLength1);
1808  unsigned char* data1 = reinterpret_cast<unsigned char*>(uniqueData1.get());
1809 
1810  WriteUint32(data1, offset, period1);
1811  offset += sizeOfUint32;
1812  WriteUint16(data1, offset, 4000);
1813  offset += sizeOfUint16;
1814  WriteUint16(data1, offset, 5000);
1815 
1816  arm::pipe::Packet packetA(packetId, dataLength1, uniqueData1);
1817 
1818  PeriodicCounterSelectionCommandHandler commandHandler(familyId, packetId, version, backendProfilingContext,
1819  counterIdMap, holder, 10000u, captureThread,
1820  readCounterValues, sendCounterPacket, profilingStateMachine);
1821 
1822  profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
1823  CHECK_THROWS_AS(commandHandler(packetA), armnn::RuntimeException);
1824  profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
1825  CHECK_THROWS_AS(commandHandler(packetA), armnn::RuntimeException);
1826  profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
1827  CHECK_THROWS_AS(commandHandler(packetA), armnn::RuntimeException);
1828  profilingStateMachine.TransitionToState(ProfilingState::Active);
1829  CHECK_NOTHROW(commandHandler(packetA));
1830 
1831  const std::vector<uint16_t> counterIdsA = holder.GetCaptureData().GetCounterIds();
1832 
1833  CHECK(holder.GetCaptureData().GetCapturePeriod() == period1);
1834  CHECK(counterIdsA.size() == 2);
1835  CHECK(counterIdsA[0] == 4000);
1836  CHECK(counterIdsA[1] == 5000);
1837 
1838  auto readBuffer = mockBuffer.GetReadableBuffer();
1839 
1840  offset = 0;
1841 
1842  uint32_t headerWord0 = ReadUint32(readBuffer, offset);
1843  offset += sizeOfUint32;
1844  uint32_t headerWord1 = ReadUint32(readBuffer, offset);
1845  offset += sizeOfUint32;
1846  uint32_t period = ReadUint32(readBuffer, offset);
1847 
1848  CHECK(((headerWord0 >> 26) & 0x3F) == 0); // packet family
1849  CHECK(((headerWord0 >> 16) & 0x3FF) == 4); // packet id
1850  CHECK(headerWord1 == 8); // data length
1851  CHECK(period == armnn::LOWEST_CAPTURE_PERIOD); // capture period
1852 
1853  uint16_t counterId = 0;
1854  offset += sizeOfUint32;
1855  counterId = ReadUint16(readBuffer, offset);
1856  CHECK(counterId == 4000);
1857  offset += sizeOfUint16;
1858  counterId = ReadUint16(readBuffer, offset);
1859  CHECK(counterId == 5000);
1860 
1861  mockBuffer.MarkRead(readBuffer);
1862 
1863  // Data with period only
1864  uint32_t period2 = 9000; // We'll specify a value below LOWEST_CAPTURE_PERIOD. It should be pulled upwards.
1865  uint32_t dataLength2 = 4;
1866 
1867  std::unique_ptr<unsigned char[]> uniqueData2 = std::make_unique<unsigned char[]>(dataLength2);
1868 
1869  WriteUint32(reinterpret_cast<unsigned char*>(uniqueData2.get()), 0, period2);
1870 
1871  arm::pipe::Packet packetB(packetId, dataLength2, uniqueData2);
1872 
1873  commandHandler(packetB);
1874 
1875  const std::vector<uint16_t> counterIdsB = holder.GetCaptureData().GetCounterIds();
1876 
1877  // Value should have been pulled up from 9000 to LOWEST_CAPTURE_PERIOD.
1879  CHECK(counterIdsB.size() == 0);
1880 
1881  readBuffer = mockBuffer.GetReadableBuffer();
1882 
1883  offset = 0;
1884 
1885  headerWord0 = ReadUint32(readBuffer, offset);
1886  offset += sizeOfUint32;
1887  headerWord1 = ReadUint32(readBuffer, offset);
1888  offset += sizeOfUint32;
1889  period = ReadUint32(readBuffer, offset);
1890 
1891  CHECK(((headerWord0 >> 26) & 0x3F) == 0); // packet family
1892  CHECK(((headerWord0 >> 16) & 0x3FF) == 4); // packet id
1893  CHECK(headerWord1 == 4); // data length
1894  CHECK(period == armnn::LOWEST_CAPTURE_PERIOD); // capture period
1895 }
1896 
1897 TEST_CASE("CheckTimelineActivationAndDeactivation")
1898 {
1899  class TestReportStructure : public IReportStructure
1900  {
1901  public:
1902  virtual void ReportStructure() override
1903  {
1904  m_ReportStructureCalled = true;
1905  }
1906 
1907  bool m_ReportStructureCalled = false;
1908  };
1909 
1910  class TestNotifyBackends : public INotifyBackends
1911  {
1912  public:
1913  TestNotifyBackends() : m_timelineReporting(false) {}
1914  virtual void NotifyBackendsForTimelineReporting() override
1915  {
1916  m_TestNotifyBackendsCalled = m_timelineReporting.load();
1917  }
1918 
1919  bool m_TestNotifyBackendsCalled = false;
1920  std::atomic<bool> m_timelineReporting;
1921  };
1922 
1923  arm::pipe::PacketVersionResolver packetVersionResolver;
1924 
1925  BufferManager bufferManager(512);
1926  SendTimelinePacket sendTimelinePacket(bufferManager);
1927  ProfilingStateMachine stateMachine;
1928  TestReportStructure testReportStructure;
1929  TestNotifyBackends testNotifyBackends;
1930 
1931  profiling::ActivateTimelineReportingCommandHandler activateTimelineReportingCommandHandler(0,
1932  6,
1933  packetVersionResolver.ResolvePacketVersion(0, 6)
1934  .GetEncodedValue(),
1935  sendTimelinePacket,
1936  stateMachine,
1937  testReportStructure,
1938  testNotifyBackends.m_timelineReporting,
1939  testNotifyBackends);
1940 
1941  // Write an "ActivateTimelineReporting" packet into the mock profiling connection, to simulate an input from an
1942  // external profiling service
1943  const uint32_t packetFamily1 = 0;
1944  const uint32_t packetId1 = 6;
1945  uint32_t packetHeader1 = ConstructHeader(packetFamily1, packetId1);
1946 
1947  // Create the ActivateTimelineReportingPacket
1948  arm::pipe::Packet ActivateTimelineReportingPacket(packetHeader1); // Length == 0
1949 
1950  CHECK_THROWS_AS(
1951  activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception);
1952 
1953  stateMachine.TransitionToState(ProfilingState::NotConnected);
1954  CHECK_THROWS_AS(
1955  activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception);
1956 
1957  stateMachine.TransitionToState(ProfilingState::WaitingForAck);
1958  CHECK_THROWS_AS(
1959  activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception);
1960 
1961  stateMachine.TransitionToState(ProfilingState::Active);
1962  activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket);
1963 
1964  CHECK(testReportStructure.m_ReportStructureCalled);
1965  CHECK(testNotifyBackends.m_TestNotifyBackendsCalled);
1966  CHECK(testNotifyBackends.m_timelineReporting.load());
1967 
1968  DeactivateTimelineReportingCommandHandler deactivateTimelineReportingCommandHandler(0,
1969  7,
1970  packetVersionResolver.ResolvePacketVersion(0, 7).GetEncodedValue(),
1971  testNotifyBackends.m_timelineReporting,
1972  stateMachine,
1973  testNotifyBackends);
1974 
1975  const uint32_t packetFamily2 = 0;
1976  const uint32_t packetId2 = 7;
1977  uint32_t packetHeader2 = ConstructHeader(packetFamily2, packetId2);
1978 
1979  // Create the DeactivateTimelineReportingPacket
1980  arm::pipe::Packet deactivateTimelineReportingPacket(packetHeader2); // Length == 0
1981 
1982  stateMachine.Reset();
1983  CHECK_THROWS_AS(
1984  deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception);
1985 
1986  stateMachine.TransitionToState(ProfilingState::NotConnected);
1987  CHECK_THROWS_AS(
1988  deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception);
1989 
1990  stateMachine.TransitionToState(ProfilingState::WaitingForAck);
1991  CHECK_THROWS_AS(
1992  deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception);
1993 
1994  stateMachine.TransitionToState(ProfilingState::Active);
1995  deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket);
1996 
1997  CHECK(!testNotifyBackends.m_TestNotifyBackendsCalled);
1998  CHECK(!testNotifyBackends.m_timelineReporting.load());
1999 }
2000 
2001 TEST_CASE("CheckProfilingServiceNotActive")
2002 {
2003  using namespace armnn;
2004  using namespace armnn::profiling;
2005 
2006  // Create runtime in which the test will run
2008  options.m_ProfilingOptions.m_EnableProfiling = true;
2009 
2010  armnn::RuntimeImpl runtime(options);
2011  profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime));
2012  profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected);
2013  profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck);
2014  profilingServiceHelper.ForceTransitionToState(ProfilingState::Active);
2015 
2016  profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager();
2017  auto readableBuffer = bufferManager.GetReadableBuffer();
2018 
2019  // Profiling is enabled, the post-optimisation structure should be created
2020  CHECK(readableBuffer == nullptr);
2021 }
2022 
2023 TEST_CASE("CheckConnectionAcknowledged")
2024 {
2025  const uint32_t packetFamilyId = 0;
2026  const uint32_t connectionPacketId = 0x10000;
2027  const uint32_t version = 1;
2028 
2029  uint32_t sizeOfUint32 = armnn::numeric_cast<uint32_t>(sizeof(uint32_t));
2030  uint32_t sizeOfUint16 = armnn::numeric_cast<uint32_t>(sizeof(uint16_t));
2031 
2032  // Data with period and counters
2033  uint32_t period1 = 10;
2034  uint32_t dataLength1 = 8;
2035  uint32_t offset = 0;
2036 
2037  std::unique_ptr<unsigned char[]> uniqueData1 = std::make_unique<unsigned char[]>(dataLength1);
2038  unsigned char* data1 = reinterpret_cast<unsigned char*>(uniqueData1.get());
2039 
2040  WriteUint32(data1, offset, period1);
2041  offset += sizeOfUint32;
2042  WriteUint16(data1, offset, 4000);
2043  offset += sizeOfUint16;
2044  WriteUint16(data1, offset, 5000);
2045 
2046  arm::pipe::Packet packetA(connectionPacketId, dataLength1, uniqueData1);
2047 
2048  ProfilingStateMachine profilingState(ProfilingState::Uninitialised);
2049  CHECK(profilingState.GetCurrentState() == ProfilingState::Uninitialised);
2050  CounterDirectory counterDirectory;
2051  MockBufferManager mockBuffer(1024);
2052  SendCounterPacket sendCounterPacket(mockBuffer);
2053  SendThread sendThread(profilingState, mockBuffer, sendCounterPacket);
2054  SendTimelinePacket sendTimelinePacket(mockBuffer);
2055  MockProfilingServiceStatus mockProfilingServiceStatus;
2056 
2057  ConnectionAcknowledgedCommandHandler commandHandler(packetFamilyId,
2058  connectionPacketId,
2059  version,
2060  counterDirectory,
2061  sendCounterPacket,
2062  sendTimelinePacket,
2063  profilingState,
2064  mockProfilingServiceStatus);
2065 
2066  // command handler received packet on ProfilingState::Uninitialised
2067  CHECK_THROWS_AS(commandHandler(packetA), armnn::Exception);
2068 
2069  profilingState.TransitionToState(ProfilingState::NotConnected);
2070  CHECK(profilingState.GetCurrentState() == ProfilingState::NotConnected);
2071  // command handler received packet on ProfilingState::NotConnected
2072  CHECK_THROWS_AS(commandHandler(packetA), armnn::Exception);
2073 
2074  profilingState.TransitionToState(ProfilingState::WaitingForAck);
2075  CHECK(profilingState.GetCurrentState() == ProfilingState::WaitingForAck);
2076  // command handler received packet on ProfilingState::WaitingForAck
2077  CHECK_NOTHROW(commandHandler(packetA));
2078  CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
2079 
2080  // command handler received packet on ProfilingState::Active
2081  CHECK_NOTHROW(commandHandler(packetA));
2082  CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
2083 
2084  // command handler received different packet
2085  const uint32_t differentPacketId = 0x40000;
2086  arm::pipe::Packet packetB(differentPacketId, dataLength1, uniqueData1);
2087  profilingState.TransitionToState(ProfilingState::NotConnected);
2088  profilingState.TransitionToState(ProfilingState::WaitingForAck);
2089  ConnectionAcknowledgedCommandHandler differentCommandHandler(packetFamilyId,
2090  differentPacketId,
2091  version,
2092  counterDirectory,
2093  sendCounterPacket,
2094  sendTimelinePacket,
2095  profilingState,
2096  mockProfilingServiceStatus);
2097  CHECK_THROWS_AS(differentCommandHandler(packetB), armnn::Exception);
2098 }
2099 
2100 TEST_CASE("CheckSocketConnectionException")
2101 {
2102  // Check that creating a SocketProfilingConnection armnnProfiling in an exception as the Gator UDS doesn't exist.
2103  CHECK_THROWS_AS(new SocketProfilingConnection(), arm::pipe::SocketConnectionException);
2104 }
2105 
2106 TEST_CASE("CheckSocketConnectionException2")
2107 {
2108  try
2109  {
2111  }
2112  catch (const arm::pipe::SocketConnectionException& ex)
2113  {
2114  CHECK(ex.GetSocketFd() == 0);
2115  CHECK(ex.GetErrorNo() == ECONNREFUSED);
2116  CHECK(ex.what()
2117  == std::string("SocketProfilingConnection: Cannot connect to stream socket: Connection refused"));
2118  }
2119 }
2120 
2121 TEST_CASE("SwTraceIsValidCharTest")
2122 {
2123  // Only ASCII 7-bit encoding supported
2124  for (unsigned char c = 0; c < 128; c++)
2125  {
2126  CHECK(arm::pipe::SwTraceCharPolicy::IsValidChar(c));
2127  }
2128 
2129  // Not ASCII
2130  for (unsigned char c = 255; c >= 128; c++)
2131  {
2132  CHECK(!arm::pipe::SwTraceCharPolicy::IsValidChar(c));
2133  }
2134 }
2135 
2136 TEST_CASE("SwTraceIsValidNameCharTest")
2137 {
2138  // Only alpha-numeric and underscore ASCII 7-bit encoding supported
2139  const unsigned char validChars[] = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
2140  for (unsigned char i = 0; i < sizeof(validChars) / sizeof(validChars[0]) - 1; i++)
2141  {
2142  CHECK(arm::pipe::SwTraceNameCharPolicy::IsValidChar(validChars[i]));
2143  }
2144 
2145  // Non alpha-numeric chars
2146  for (unsigned char c = 0; c < 48; c++)
2147  {
2148  CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2149  }
2150  for (unsigned char c = 58; c < 65; c++)
2151  {
2152  CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2153  }
2154  for (unsigned char c = 91; c < 95; c++)
2155  {
2156  CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2157  }
2158  for (unsigned char c = 96; c < 97; c++)
2159  {
2160  CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2161  }
2162  for (unsigned char c = 123; c < 128; c++)
2163  {
2164  CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2165  }
2166 
2167  // Not ASCII
2168  for (unsigned char c = 255; c >= 128; c++)
2169  {
2170  CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2171  }
2172 }
2173 
2174 TEST_CASE("IsValidSwTraceStringTest")
2175 {
2176  // Valid SWTrace strings
2177  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>(""));
2178  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("_"));
2179  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("0123"));
2180  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("valid_string"));
2181  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("VALID_string_456"));
2182  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>(" "));
2183  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("valid string"));
2184  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("!$%"));
2185  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("valid|\\~string#123"));
2186 
2187  // Invalid SWTrace strings
2188  CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("€£"));
2189  CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("invalid‡string"));
2190  CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("12Ž34"));
2191 }
2192 
2193 TEST_CASE("IsValidSwTraceNameStringTest")
2194 {
2195  // Valid SWTrace name strings
2196  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>(""));
2197  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("_"));
2198  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("0123"));
2199  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("valid_string"));
2200  CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("VALID_string_456"));
2201 
2202  // Invalid SWTrace name strings
2203  CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>(" "));
2204  CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid string"));
2205  CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("!$%"));
2206  CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid|\\~string#123"));
2207  CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("€£"));
2208  CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid‡string"));
2209  CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("12Ž34"));
2210 }
2211 
2212 template <typename SwTracePolicy>
2213 void StringToSwTraceStringTestHelper(const std::string& testString, std::vector<uint32_t> buffer, size_t expectedSize)
2214 {
2215  // Convert the test string to a SWTrace string
2216  CHECK(arm::pipe::StringToSwTraceString<SwTracePolicy>(testString, buffer));
2217 
2218  // The buffer must contain at least the length of the string
2219  CHECK(!buffer.empty());
2220 
2221  // The buffer must be of the expected size (in words)
2222  CHECK(buffer.size() == expectedSize);
2223 
2224  // The first word of the byte must be the length of the string including the null-terminator
2225  CHECK(buffer[0] == testString.size() + 1);
2226 
2227  // The contents of the buffer must match the test string
2228  CHECK(std::memcmp(testString.data(), buffer.data() + 1, testString.size()) == 0);
2229 
2230  // The buffer must include the null-terminator at the end of the string
2231  size_t nullTerminatorIndex = sizeof(uint32_t) + testString.size();
2232  CHECK(reinterpret_cast<unsigned char*>(buffer.data())[nullTerminatorIndex] == '\0');
2233 }
2234 
2235 TEST_CASE("StringToSwTraceStringTest")
2236 {
2237  std::vector<uint32_t> buffer;
2238 
2239  // Valid SWTrace strings (expected size in words)
2240  StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("", buffer, 2);
2241  StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("_", buffer, 2);
2242  StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("0123", buffer, 3);
2243  StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("valid_string", buffer, 5);
2244  StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("VALID_string_456", buffer, 6);
2245  StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>(" ", buffer, 2);
2246  StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("valid string", buffer, 5);
2247  StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("!$%", buffer, 2);
2248  StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("valid|\\~string#123", buffer, 6);
2249 
2250  // Invalid SWTrace strings
2251  CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceCharPolicy>("€£", buffer));
2252  CHECK(buffer.empty());
2253  CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceCharPolicy>("invalid‡string", buffer));
2254  CHECK(buffer.empty());
2255  CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceCharPolicy>("12Ž34", buffer));
2256  CHECK(buffer.empty());
2257 }
2258 
2259 TEST_CASE("StringToSwTraceNameStringTest")
2260 {
2261  std::vector<uint32_t> buffer;
2262 
2263  // Valid SWTrace namestrings (expected size in words)
2264  StringToSwTraceStringTestHelper<arm::pipe::SwTraceNameCharPolicy>("", buffer, 2);
2265  StringToSwTraceStringTestHelper<arm::pipe::SwTraceNameCharPolicy>("_", buffer, 2);
2266  StringToSwTraceStringTestHelper<arm::pipe::SwTraceNameCharPolicy>("0123", buffer, 3);
2267  StringToSwTraceStringTestHelper<arm::pipe::SwTraceNameCharPolicy>("valid_string", buffer, 5);
2268  StringToSwTraceStringTestHelper<arm::pipe::SwTraceNameCharPolicy>("VALID_string_456", buffer, 6);
2269 
2270  // Invalid SWTrace namestrings
2271  CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>(" ", buffer));
2272  CHECK(buffer.empty());
2273  CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid string", buffer));
2274  CHECK(buffer.empty());
2275  CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("!$%", buffer));
2276  CHECK(buffer.empty());
2277  CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid|\\~string#123", buffer));
2278  CHECK(buffer.empty());
2279  CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("€£", buffer));
2280  CHECK(buffer.empty());
2281  CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid‡string", buffer));
2282  CHECK(buffer.empty());
2283  CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("12Ž34", buffer));
2284  CHECK(buffer.empty());
2285 }
2286 
2287 TEST_CASE("CheckPeriodicCounterCaptureThread")
2288 {
2289  class CaptureReader : public IReadCounterValues
2290  {
2291  public:
2292  CaptureReader(uint16_t counterSize)
2293  {
2294  for (uint16_t i = 0; i < counterSize; ++i)
2295  {
2296  m_Data[i] = 0;
2297  }
2298  m_CounterSize = counterSize;
2299  }
2300  //not used
2301  bool IsCounterRegistered(uint16_t counterUid) const override
2302  {
2303  armnn::IgnoreUnused(counterUid);
2304  return false;
2305  }
2306 
2307  uint16_t GetCounterCount() const override
2308  {
2309  return m_CounterSize;
2310  }
2311 
2312  uint32_t GetAbsoluteCounterValue(uint16_t counterUid) const override
2313  {
2314  if (counterUid > m_CounterSize)
2315  {
2316  FAIL("Invalid counter Uid");
2317  }
2318  return m_Data.at(counterUid).load();
2319  }
2320 
2321  uint32_t GetDeltaCounterValue(uint16_t counterUid) override
2322  {
2323  if (counterUid > m_CounterSize)
2324  {
2325  FAIL("Invalid counter Uid");
2326  }
2327  return m_Data.at(counterUid).load();
2328  }
2329 
2330  void SetCounterValue(uint16_t counterUid, uint32_t value)
2331  {
2332  if (counterUid > m_CounterSize)
2333  {
2334  FAIL("Invalid counter Uid");
2335  }
2336  m_Data.at(counterUid).store(value);
2337  }
2338 
2339  private:
2340  std::unordered_map<uint16_t, std::atomic<uint32_t>> m_Data;
2341  uint16_t m_CounterSize;
2342  };
2343 
2344  ProfilingStateMachine profilingStateMachine;
2345 
2346  const std::unordered_map<armnn::BackendId,
2347  std::shared_ptr<armnn::profiling::IBackendProfilingContext>> backendProfilingContext;
2348  CounterIdMap counterIdMap;
2349  Holder data;
2350  std::vector<uint16_t> captureIds1 = { 0, 1 };
2351  std::vector<uint16_t> captureIds2;
2352 
2353  MockBufferManager mockBuffer(512);
2354  SendCounterPacket sendCounterPacket(mockBuffer);
2355  SendThread sendThread(profilingStateMachine, mockBuffer, sendCounterPacket);
2356 
2357  std::vector<uint16_t> counterIds;
2358  CaptureReader captureReader(2);
2359 
2360  unsigned int valueA = 10;
2361  unsigned int valueB = 15;
2362  unsigned int numSteps = 5;
2363 
2364  PeriodicCounterCapture periodicCounterCapture(std::ref(data), std::ref(sendCounterPacket), captureReader,
2365  counterIdMap, backendProfilingContext);
2366 
2367  for (unsigned int i = 0; i < numSteps; ++i)
2368  {
2369  data.SetCaptureData(1, captureIds1, {});
2370  captureReader.SetCounterValue(0, valueA * (i + 1));
2371  captureReader.SetCounterValue(1, valueB * (i + 1));
2372 
2373  periodicCounterCapture.Start();
2374  periodicCounterCapture.Stop();
2375  }
2376 
2377  auto buffer = mockBuffer.GetReadableBuffer();
2378 
2379  uint32_t headerWord0 = ReadUint32(buffer, 0);
2380  uint32_t headerWord1 = ReadUint32(buffer, 4);
2381 
2382  CHECK(((headerWord0 >> 26) & 0x0000003F) == 3); // packet family
2383  CHECK(((headerWord0 >> 19) & 0x0000007F) == 0); // packet class
2384  CHECK(((headerWord0 >> 16) & 0x00000007) == 0); // packet type
2385  CHECK(headerWord1 == 20);
2386 
2387  uint32_t offset = 16;
2388  uint16_t readIndex = ReadUint16(buffer, offset);
2389  CHECK(0 == readIndex);
2390 
2391  offset += 2;
2392  uint32_t readValue = ReadUint32(buffer, offset);
2393  CHECK((valueA * numSteps) == readValue);
2394 
2395  offset += 4;
2396  readIndex = ReadUint16(buffer, offset);
2397  CHECK(1 == readIndex);
2398 
2399  offset += 2;
2400  readValue = ReadUint32(buffer, offset);
2401  CHECK((valueB * numSteps) == readValue);
2402 }
2403 
2404 TEST_CASE("RequestCounterDirectoryCommandHandlerTest1")
2405 {
2406  const uint32_t familyId = 0;
2407  const uint32_t packetId = 3;
2408  const uint32_t version = 1;
2409  ProfilingStateMachine profilingStateMachine;
2410  CounterDirectory counterDirectory;
2411  MockBufferManager mockBuffer1(1024);
2412  SendCounterPacket sendCounterPacket(mockBuffer1);
2413  SendThread sendThread(profilingStateMachine, mockBuffer1, sendCounterPacket);
2414  MockBufferManager mockBuffer2(1024);
2415  SendTimelinePacket sendTimelinePacket(mockBuffer2);
2416  RequestCounterDirectoryCommandHandler commandHandler(familyId, packetId, version, counterDirectory,
2417  sendCounterPacket, sendTimelinePacket, profilingStateMachine);
2418 
2419  const uint32_t wrongPacketId = 47;
2420  const uint32_t wrongHeader = (wrongPacketId & 0x000003FF) << 16;
2421 
2422  arm::pipe::Packet wrongPacket(wrongHeader);
2423 
2424  profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
2425  CHECK_THROWS_AS(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state
2426  profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
2427  CHECK_THROWS_AS(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state
2428  profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
2429  CHECK_THROWS_AS(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state
2430  profilingStateMachine.TransitionToState(ProfilingState::Active);
2431  CHECK_THROWS_AS(commandHandler(wrongPacket), armnn::InvalidArgumentException); // Wrong packet
2432 
2433  const uint32_t rightHeader = (packetId & 0x000003FF) << 16;
2434 
2435  arm::pipe::Packet rightPacket(rightHeader);
2436 
2437  CHECK_NOTHROW(commandHandler(rightPacket)); // Right packet
2438 
2439  auto readBuffer1 = mockBuffer1.GetReadableBuffer();
2440 
2441  uint32_t header1Word0 = ReadUint32(readBuffer1, 0);
2442  uint32_t header1Word1 = ReadUint32(readBuffer1, 4);
2443 
2444  // Counter directory packet
2445  CHECK(((header1Word0 >> 26) & 0x0000003F) == 0); // packet family
2446  CHECK(((header1Word0 >> 16) & 0x000003FF) == 2); // packet id
2447  CHECK(header1Word1 == 24); // data length
2448 
2449  uint32_t bodyHeader1Word0 = ReadUint32(readBuffer1, 8);
2450  uint16_t deviceRecordCount = armnn::numeric_cast<uint16_t>(bodyHeader1Word0 >> 16);
2451  CHECK(deviceRecordCount == 0); // device_records_count
2452 
2453  auto readBuffer2 = mockBuffer2.GetReadableBuffer();
2454 
2455  uint32_t header2Word0 = ReadUint32(readBuffer2, 0);
2456  uint32_t header2Word1 = ReadUint32(readBuffer2, 4);
2457 
2458  // Timeline message directory packet
2459  CHECK(((header2Word0 >> 26) & 0x0000003F) == 1); // packet family
2460  CHECK(((header2Word0 >> 16) & 0x000003FF) == 0); // packet id
2461  CHECK(header2Word1 == 443); // data length
2462 }
2463 
2464 TEST_CASE("RequestCounterDirectoryCommandHandlerTest2")
2465 {
2466  const uint32_t familyId = 0;
2467  const uint32_t packetId = 3;
2468  const uint32_t version = 1;
2469  ProfilingStateMachine profilingStateMachine;
2470  CounterDirectory counterDirectory;
2471  MockBufferManager mockBuffer1(1024);
2472  SendCounterPacket sendCounterPacket(mockBuffer1);
2473  SendThread sendThread(profilingStateMachine, mockBuffer1, sendCounterPacket);
2474  MockBufferManager mockBuffer2(1024);
2475  SendTimelinePacket sendTimelinePacket(mockBuffer2);
2476  RequestCounterDirectoryCommandHandler commandHandler(familyId, packetId, version, counterDirectory,
2477  sendCounterPacket, sendTimelinePacket, profilingStateMachine);
2478  const uint32_t header = (packetId & 0x000003FF) << 16;
2479  const arm::pipe::Packet packet(header);
2480 
2481  const Device* device = counterDirectory.RegisterDevice("deviceA", 1);
2482  CHECK(device != nullptr);
2483  const CounterSet* counterSet = counterDirectory.RegisterCounterSet("countersetA");
2484  CHECK(counterSet != nullptr);
2485  counterDirectory.RegisterCategory("categoryA");
2486  counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID, 24,
2487  "categoryA", 0, 1, 2.0f, "counterA", "descA");
2488  counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID, 25,
2489  "categoryA", 1, 1, 3.0f, "counterB", "descB");
2490 
2491  profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
2492  CHECK_THROWS_AS(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state
2493  profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
2494  CHECK_THROWS_AS(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state
2495  profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
2496  CHECK_THROWS_AS(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state
2497  profilingStateMachine.TransitionToState(ProfilingState::Active);
2498  CHECK_NOTHROW(commandHandler(packet));
2499 
2500  auto readBuffer1 = mockBuffer1.GetReadableBuffer();
2501 
2502  const uint32_t header1Word0 = ReadUint32(readBuffer1, 0);
2503  const uint32_t header1Word1 = ReadUint32(readBuffer1, 4);
2504 
2505  CHECK(((header1Word0 >> 26) & 0x0000003F) == 0); // packet family
2506  CHECK(((header1Word0 >> 16) & 0x000003FF) == 2); // packet id
2507  CHECK(header1Word1 == 236); // data length
2508 
2509  const uint32_t bodyHeaderSizeBytes = bodyHeaderSize * sizeof(uint32_t);
2510 
2511  const uint32_t bodyHeader1Word0 = ReadUint32(readBuffer1, 8);
2512  const uint32_t bodyHeader1Word1 = ReadUint32(readBuffer1, 12);
2513  const uint32_t bodyHeader1Word2 = ReadUint32(readBuffer1, 16);
2514  const uint32_t bodyHeader1Word3 = ReadUint32(readBuffer1, 20);
2515  const uint32_t bodyHeader1Word4 = ReadUint32(readBuffer1, 24);
2516  const uint32_t bodyHeader1Word5 = ReadUint32(readBuffer1, 28);
2517  const uint16_t deviceRecordCount = armnn::numeric_cast<uint16_t>(bodyHeader1Word0 >> 16);
2518  const uint16_t counterSetRecordCount = armnn::numeric_cast<uint16_t>(bodyHeader1Word2 >> 16);
2519  const uint16_t categoryRecordCount = armnn::numeric_cast<uint16_t>(bodyHeader1Word4 >> 16);
2520  CHECK(deviceRecordCount == 1); // device_records_count
2521  CHECK(bodyHeader1Word1 == 0 + bodyHeaderSizeBytes); // device_records_pointer_table_offset
2522  CHECK(counterSetRecordCount == 1); // counter_set_count
2523  CHECK(bodyHeader1Word3 == 4 + bodyHeaderSizeBytes); // counter_set_pointer_table_offset
2524  CHECK(categoryRecordCount == 1); // categories_count
2525  CHECK(bodyHeader1Word5 == 8 + bodyHeaderSizeBytes); // categories_pointer_table_offset
2526 
2527  const uint32_t deviceRecordOffset = ReadUint32(readBuffer1, 32);
2528  CHECK(deviceRecordOffset == 12);
2529 
2530  const uint32_t counterSetRecordOffset = ReadUint32(readBuffer1, 36);
2531  CHECK(counterSetRecordOffset == 28);
2532 
2533  const uint32_t categoryRecordOffset = ReadUint32(readBuffer1, 40);
2534  CHECK(categoryRecordOffset == 48);
2535 
2536  auto readBuffer2 = mockBuffer2.GetReadableBuffer();
2537 
2538  const uint32_t header2Word0 = ReadUint32(readBuffer2, 0);
2539  const uint32_t header2Word1 = ReadUint32(readBuffer2, 4);
2540 
2541  // Timeline message directory packet
2542  CHECK(((header2Word0 >> 26) & 0x0000003F) == 1); // packet family
2543  CHECK(((header2Word0 >> 16) & 0x000003FF) == 0); // packet id
2544  CHECK(header2Word1 == 443); // data length
2545 }
2546 
2547 TEST_CASE("CheckProfilingServiceGoodConnectionAcknowledgedPacket")
2548 {
2549  unsigned int streamMetadataPacketsize = GetStreamMetaDataPacketSize();
2550 
2551  // Reset the profiling service to the uninitialized state
2553  options.m_EnableProfiling = true;
2554  armnn::profiling::ProfilingService profilingService;
2555  profilingService.ResetExternalProfilingOptions(options, true);
2556 
2557  // Swap the profiling connection factory in the profiling service instance with our mock one
2558  SwapProfilingConnectionFactoryHelper helper(profilingService);
2559 
2560  // Bring the profiling service to the "WaitingForAck" state
2561  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2562  profilingService.Update(); // Initialize the counter directory
2563  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2564  profilingService.Update(); // Create the profiling connection
2565 
2566  // Get the mock profiling connection
2567  MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2568  CHECK(mockProfilingConnection);
2569 
2570  // Remove the packets received so far
2571  mockProfilingConnection->Clear();
2572 
2573  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2574  profilingService.Update(); // Start the command handler and the send thread
2575 
2576  // Wait for the Stream Metadata packet to be sent
2577  CHECK(helper.WaitForPacketsSent(
2578  mockProfilingConnection, PacketType::StreamMetaData, streamMetadataPacketsize) >= 1);
2579 
2580  // Write a valid "Connection Acknowledged" packet into the mock profiling connection, to simulate a valid
2581  // reply from an external profiling service
2582 
2583  // Connection Acknowledged Packet header (word 0, word 1 is always zero):
2584  // 26:31 [6] packet_family: Control Packet Family, value 0b000000
2585  // 16:25 [10] packet_id: Packet identifier, value 0b0000000001
2586  // 8:15 [8] reserved: Reserved, value 0b00000000
2587  // 0:7 [8] reserved: Reserved, value 0b00000000
2588  uint32_t packetFamily = 0;
2589  uint32_t packetId = 1;
2590  uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2591 
2592  // Create the Connection Acknowledged Packet
2593  arm::pipe::Packet connectionAcknowledgedPacket(header);
2594 
2595  // Write the packet to the mock profiling connection
2596  mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket));
2597 
2598  // Wait for the counter directory packet to ensure the ConnectionAcknowledgedCommandHandler has run.
2599  CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::CounterDirectory) == 1);
2600 
2601  // The Connection Acknowledged Command Handler should have updated the profiling state accordingly
2602  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2603 
2604  // Reset the profiling service to stop any running thread
2605  options.m_EnableProfiling = false;
2606  profilingService.ResetExternalProfilingOptions(options, true);
2607 }
2608 
2609 TEST_CASE("CheckProfilingServiceGoodRequestCounterDirectoryPacket")
2610 {
2611  // Reset the profiling service to the uninitialized state
2613  options.m_EnableProfiling = true;
2614  armnn::profiling::ProfilingService profilingService;
2615  profilingService.ResetExternalProfilingOptions(options, true);
2616 
2617  // Swap the profiling connection factory in the profiling service instance with our mock one
2618  SwapProfilingConnectionFactoryHelper helper(profilingService);
2619 
2620  // Bring the profiling service to the "Active" state
2621  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2622  profilingService.Update(); // Initialize the counter directory
2623  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2624  profilingService.Update(); // Create the profiling connection
2625  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2626  profilingService.Update(); // Start the command handler and the send thread
2627 
2628  // Get the mock profiling connection
2629  MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2630  CHECK(mockProfilingConnection);
2631 
2632  // Force the profiling service to the "Active" state
2633  helper.ForceTransitionToState(ProfilingState::Active);
2634  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2635 
2636  // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
2637  // reply from an external profiling service
2638 
2639  // Request Counter Directory packet header (word 0, word 1 is always zero):
2640  // 26:31 [6] packet_family: Control Packet Family, value 0b000000
2641  // 16:25 [10] packet_id: Packet identifier, value 0b0000000011
2642  // 8:15 [8] reserved: Reserved, value 0b00000000
2643  // 0:7 [8] reserved: Reserved, value 0b00000000
2644  uint32_t packetFamily = 0;
2645  uint32_t packetId = 3;
2646  uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2647 
2648  // Create the Request Counter Directory packet
2649  arm::pipe::Packet requestCounterDirectoryPacket(header);
2650 
2651  // Write the packet to the mock profiling connection
2652  mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
2653 
2654  // Expecting one CounterDirectory Packet of length 652
2655  // and one TimelineMessageDirectory packet of length 451
2656  CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::CounterDirectory, 652) == 1);
2657  CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::TimelineMessageDirectory, 451) == 1);
2658 
2659  // The Request Counter Directory Command Handler should not have updated the profiling state
2660  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2661 
2662  // Reset the profiling service to stop any running thread
2663  options.m_EnableProfiling = false;
2664  profilingService.ResetExternalProfilingOptions(options, true);
2665 }
2666 
2667 TEST_CASE("CheckProfilingServiceBadPeriodicCounterSelectionPacketInvalidCounterUid")
2668 {
2669  // Reset the profiling service to the uninitialized state
2671  options.m_EnableProfiling = true;
2672  armnn::profiling::ProfilingService profilingService;
2673  profilingService.ResetExternalProfilingOptions(options, true);
2674 
2675  // Swap the profiling connection factory in the profiling service instance with our mock one
2676  SwapProfilingConnectionFactoryHelper helper(profilingService);
2677 
2678  // Bring the profiling service to the "Active" state
2679  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2680  profilingService.Update(); // Initialize the counter directory
2681  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2682  profilingService.Update(); // Create the profiling connection
2683  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2684  profilingService.Update(); // Start the command handler and the send thread
2685 
2686  // Get the mock profiling connection
2687  MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2688  CHECK(mockProfilingConnection);
2689 
2690  // Force the profiling service to the "Active" state
2691  helper.ForceTransitionToState(ProfilingState::Active);
2692  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2693 
2694  // Remove the packets received so far
2695  mockProfilingConnection->Clear();
2696 
2697  // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2698  // external profiling service
2699 
2700  // Periodic Counter Selection packet header:
2701  // 26:31 [6] packet_family: Control Packet Family, value 0b000000
2702  // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2703  // 8:15 [8] reserved: Reserved, value 0b00000000
2704  // 0:7 [8] reserved: Reserved, value 0b00000000
2705  uint32_t packetFamily = 0;
2706  uint32_t packetId = 4;
2707  uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2708 
2709  uint32_t capturePeriod = 123456; // Some capture period (microseconds)
2710 
2711  // Get the first valid counter UID
2712  const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
2713  const Counters& counters = counterDirectory.GetCounters();
2714  CHECK(counters.size() > 1);
2715  uint16_t counterUidA = counters.begin()->first; // First valid counter UID
2716  uint16_t counterUidB = 9999; // Second invalid counter UID
2717 
2718  uint32_t length = 8;
2719 
2720  auto data = std::make_unique<unsigned char[]>(length);
2721  WriteUint32(data.get(), 0, capturePeriod);
2722  WriteUint16(data.get(), 4, counterUidA);
2723  WriteUint16(data.get(), 6, counterUidB);
2724 
2725  // Create the Periodic Counter Selection packet
2726  // Length > 0, this will start the Period Counter Capture thread
2727  arm::pipe::Packet periodicCounterSelectionPacket(header, length, data);
2728 
2729 
2730  // Write the packet to the mock profiling connection
2731  mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2732 
2733  // Expecting one Periodic Counter Selection packet of length 14
2734  // and at least one Periodic Counter Capture packet of length 22
2735  CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterSelection, 14) == 1);
2736  CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterCapture, 22) >= 1);
2737 
2738  // The Periodic Counter Selection Handler should not have updated the profiling state
2739  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2740 
2741  // Reset the profiling service to stop any running thread
2742  options.m_EnableProfiling = false;
2743  profilingService.ResetExternalProfilingOptions(options, true);
2744 }
2745 
2746 TEST_CASE("CheckProfilingServiceGoodPeriodicCounterSelectionPacketNoCounters")
2747 {
2748  // Reset the profiling service to the uninitialized state
2750  options.m_EnableProfiling = true;
2751  armnn::profiling::ProfilingService profilingService;
2752  profilingService.ResetExternalProfilingOptions(options, true);
2753 
2754  // Swap the profiling connection factory in the profiling service instance with our mock one
2755  SwapProfilingConnectionFactoryHelper helper(profilingService);
2756 
2757  // Bring the profiling service to the "Active" state
2758  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2759  profilingService.Update(); // Initialize the counter directory
2760  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2761  profilingService.Update(); // Create the profiling connection
2762  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2763  profilingService.Update(); // Start the command handler and the send thread
2764 
2765  // Get the mock profiling connection
2766  MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2767  CHECK(mockProfilingConnection);
2768 
2769  // Wait for the Stream Metadata packet the be sent
2770  // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2771  helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData);
2772 
2773  // Force the profiling service to the "Active" state
2774  helper.ForceTransitionToState(ProfilingState::Active);
2775  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2776 
2777  // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2778  // external profiling service
2779 
2780  // Periodic Counter Selection packet header:
2781  // 26:31 [6] packet_family: Control Packet Family, value 0b000000
2782  // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2783  // 8:15 [8] reserved: Reserved, value 0b00000000
2784  // 0:7 [8] reserved: Reserved, value 0b00000000
2785  uint32_t packetFamily = 0;
2786  uint32_t packetId = 4;
2787  uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2788 
2789  // Create the Periodic Counter Selection packet
2790  // Length == 0, this will disable the collection of counters
2791  arm::pipe::Packet periodicCounterSelectionPacket(header);
2792 
2793  // Write the packet to the mock profiling connection
2794  mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2795 
2796  // Wait for the Periodic Counter Selection packet of length 12 to be sent
2797  // The size of the expected Periodic Counter Selection (echos the sent one)
2798  CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterSelection, 12) == 1);
2799 
2800  // The Periodic Counter Selection Handler should not have updated the profiling state
2801  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2802 
2803  // No Periodic Counter packets are expected
2804  CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterCapture, 0, 0) == 0);
2805 
2806  // Reset the profiling service to stop any running thread
2807  options.m_EnableProfiling = false;
2808  profilingService.ResetExternalProfilingOptions(options, true);
2809 }
2810 
2811 TEST_CASE("CheckProfilingServiceGoodPeriodicCounterSelectionPacketSingleCounter")
2812 {
2813  // Reset the profiling service to the uninitialized state
2815  options.m_EnableProfiling = true;
2816  armnn::profiling::ProfilingService profilingService;
2817  profilingService.ResetExternalProfilingOptions(options, true);
2818 
2819  // Swap the profiling connection factory in the profiling service instance with our mock one
2820  SwapProfilingConnectionFactoryHelper helper(profilingService);
2821 
2822  // Bring the profiling service to the "Active" state
2823  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2824  profilingService.Update(); // Initialize the counter directory
2825  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2826  profilingService.Update(); // Create the profiling connection
2827  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2828  profilingService.Update(); // Start the command handler and the send thread
2829 
2830  // Get the mock profiling connection
2831  MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2832  CHECK(mockProfilingConnection);
2833 
2834  // Wait for the Stream Metadata packet to be sent
2835  // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2836  helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData);
2837 
2838  // Force the profiling service to the "Active" state
2839  helper.ForceTransitionToState(ProfilingState::Active);
2840  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2841 
2842  // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2843  // external profiling service
2844 
2845  // Periodic Counter Selection packet header:
2846  // 26:31 [6] packet_family: Control Packet Family, value 0b000000
2847  // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2848  // 8:15 [8] reserved: Reserved, value 0b00000000
2849  // 0:7 [8] reserved: Reserved, value 0b00000000
2850  uint32_t packetFamily = 0;
2851  uint32_t packetId = 4;
2852  uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2853 
2854  uint32_t capturePeriod = 123456; // Some capture period (microseconds)
2855 
2856  // Get the first valid counter UID
2857  const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
2858  const Counters& counters = counterDirectory.GetCounters();
2859  CHECK(!counters.empty());
2860  uint16_t counterUid = counters.begin()->first; // Valid counter UID
2861 
2862  uint32_t length = 6;
2863 
2864  auto data = std::make_unique<unsigned char[]>(length);
2865  WriteUint32(data.get(), 0, capturePeriod);
2866  WriteUint16(data.get(), 4, counterUid);
2867 
2868  // Create the Periodic Counter Selection packet
2869  // Length > 0, this will start the Period Counter Capture thread
2870  arm::pipe::Packet periodicCounterSelectionPacket(header, length, data);
2871 
2872  // Write the packet to the mock profiling connection
2873  mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2874 
2875  // Expecting one Periodic Counter Selection packet of length 14
2876  // and at least one Periodic Counter Capture packet of length 22
2877  CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterSelection, 14) == 1);
2878  CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterCapture, 22) >= 1);
2879 
2880  // The Periodic Counter Selection Handler should not have updated the profiling state
2881  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2882 
2883  // Reset the profiling service to stop any running thread
2884  options.m_EnableProfiling = false;
2885  profilingService.ResetExternalProfilingOptions(options, true);
2886 }
2887 
2888 TEST_CASE("CheckProfilingServiceGoodPeriodicCounterSelectionPacketMultipleCounters")
2889 {
2890  // Reset the profiling service to the uninitialized state
2892  options.m_EnableProfiling = true;
2893  armnn::profiling::ProfilingService profilingService;
2894  profilingService.ResetExternalProfilingOptions(options, true);
2895 
2896  // Swap the profiling connection factory in the profiling service instance with our mock one
2897  SwapProfilingConnectionFactoryHelper helper(profilingService);
2898 
2899  // Bring the profiling service to the "Active" state
2900  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2901  profilingService.Update(); // Initialize the counter directory
2902  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2903  profilingService.Update(); // Create the profiling connection
2904  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2905  profilingService.Update(); // Start the command handler and the send thread
2906 
2907  // Get the mock profiling connection
2908  MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2909  CHECK(mockProfilingConnection);
2910 
2911  // Wait for the Stream Metadata packet the be sent
2912  // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2913  helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData);
2914 
2915  // Force the profiling service to the "Active" state
2916  helper.ForceTransitionToState(ProfilingState::Active);
2917  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2918 
2919  // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2920  // external profiling service
2921 
2922  // Periodic Counter Selection packet header:
2923  // 26:31 [6] packet_family: Control Packet Family, value 0b000000
2924  // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2925  // 8:15 [8] reserved: Reserved, value 0b00000000
2926  // 0:7 [8] reserved: Reserved, value 0b00000000
2927  uint32_t packetFamily = 0;
2928  uint32_t packetId = 4;
2929  uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2930 
2931  uint32_t capturePeriod = 123456; // Some capture period (microseconds)
2932 
2933  // Get the first valid counter UID
2934  const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
2935  const Counters& counters = counterDirectory.GetCounters();
2936  CHECK(counters.size() > 1);
2937  uint16_t counterUidA = counters.begin()->first; // First valid counter UID
2938  uint16_t counterUidB = (counters.begin()++)->first; // Second valid counter UID
2939 
2940  uint32_t length = 8;
2941 
2942  auto data = std::make_unique<unsigned char[]>(length);
2943  WriteUint32(data.get(), 0, capturePeriod);
2944  WriteUint16(data.get(), 4, counterUidA);
2945  WriteUint16(data.get(), 6, counterUidB);
2946 
2947  // Create the Periodic Counter Selection packet
2948  // Length > 0, this will start the Period Counter Capture thread
2949  arm::pipe::Packet periodicCounterSelectionPacket(header, length, data);
2950 
2951  // Write the packet to the mock profiling connection
2952  mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2953 
2954  // Expecting one PeriodicCounterSelection Packet with a length of 16
2955  // And at least one PeriodicCounterCapture Packet with a length of 28
2956  CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterSelection, 16) == 1);
2957  CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterCapture, 28) >= 1);
2958 
2959  // The Periodic Counter Selection Handler should not have updated the profiling state
2960  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2961 
2962  // Reset the profiling service to stop any running thread
2963  options.m_EnableProfiling = false;
2964  profilingService.ResetExternalProfilingOptions(options, true);
2965 }
2966 
2967 TEST_CASE("CheckProfilingServiceDisconnect")
2968 {
2969  // Reset the profiling service to the uninitialized state
2971  options.m_EnableProfiling = true;
2972  armnn::profiling::ProfilingService profilingService;
2973  profilingService.ResetExternalProfilingOptions(options, true);
2974 
2975  // Swap the profiling connection factory in the profiling service instance with our mock one
2976  SwapProfilingConnectionFactoryHelper helper(profilingService);
2977 
2978  // Try to disconnect the profiling service while in the "Uninitialised" state
2979  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2980  profilingService.Disconnect();
2981  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); // The state should not change
2982 
2983  // Try to disconnect the profiling service while in the "NotConnected" state
2984  profilingService.Update(); // Initialize the counter directory
2985  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2986  profilingService.Disconnect();
2987  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); // The state should not change
2988 
2989  // Try to disconnect the profiling service while in the "WaitingForAck" state
2990  profilingService.Update(); // Create the profiling connection
2991  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2992  profilingService.Disconnect();
2993  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); // The state should not change
2994 
2995  // Try to disconnect the profiling service while in the "Active" state
2996  profilingService.Update(); // Start the command handler and the send thread
2997 
2998  // Get the mock profiling connection
2999  MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
3000  CHECK(mockProfilingConnection);
3001 
3002  // Wait for the Stream Metadata packet the be sent
3003  // (we are not testing the connection acknowledgement here so it will be ignored by this test)
3004  helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData);
3005 
3006  // Force the profiling service to the "Active" state
3007  helper.ForceTransitionToState(ProfilingState::Active);
3008  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
3009 
3010  // Check that the profiling connection is open
3011  CHECK(mockProfilingConnection->IsOpen());
3012 
3013  profilingService.Disconnect();
3014  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); // The state should have changed
3015 
3016  // Check that the profiling connection has been reset
3017  mockProfilingConnection = helper.GetMockProfilingConnection();
3018  CHECK(mockProfilingConnection == nullptr);
3019 
3020  // Reset the profiling service to stop any running thread
3021  options.m_EnableProfiling = false;
3022  profilingService.ResetExternalProfilingOptions(options, true);
3023 }
3024 
3025 TEST_CASE("CheckProfilingServiceGoodPerJobCounterSelectionPacket")
3026 {
3027  // Reset the profiling service to the uninitialized state
3029  options.m_EnableProfiling = true;
3030  armnn::profiling::ProfilingService profilingService;
3031  profilingService.ResetExternalProfilingOptions(options, true);
3032 
3033  // Swap the profiling connection factory in the profiling service instance with our mock one
3034  SwapProfilingConnectionFactoryHelper helper(profilingService);
3035 
3036  // Bring the profiling service to the "Active" state
3037  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3038  profilingService.Update(); // Initialize the counter directory
3039  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3040  profilingService.Update(); // Create the profiling connection
3041  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
3042  profilingService.Update(); // Start the command handler and the send thread
3043 
3044  // Get the mock profiling connection
3045  MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
3046  CHECK(mockProfilingConnection);
3047 
3048  // Wait for the Stream Metadata packet the be sent
3049  // (we are not testing the connection acknowledgement here so it will be ignored by this test)
3050  helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData);
3051 
3052  // Force the profiling service to the "Active" state
3053  helper.ForceTransitionToState(ProfilingState::Active);
3054  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
3055 
3056  // Write a "Per-Job Counter Selection" packet into the mock profiling connection, to simulate an input from an
3057  // external profiling service
3058 
3059  // Per-Job Counter Selection packet header:
3060  // 26:31 [6] packet_family: Control Packet Family, value 0b000000
3061  // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
3062  // 8:15 [8] reserved: Reserved, value 0b00000000
3063  // 0:7 [8] reserved: Reserved, value 0b00000000
3064  uint32_t packetFamily = 0;
3065  uint32_t packetId = 5;
3066  uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
3067 
3068  // Create the Per-Job Counter Selection packet
3069  // Length == 0, this will disable the collection of counters
3070  arm::pipe::Packet periodicCounterSelectionPacket(header);
3071 
3072  // Write the packet to the mock profiling connection
3073  mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
3074 
3075  // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
3076  // the Per-Job Counter Selection packet gets processed by the profiling service
3077  std::this_thread::sleep_for(std::chrono::milliseconds(5));
3078 
3079  // The Per-Job Counter Selection Command Handler should not have updated the profiling state
3080  CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
3081 
3082  // The Per-Job Counter Selection packets are dropped silently, so there should be no reply coming
3083  // from the profiling service
3084  const auto StreamMetaDataSize = static_cast<unsigned long>(
3085  helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData, 0, 0));
3086  CHECK(StreamMetaDataSize == mockProfilingConnection->GetWrittenDataSize());
3087 
3088  // Reset the profiling service to stop any running thread
3089  options.m_EnableProfiling = false;
3090  profilingService.ResetExternalProfilingOptions(options, true);
3091 }
3092 
3093 TEST_CASE("CheckConfigureProfilingServiceOn")
3094 {
3096  options.m_EnableProfiling = true;
3097  armnn::profiling::ProfilingService profilingService;
3098  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3099  profilingService.ConfigureProfilingService(options);
3100  // should get as far as NOT_CONNECTED
3101  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3102  // Reset the profiling service to stop any running thread
3103  options.m_EnableProfiling = false;
3104  profilingService.ResetExternalProfilingOptions(options, true);
3105 }
3106 
3107 TEST_CASE("CheckConfigureProfilingServiceOff")
3108 {
3110  armnn::profiling::ProfilingService profilingService;
3111  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3112  profilingService.ConfigureProfilingService(options);
3113  // should not move from Uninitialised
3114  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3115  // Reset the profiling service to stop any running thread
3116  options.m_EnableProfiling = false;
3117  profilingService.ResetExternalProfilingOptions(options, true);
3118 }
3119 
3120 TEST_CASE("CheckProfilingServiceEnabled")
3121 {
3122  // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3125  options.m_EnableProfiling = true;
3126  armnn::profiling::ProfilingService profilingService;
3127  profilingService.ResetExternalProfilingOptions(options, true);
3128  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3129  profilingService.Update();
3130  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3131 
3132  // Redirect the output to a local stream so that we can parse the warning message
3133  std::stringstream ss;
3134  StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3135  profilingService.Update();
3136 
3137  // Reset the profiling service to stop any running thread
3138  options.m_EnableProfiling = false;
3139  profilingService.ResetExternalProfilingOptions(options, true);
3140 
3141  streamRedirector.CancelRedirect();
3142 
3143  // Check that the expected error has occurred and logged to the standard output
3144  if (ss.str().find("Cannot connect to stream socket: Connection refused") == std::string::npos)
3145  {
3146  std::cout << ss.str();
3147  FAIL("Expected string not found.");
3148  }
3149 }
3150 
3151 TEST_CASE("CheckProfilingServiceEnabledRuntime")
3152 {
3153  // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3156  armnn::profiling::ProfilingService profilingService;
3157  profilingService.ResetExternalProfilingOptions(options, true);
3158  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3159  profilingService.Update();
3160  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3161  options.m_EnableProfiling = true;
3162  profilingService.ResetExternalProfilingOptions(options);
3163  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3164  profilingService.Update();
3165  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3166 
3167  // Redirect the output to a local stream so that we can parse the warning message
3168  std::stringstream ss;
3169  StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3170  profilingService.Update();
3171 
3172  // Reset the profiling service to stop any running thread
3173  options.m_EnableProfiling = false;
3174  profilingService.ResetExternalProfilingOptions(options, true);
3175 
3176  streamRedirector.CancelRedirect();
3177 
3178  // Check that the expected error has occurred and logged to the standard output
3179  if (ss.str().find("Cannot connect to stream socket: Connection refused") == std::string::npos)
3180  {
3181  std::cout << ss.str();
3182  FAIL("Expected string not found.");
3183  }
3184 }
3185 
3186 TEST_CASE("CheckProfilingServiceBadConnectionAcknowledgedPacket")
3187 {
3188  // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3190 
3191 
3192  // Redirect the standard output to a local stream so that we can parse the warning message
3193  std::stringstream ss;
3194  StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3195 
3196  // Reset the profiling service to the uninitialized state
3198  options.m_EnableProfiling = true;
3199  armnn::profiling::ProfilingService profilingService;
3200  profilingService.ResetExternalProfilingOptions(options, true);
3201 
3202  // Swap the profiling connection factory in the profiling service instance with our mock one
3203  SwapProfilingConnectionFactoryHelper helper(profilingService);
3204 
3205  // Bring the profiling service to the "WaitingForAck" state
3206  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3207  profilingService.Update(); // Initialize the counter directory
3208  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3209  profilingService.Update(); // Create the profiling connection
3210 
3211  // Get the mock profiling connection
3212  MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
3213  CHECK(mockProfilingConnection);
3214 
3215  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
3216 
3217  // Connection Acknowledged Packet header (word 0, word 1 is always zero):
3218  // 26:31 [6] packet_family: Control Packet Family, value 0b000000
3219  // 16:25 [10] packet_id: Packet identifier, value 0b0000000001
3220  // 8:15 [8] reserved: Reserved, value 0b00000000
3221  // 0:7 [8] reserved: Reserved, value 0b00000000
3222  uint32_t packetFamily = 0;
3223  uint32_t packetId = 37; // Wrong packet id!!!
3224  uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
3225 
3226  // Create the Connection Acknowledged Packet
3227  arm::pipe::Packet connectionAcknowledgedPacket(header);
3228  // Write an invalid "Connection Acknowledged" packet into the mock profiling connection, to simulate an invalid
3229  // reply from an external profiling service
3230  mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket));
3231 
3232  // Start the command thread
3233  profilingService.Update();
3234 
3235  // Wait for the command thread to join
3236  options.m_EnableProfiling = false;
3237  profilingService.ResetExternalProfilingOptions(options, true);
3238 
3239  streamRedirector.CancelRedirect();
3240 
3241  // Check that the expected error has occurred and logged to the standard output
3242  if (ss.str().find("Functor with requested PacketId=37 and Version=4194304 does not exist") == std::string::npos)
3243  {
3244  std::cout << ss.str();
3245  FAIL("Expected string not found.");
3246  }
3247 }
3248 
3249 TEST_CASE("CheckProfilingServiceBadRequestCounterDirectoryPacket")
3250 {
3251  // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3253 
3254  // Redirect the standard output to a local stream so that we can parse the warning message
3255  std::stringstream ss;
3256  StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3257 
3258  // Reset the profiling service to the uninitialized state
3260  options.m_EnableProfiling = true;
3261  armnn::profiling::ProfilingService profilingService;
3262  profilingService.ResetExternalProfilingOptions(options, true);
3263 
3264  // Swap the profiling connection factory in the profiling service instance with our mock one
3265  SwapProfilingConnectionFactoryHelper helper(profilingService);
3266 
3267  // Bring the profiling service to the "Active" state
3268  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3269  helper.ForceTransitionToState(ProfilingState::NotConnected);
3270  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3271  profilingService.Update(); // Create the profiling connection
3272  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
3273 
3274  // Get the mock profiling connection
3275  MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
3276  CHECK(mockProfilingConnection);
3277 
3278  // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
3279  // reply from an external profiling service
3280 
3281  // Request Counter Directory packet header (word 0, word 1 is always zero):
3282  // 26:31 [6] packet_family: Control Packet Family, value 0b000000
3283  // 16:25 [10] packet_id: Packet identifier, value 0b0000000011
3284  // 8:15 [8] reserved: Reserved, value 0b00000000
3285  // 0:7 [8] reserved: Reserved, value 0b00000000
3286  uint32_t packetFamily = 0;
3287  uint32_t packetId = 123; // Wrong packet id!!!
3288  uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
3289 
3290  // Create the Request Counter Directory packet
3291  arm::pipe::Packet requestCounterDirectoryPacket(header);
3292 
3293  // Write the packet to the mock profiling connection
3294  mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
3295 
3296  // Start the command handler and the send thread
3297  profilingService.Update();
3298 
3299  // Reset the profiling service to stop and join any running thread
3300  options.m_EnableProfiling = false;
3301  profilingService.ResetExternalProfilingOptions(options, true);
3302 
3303  streamRedirector.CancelRedirect();
3304 
3305  // Check that the expected error has occurred and logged to the standard output
3306  if (ss.str().find("Functor with requested PacketId=123 and Version=4194304 does not exist") == std::string::npos)
3307  {
3308  std::cout << ss.str();
3309  FAIL("Expected string not found.");
3310  }
3311 }
3312 
3313 TEST_CASE("CheckProfilingServiceBadPeriodicCounterSelectionPacket")
3314 {
3315  // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3317 
3318  // Redirect the standard output to a local stream so that we can parse the warning message
3319  std::stringstream ss;
3320  StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3321 
3322  // Reset the profiling service to the uninitialized state
3324  options.m_EnableProfiling = true;
3325  armnn::profiling::ProfilingService profilingService;
3326  profilingService.ResetExternalProfilingOptions(options, true);
3327 
3328  // Swap the profiling connection factory in the profiling service instance with our mock one
3329  SwapProfilingConnectionFactoryHelper helper(profilingService);
3330 
3331  // Bring the profiling service to the "Active" state
3332  CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3333  profilingService.Update(); // Initialize the counter directory
3334  CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3335  profilingService.Update(); // Create the profiling connection
3336  CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
3337  profilingService.Update(); // Start the command handler and the send thread
3338 
3339  // Get the mock profiling connection
3340  MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
3341  CHECK(mockProfilingConnection);
3342 
3343  // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
3344  // external profiling service
3345 
3346  // Periodic Counter Selection packet header:
3347  // 26:31 [6] packet_family: Control Packet Family, value 0b000000
3348  // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
3349  // 8:15 [8] reserved: Reserved, value 0b00000000
3350  // 0:7 [8] reserved: Reserved, value 0b00000000
3351  uint32_t packetFamily = 0;
3352  uint32_t packetId = 999; // Wrong packet id!!!
3353  uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
3354 
3355  // Create the Periodic Counter Selection packet
3356  // Length == 0, this will disable the collection of counters
3357  arm::pipe::Packet periodicCounterSelectionPacket(header);
3358 
3359  // Write the packet to the mock profiling connection
3360  mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
3361  profilingService.Update();
3362 
3363  // Reset the profiling service to stop any running thread
3364  options.m_EnableProfiling = false;
3365  profilingService.ResetExternalProfilingOptions(options, true);
3366 
3367  // Check that the expected error has occurred and logged to the standard output
3368  streamRedirector.CancelRedirect();
3369 
3370  // Check that the expected error has occurred and logged to the standard output
3371  if (ss.str().find("Functor with requested PacketId=999 and Version=4194304 does not exist") == std::string::npos)
3372  {
3373  std::cout << ss.str();
3374  FAIL("Expected string not found.");
3375  }
3376 }
3377 
3378 TEST_CASE("CheckCounterIdMap")
3379 {
3380  CounterIdMap counterIdMap;
3381  CHECK_THROWS_AS(counterIdMap.GetBackendId(0), armnn::Exception);
3382  CHECK_THROWS_AS(counterIdMap.GetGlobalId(0, armnn::profiling::BACKEND_ID), armnn::Exception);
3383 
3384  uint16_t globalCounterIds = 0;
3385 
3386  armnn::BackendId cpuRefId(armnn::Compute::CpuRef);
3387  armnn::BackendId cpuAccId(armnn::Compute::CpuAcc);
3388 
3389  std::vector<uint16_t> cpuRefCounters = {0, 1, 2, 3};
3390  std::vector<uint16_t> cpuAccCounters = {0, 1};
3391 
3392  for (uint16_t backendCounterId : cpuRefCounters)
3393  {
3394  counterIdMap.RegisterMapping(globalCounterIds, backendCounterId, cpuRefId);
3395  ++globalCounterIds;
3396  }
3397  for (uint16_t backendCounterId : cpuAccCounters)
3398  {
3399  counterIdMap.RegisterMapping(globalCounterIds, backendCounterId, cpuAccId);
3400  ++globalCounterIds;
3401  }
3402 
3403  CHECK(counterIdMap.GetBackendId(0) == (std::pair<uint16_t, armnn::BackendId>(0, cpuRefId)));
3404  CHECK(counterIdMap.GetBackendId(1) == (std::pair<uint16_t, armnn::BackendId>(1, cpuRefId)));
3405  CHECK(counterIdMap.GetBackendId(2) == (std::pair<uint16_t, armnn::BackendId>(2, cpuRefId)));
3406  CHECK(counterIdMap.GetBackendId(3) == (std::pair<uint16_t, armnn::BackendId>(3, cpuRefId)));
3407  CHECK(counterIdMap.GetBackendId(4) == (std::pair<uint16_t, armnn::BackendId>(0, cpuAccId)));
3408  CHECK(counterIdMap.GetBackendId(5) == (std::pair<uint16_t, armnn::BackendId>(1, cpuAccId)));
3409 
3410  CHECK(counterIdMap.GetGlobalId(0, cpuRefId) == 0);
3411  CHECK(counterIdMap.GetGlobalId(1, cpuRefId) == 1);
3412  CHECK(counterIdMap.GetGlobalId(2, cpuRefId) == 2);
3413  CHECK(counterIdMap.GetGlobalId(3, cpuRefId) == 3);
3414  CHECK(counterIdMap.GetGlobalId(0, cpuAccId) == 4);
3415  CHECK(counterIdMap.GetGlobalId(1, cpuAccId) == 5);
3416 }
3417 
3418 TEST_CASE("CheckRegisterBackendCounters")
3419 {
3420  uint16_t globalCounterIds = armnn::profiling::INFERENCES_RUN;
3421  armnn::BackendId cpuRefId(armnn::Compute::CpuRef);
3422 
3423  // Reset the profiling service to the uninitialized state
3425  options.m_EnableProfiling = true;
3426  ProfilingService profilingService;
3427  profilingService.ResetExternalProfilingOptions(options, true);
3428 
3429  RegisterBackendCounters registerBackendCounters(globalCounterIds, cpuRefId, profilingService);
3430 
3431 
3432 
3433  CHECK(profilingService.GetCounterDirectory().GetCategories().empty());
3434  registerBackendCounters.RegisterCategory("categoryOne");
3435  auto categoryOnePtr = profilingService.GetCounterDirectory().GetCategory("categoryOne");
3436  CHECK(categoryOnePtr);
3437 
3438  CHECK(profilingService.GetCounterDirectory().GetDevices().empty());
3439  globalCounterIds = registerBackendCounters.RegisterDevice("deviceOne");
3440  auto deviceOnePtr = profilingService.GetCounterDirectory().GetDevice(globalCounterIds);
3441  CHECK(deviceOnePtr);
3442  CHECK(deviceOnePtr->m_Name == "deviceOne");
3443 
3444  CHECK(profilingService.GetCounterDirectory().GetCounterSets().empty());
3445  globalCounterIds = registerBackendCounters.RegisterCounterSet("counterSetOne");
3446  auto counterSetOnePtr = profilingService.GetCounterDirectory().GetCounterSet(globalCounterIds);
3447  CHECK(counterSetOnePtr);
3448  CHECK(counterSetOnePtr->m_Name == "counterSetOne");
3449 
3450  uint16_t newGlobalCounterId = registerBackendCounters.RegisterCounter(0,
3451  "categoryOne",
3452  0,
3453  0,
3454  1.f,
3455  "CounterOne",
3456  "first test counter");
3457  CHECK((newGlobalCounterId = armnn::profiling::INFERENCES_RUN + 1));
3458  uint16_t mappedGlobalId = profilingService.GetCounterMappings().GetGlobalId(0, cpuRefId);
3459  CHECK(mappedGlobalId == newGlobalCounterId);
3460  auto backendMapping = profilingService.GetCounterMappings().GetBackendId(newGlobalCounterId);
3461  CHECK(backendMapping.first == 0);
3462  CHECK(backendMapping.second == cpuRefId);
3463 
3464  // Reset the profiling service to stop any running thread
3465  options.m_EnableProfiling = false;
3466  profilingService.ResetExternalProfilingOptions(options, true);
3467 }
3468 
3469 TEST_CASE("CheckCounterStatusQuery")
3470 {
3472  options.m_ProfilingOptions.m_EnableProfiling = true;
3473 
3474  // Reset the profiling service to the uninitialized state
3475  ProfilingService profilingService;
3476  profilingService.ResetExternalProfilingOptions(options.m_ProfilingOptions, true);
3477 
3478  const armnn::BackendId cpuRefId(armnn::Compute::CpuRef);
3479  const armnn::BackendId cpuAccId(armnn::Compute::CpuAcc);
3480 
3481  // Create BackendProfiling for each backend
3482  BackendProfiling backendProfilingCpuRef(options, profilingService, cpuRefId);
3483  BackendProfiling backendProfilingCpuAcc(options, profilingService, cpuAccId);
3484 
3485  uint16_t initialNumGlobalCounterIds = armnn::profiling::INFERENCES_RUN;
3486 
3487  // Create RegisterBackendCounters for CpuRef
3488  RegisterBackendCounters registerBackendCountersCpuRef(initialNumGlobalCounterIds, cpuRefId, profilingService);
3489 
3490  // Create 'testCategory' in CounterDirectory (backend agnostic)
3491  CHECK(profilingService.GetCounterDirectory().GetCategories().empty());
3492  registerBackendCountersCpuRef.RegisterCategory("testCategory");
3493  auto categoryOnePtr = profilingService.GetCounterDirectory().GetCategory("testCategory");
3494  CHECK(categoryOnePtr);
3495 
3496  // Counters:
3497  // Global | Local | Backend
3498  // 5 | 0 | CpuRef
3499  // 6 | 1 | CpuRef
3500  // 7 | 1 | CpuAcc
3501 
3502  std::vector<uint16_t> cpuRefCounters = {0, 1};
3503  std::vector<uint16_t> cpuAccCounters = {0};
3504 
3505  // Register the backend counters for CpuRef and validate GetGlobalId and GetBackendId
3506  uint16_t currentNumGlobalCounterIds = registerBackendCountersCpuRef.RegisterCounter(
3507  0, "testCategory", 0, 0, 1.f, "CpuRefCounter0", "Zeroth CpuRef Counter");
3508  CHECK(currentNumGlobalCounterIds == initialNumGlobalCounterIds + 1);
3509  uint16_t mappedGlobalId = profilingService.GetCounterMappings().GetGlobalId(0, cpuRefId);
3510  CHECK(mappedGlobalId == currentNumGlobalCounterIds);
3511  auto backendMapping = profilingService.GetCounterMappings().GetBackendId(currentNumGlobalCounterIds);
3512  CHECK(backendMapping.first == 0);
3513  CHECK(backendMapping.second == cpuRefId);
3514 
3515  currentNumGlobalCounterIds = registerBackendCountersCpuRef.RegisterCounter(
3516  1, "testCategory", 0, 0, 1.f, "CpuRefCounter1", "First CpuRef Counter");
3517  CHECK(currentNumGlobalCounterIds == initialNumGlobalCounterIds + 2);
3518  mappedGlobalId = profilingService.GetCounterMappings().GetGlobalId(1, cpuRefId);
3519  CHECK(mappedGlobalId == currentNumGlobalCounterIds);
3520  backendMapping = profilingService.GetCounterMappings().GetBackendId(currentNumGlobalCounterIds);
3521  CHECK(backendMapping.first == 1);
3522  CHECK(backendMapping.second == cpuRefId);
3523 
3524  // Create RegisterBackendCounters for CpuAcc
3525  RegisterBackendCounters registerBackendCountersCpuAcc(currentNumGlobalCounterIds, cpuAccId, profilingService);
3526 
3527  // Register the backend counter for CpuAcc and validate GetGlobalId and GetBackendId
3528  currentNumGlobalCounterIds = registerBackendCountersCpuAcc.RegisterCounter(
3529  0, "testCategory", 0, 0, 1.f, "CpuAccCounter0", "Zeroth CpuAcc Counter");
3530  CHECK(currentNumGlobalCounterIds == initialNumGlobalCounterIds + 3);
3531  mappedGlobalId = profilingService.GetCounterMappings().GetGlobalId(0, cpuAccId);
3532  CHECK(mappedGlobalId == currentNumGlobalCounterIds);
3533  backendMapping = profilingService.GetCounterMappings().GetBackendId(currentNumGlobalCounterIds);
3534  CHECK(backendMapping.first == 0);
3535  CHECK(backendMapping.second == cpuAccId);
3536 
3537  // Create vectors for active counters
3538  const std::vector<uint16_t> activeGlobalCounterIds = {5}; // CpuRef(0) activated
3539  const std::vector<uint16_t> newActiveGlobalCounterIds = {6, 7}; // CpuRef(0) and CpuAcc(1) activated
3540 
3541  const uint32_t capturePeriod = 200;
3542  const uint32_t newCapturePeriod = 100;
3543 
3544  // Set capture period and active counters in CaptureData
3545  profilingService.SetCaptureData(capturePeriod, activeGlobalCounterIds, {});
3546 
3547  // Get vector of active counters for CpuRef and CpuAcc backends
3548  std::vector<CounterStatus> cpuRefCounterStatus = backendProfilingCpuRef.GetActiveCounters();
3549  std::vector<CounterStatus> cpuAccCounterStatus = backendProfilingCpuAcc.GetActiveCounters();
3550  CHECK_EQ(cpuRefCounterStatus.size(), 1);
3551  CHECK_EQ(cpuAccCounterStatus.size(), 0);
3552 
3553  // Check active CpuRef counter
3554  CHECK_EQ(cpuRefCounterStatus[0].m_GlobalCounterId, activeGlobalCounterIds[0]);
3555  CHECK_EQ(cpuRefCounterStatus[0].m_BackendCounterId, cpuRefCounters[0]);
3556  CHECK_EQ(cpuRefCounterStatus[0].m_SamplingRateInMicroseconds, capturePeriod);
3557  CHECK_EQ(cpuRefCounterStatus[0].m_Enabled, true);
3558 
3559  // Check inactive CpuRef counter
3560  CounterStatus inactiveCpuRefCounter = backendProfilingCpuRef.GetCounterStatus(cpuRefCounters[1]);
3561  CHECK_EQ(inactiveCpuRefCounter.m_GlobalCounterId, 6);
3562  CHECK_EQ(inactiveCpuRefCounter.m_BackendCounterId, cpuRefCounters[1]);
3563  CHECK_EQ(inactiveCpuRefCounter.m_SamplingRateInMicroseconds, 0);
3564  CHECK_EQ(inactiveCpuRefCounter.m_Enabled, false);
3565 
3566  // Check inactive CpuAcc counter
3567  CounterStatus inactiveCpuAccCounter = backendProfilingCpuAcc.GetCounterStatus(cpuAccCounters[0]);
3568  CHECK_EQ(inactiveCpuAccCounter.m_GlobalCounterId, 7);
3569  CHECK_EQ(inactiveCpuAccCounter.m_BackendCounterId, cpuAccCounters[0]);
3570  CHECK_EQ(inactiveCpuAccCounter.m_SamplingRateInMicroseconds, 0);
3571  CHECK_EQ(inactiveCpuAccCounter.m_Enabled, false);
3572 
3573  // Set new capture period and new active counters in CaptureData
3574  profilingService.SetCaptureData(newCapturePeriod, newActiveGlobalCounterIds, {});
3575 
3576  // Get vector of active counters for CpuRef and CpuAcc backends
3577  cpuRefCounterStatus = backendProfilingCpuRef.GetActiveCounters();
3578  cpuAccCounterStatus = backendProfilingCpuAcc.GetActiveCounters();
3579  CHECK_EQ(cpuRefCounterStatus.size(), 1);
3580  CHECK_EQ(cpuAccCounterStatus.size(), 1);
3581 
3582  // Check active CpuRef counter
3583  CHECK_EQ(cpuRefCounterStatus[0].m_GlobalCounterId, newActiveGlobalCounterIds[0]);
3584  CHECK_EQ(cpuRefCounterStatus[0].m_BackendCounterId, cpuRefCounters[1]);
3585  CHECK_EQ(cpuRefCounterStatus[0].m_SamplingRateInMicroseconds, newCapturePeriod);
3586  CHECK_EQ(cpuRefCounterStatus[0].m_Enabled, true);
3587 
3588  // Check active CpuAcc counter
3589  CHECK_EQ(cpuAccCounterStatus[0].m_GlobalCounterId, newActiveGlobalCounterIds[1]);
3590  CHECK_EQ(cpuAccCounterStatus[0].m_BackendCounterId, cpuAccCounters[0]);
3591  CHECK_EQ(cpuAccCounterStatus[0].m_SamplingRateInMicroseconds, newCapturePeriod);
3592  CHECK_EQ(cpuAccCounterStatus[0].m_Enabled, true);
3593 
3594  // Check inactive CpuRef counter
3595  inactiveCpuRefCounter = backendProfilingCpuRef.GetCounterStatus(cpuRefCounters[0]);
3596  CHECK_EQ(inactiveCpuRefCounter.m_GlobalCounterId, 5);
3597  CHECK_EQ(inactiveCpuRefCounter.m_BackendCounterId, cpuRefCounters[0]);
3598  CHECK_EQ(inactiveCpuRefCounter.m_SamplingRateInMicroseconds, 0);
3599  CHECK_EQ(inactiveCpuRefCounter.m_Enabled, false);
3600 
3601  // Reset the profiling service to stop any running thread
3602  options.m_ProfilingOptions.m_EnableProfiling = false;
3603  profilingService.ResetExternalProfilingOptions(options.m_ProfilingOptions, true);
3604 }
3605 
3606 TEST_CASE("CheckRegisterCounters")
3607 {
3609  options.m_ProfilingOptions.m_EnableProfiling = true;
3610  MockBufferManager mockBuffer(1024);
3611 
3612  CaptureData captureData;
3613  MockProfilingService mockProfilingService(mockBuffer, options.m_ProfilingOptions.m_EnableProfiling, captureData);
3614  armnn::BackendId cpuRefId(armnn::Compute::CpuRef);
3615 
3616  mockProfilingService.RegisterMapping(6, 0, cpuRefId);
3617  mockProfilingService.RegisterMapping(7, 1, cpuRefId);
3618  mockProfilingService.RegisterMapping(8, 2, cpuRefId);
3619 
3620  armnn::profiling::BackendProfiling backendProfiling(options,
3621  mockProfilingService,
3622  cpuRefId);
3623 
3624  armnn::profiling::Timestamp timestamp;
3625  timestamp.timestamp = 1000998;
3626  timestamp.counterValues.emplace_back(0, 700);
3627  timestamp.counterValues.emplace_back(2, 93);
3628  std::vector<armnn::profiling::Timestamp> timestamps;
3629  timestamps.push_back(timestamp);
3630  backendProfiling.ReportCounters(timestamps);
3631 
3632  auto readBuffer = mockBuffer.GetReadableBuffer();
3633 
3634  uint32_t headerWord0 = ReadUint32(readBuffer, 0);
3635  uint32_t headerWord1 = ReadUint32(readBuffer, 4);
3636  uint64_t readTimestamp = ReadUint64(readBuffer, 8);
3637 
3638  CHECK(((headerWord0 >> 26) & 0x0000003F) == 3); // packet family
3639  CHECK(((headerWord0 >> 19) & 0x0000007F) == 0); // packet class
3640  CHECK(((headerWord0 >> 16) & 0x00000007) == 0); // packet type
3641  CHECK(headerWord1 == 20); // data length
3642  CHECK(1000998 == readTimestamp); // capture period
3643 
3644  uint32_t offset = 16;
3645  // Check Counter Index
3646  uint16_t readIndex = ReadUint16(readBuffer, offset);
3647  CHECK(6 == readIndex);
3648 
3649  // Check Counter Value
3650  offset += 2;
3651  uint32_t readValue = ReadUint32(readBuffer, offset);
3652  CHECK(700 == readValue);
3653 
3654  // Check Counter Index
3655  offset += 4;
3656  readIndex = ReadUint16(readBuffer, offset);
3657  CHECK(8 == readIndex);
3658 
3659  // Check Counter Value
3660  offset += 2;
3661  readValue = ReadUint32(readBuffer, offset);
3662  CHECK(93 == readValue);
3663 }
3664 
3665 TEST_CASE("CheckFileFormat") {
3666  // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3668 
3669  // Create profiling options.
3671  options.m_EnableProfiling = true;
3672  // Check the default value set to binary
3673  CHECK(options.m_FileFormat == "binary");
3674 
3675  // Change file format to an unsupported value
3676  options.m_FileFormat = "json";
3677  // Enable the profiling service
3678  armnn::profiling::ProfilingService profilingService;
3679  profilingService.ResetExternalProfilingOptions(options, true);
3680  // Start the command handler and the send thread
3681  profilingService.Update();
3682  CHECK(profilingService.GetCurrentState()==ProfilingState::NotConnected);
3683 
3684  // Redirect the output to a local stream so that we can parse the warning message
3685  std::stringstream ss;
3686  StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3687 
3688  // When Update is called and the current state is ProfilingState::NotConnected
3689  // an exception will be raised from GetProfilingConnection and displayed as warning in the output local stream
3690  profilingService.Update();
3691 
3692  streamRedirector.CancelRedirect();
3693 
3694  // Check that the expected error has occurred and logged to the standard output
3695  if (ss.str().find("Unsupported profiling file format, only binary is supported") == std::string::npos)
3696  {
3697  std::cout << ss.str();
3698  FAIL("Expected string not found.");
3699  }
3700 }
3701 
3702 }
profiling::ProfilingService & GetProfilingService(armnn::RuntimeImpl *runtime)
Definition: TestUtils.cpp:35
const Category * RegisterCategory(const std::string &categoryName) override
const Counter * RegisterCounter(const BackendId &backendId, const uint16_t uid, const std::string &parentCategoryName, uint16_t counterClass, uint16_t interpolation, double multiplier, const std::string &name, const std::string &description, const Optional< std::string > &units=EmptyOptional(), const Optional< uint16_t > &numberOfCores=EmptyOptional(), const Optional< uint16_t > &deviceUid=EmptyOptional(), const Optional< uint16_t > &counterSetUid=EmptyOptional()) override
uint32_t GetAbsoluteCounterValue(uint16_t counterUid) const override
uint16_t GetGlobalId(uint16_t backendCounterId, const armnn::BackendId &backendId) const override
CPU Execution: Reference C++ kernels.
ProfilingState GetCurrentState() const
void WriteUint16(const IPacketBufferPtr &packetBuffer, unsigned int offset, uint16_t value)
const std::vector< uint16_t > & GetCounterIds() const
Definition: Holder.cpp:49
void WriteUint32(const IPacketBufferPtr &packetBuffer, unsigned int offset, uint32_t value)
uint64_t ReadUint64(const IPacketBufferPtr &packetBuffer, unsigned int offset)
std::unordered_map< uint16_t, CounterPtr > Counters
virtual const Category * GetCategory(const std::string &name) const =0
uint32_t GetStreamMetaDataPacketSize()
const Device * GetDevice(uint16_t uid) const override
virtual uint16_t GetCounterCount() const =0
uint16_t ReadUint16(const IPacketBufferPtr &packetBuffer, unsigned int offset)
uint16_t GetCounterSetCount() const override
virtual const Device * GetDevice(uint16_t uid) const =0
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
virtual const CounterSets & GetCounterSets() const =0
void RegisterMapping(uint16_t globalCounterId, uint16_t backendCounterId, const armnn::BackendId &backendId) override
uint32_t IncrementCounterValue(uint16_t counterUid) override
virtual const Categories & GetCategories() const =0
uint16_t GetNextUid(bool peekOnly)
uint16_t GetDeviceCount() const override
uint32_t GetDeltaCounterValue(uint16_t counterUid) override
uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) override
virtual const Devices & GetDevices() const =0
void SetCapturePeriod(uint32_t capturePeriod)
Definition: Holder.cpp:29
uint32_t GetCapturePeriod() const
Definition: Holder.cpp:44
virtual const CounterSet * GetCounterSet(uint16_t uid) const =0
CaptureData GetCaptureData() const
Definition: Holder.cpp:54
void SetCaptureData(uint32_t capturePeriod, const std::vector< uint16_t > &counterIds, const std::set< BackendId > &activeBackends)
void SetCaptureData(uint32_t capturePeriod, const std::vector< uint16_t > &counterIds, const std::set< armnn::BackendId > &activeBackends)
Definition: Holder.cpp:74
std::vector< uint16_t > GetNextCounterUids(uint16_t firstUid, uint16_t cores)
const std::pair< uint16_t, armnn::BackendId > & GetBackendId(uint16_t globalCounterId) const override
const ICounterMappings & GetCounterMappings() const override
void ResetExternalProfilingOptions(const ExternalProfilingOptions &options, bool resetProfilingService=false)
uint32_t ConstructHeader(uint32_t packetFamily, uint32_t packetId)
uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) override
uint16_t GetCategoryCount() const override
constexpr unsigned int LOWEST_CAPTURE_PERIOD
The lowest performance data capture interval we support is 10 miliseconds.
Definition: Types.hpp:21
std::vector< CounterValue > counterValues
IPacketBufferPtr GetReadableBuffer() override
uint32_t ReadUint32(const IPacketBufferPtr &packetBuffer, unsigned int offset)
std::vector< uint16_t > m_Counters
void SetCounterValue(uint16_t counterUid, uint32_t value) override
const Counter * GetCounter(uint16_t uid) const override
const Category * GetCategory(const std::string &name) const override
EmptyOptional is used to initialize the Optional class in case we want to have default value for an O...
Definition: Optional.hpp:32
const Device * RegisterDevice(const std::string &deviceName, uint16_t cores=0, const Optional< std::string > &parentCategoryName=EmptyOptional()) override
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
CPU Execution: NEON: ArmCompute.
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:35
void SetCounterIds(const std::vector< uint16_t > &counterIds)
Definition: Holder.cpp:34
bool WritePacket(const unsigned char *buffer, uint32_t length) override
uint16_t GetCounterCount() const override
void TransitionToState(ProfilingState newState)
virtual const Counters & GetCounters() const =0
const CounterSet * RegisterCounterSet(const std::string &counterSetName, uint16_t count=0, const Optional< std::string > &parentCategoryName=EmptyOptional()) override
virtual const std::pair< uint16_t, armnn::BackendId > & GetBackendId(uint16_t globalCounterId) const =0
const CounterSet * GetCounterSet(uint16_t uid) const override
ExternalProfilingOptions m_ProfilingOptions
Definition: IRuntime.hpp:160
const ICounterDirectory & GetCounterDirectory() const
ProfilingState ConfigureProfilingService(const ExternalProfilingOptions &options, bool resetProfilingService=false)
virtual uint16_t GetGlobalId(uint16_t backendCounterId, const armnn::BackendId &backendId) const =0