ArmNN
 20.02
ProfilingTests.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ProfilingMocks.hpp"
9 
10 #include <armnn/Logging.hpp>
11 
13 #include <IProfilingConnection.hpp>
14 #include <ProfilingService.hpp>
15 
16 #include <boost/polymorphic_cast.hpp>
17 #include <boost/test/unit_test.hpp>
18 
19 #include <chrono>
20 #include <thread>
21 
22 namespace armnn
23 {
24 
25 namespace profiling
26 {
27 
29 {
30 public:
32  {
33  // Set the new log level
34  armnn::ConfigureLogging(true, true, severity);
35  }
37  {
38  // The default log level for unit tests is "Fatal"
40  }
41 };
42 
44 {
45 public:
46  StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer)
47  : m_Stream(stream)
48  , m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer))
49  {}
50 
51  ~StreamRedirector() { CancelRedirect(); }
52 
54  {
55  // Only cancel the redirect once.
56  if (m_BackupBuffer != nullptr )
57  {
58  m_Stream.rdbuf(m_BackupBuffer);
59  m_BackupBuffer = nullptr;
60  }
61  }
62 
63 private:
64  std::ostream& m_Stream;
65  std::streambuf* m_BackupBuffer;
66 };
67 
69 {
70 public:
71  TestProfilingConnectionBase() = default;
72  ~TestProfilingConnectionBase() = default;
73 
74  bool IsOpen() const override { return true; }
75 
76  void Close() override {}
77 
78  bool WritePacket(const unsigned char* buffer, uint32_t length) override
79  {
80  IgnoreUnused(buffer, length);
81 
82  return false;
83  }
84 
85  Packet ReadPacket(uint32_t timeout) override
86  {
87  // First time we're called return a connection ack packet. After that always timeout.
88  if (m_FirstCall)
89  {
90  m_FirstCall = false;
91  // Return connection acknowledged packet
92  return Packet(65536);
93  }
94  else
95  {
96  std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
97  throw armnn::TimeoutException("Simulate a timeout error\n");
98  }
99  }
100 
101  bool m_FirstCall = true;
102 };
103 
105 {
106 public:
108  : m_ReadRequests(0)
109  {}
110 
111  Packet ReadPacket(uint32_t timeout) override
112  {
113  // Return connection acknowledged packet after three timeouts
114  if (m_ReadRequests % 3 == 0)
115  {
116  std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
117  ++m_ReadRequests;
118  throw armnn::TimeoutException("Simulate a timeout error\n");
119  }
120 
121  return Packet(65536);
122  }
123 
125  {
126  return m_ReadRequests.load();
127  }
128 
129 private:
130  std::atomic<int> m_ReadRequests;
131 };
132 
134 {
135 public:
137  : m_ReadRequests(0)
138  {}
139 
140  Packet ReadPacket(uint32_t timeout) override
141  {
142  IgnoreUnused(timeout);
143  ++m_ReadRequests;
144  throw armnn::Exception("Simulate a non-timeout error");
145  }
146 
148  {
149  return m_ReadRequests.load();
150  }
151 
152 private:
153  std::atomic<int> m_ReadRequests;
154 };
155 
157 {
158 public:
159  Packet ReadPacket(uint32_t timeout) override
160  {
161  IgnoreUnused(timeout);
162  // Connection Acknowledged Packet header (word 0, word 1 is always zero):
163  // 26:31 [6] packet_family: Control Packet Family, value 0b000000
164  // 16:25 [10] packet_id: Packet identifier, value 0b0000000001
165  // 8:15 [8] reserved: Reserved, value 0b00000000
166  // 0:7 [8] reserved: Reserved, value 0b00000000
167  uint32_t packetFamily = 0;
168  uint32_t packetId = 37; // Wrong packet id!!!
169  uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
170 
171  return Packet(header);
172  }
173 };
174 
176 {
177 public:
179 
180  int GetCount() { return m_Count; }
181 
182  void operator()(const Packet& packet) override
183  {
184  IgnoreUnused(packet);
185  m_Count++;
186  }
187 
188 private:
189  int m_Count = 0;
190 };
191 
193 {
194  using TestFunctorA::TestFunctorA;
195 };
196 
198 {
199  using TestFunctorA::TestFunctorA;
200 };
201 
203 {
204 public:
205  using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>;
206 
208  : ProfilingService()
209  , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
210  , m_BackupProfilingConnectionFactory(nullptr)
211  {
212  BOOST_CHECK(m_MockProfilingConnectionFactory);
213  SwapProfilingConnectionFactory(ProfilingService::Instance(),
214  m_MockProfilingConnectionFactory.get(),
215  m_BackupProfilingConnectionFactory);
216  BOOST_CHECK(m_BackupProfilingConnectionFactory);
217  }
219  {
220  BOOST_CHECK(m_BackupProfilingConnectionFactory);
221  IProfilingConnectionFactory* temp = nullptr;
222  SwapProfilingConnectionFactory(ProfilingService::Instance(),
223  m_BackupProfilingConnectionFactory,
224  temp);
225  }
226 
228  {
229  IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance());
230  return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection);
231  }
232 
234  {
235  TransitionToState(ProfilingService::Instance(), newState);
236  }
237 
238  long WaitForPacketsSent(MockProfilingConnection* mockProfilingConnection,
240  uint32_t length = 0,
241  uint32_t timeout = 1000)
242  {
243  long packetCount = mockProfilingConnection->CheckForPacket({packetType, length});
244  // The first packet we receive may not be the one we are looking for, so keep looping until till we find it,
245  // or until WaitForPacketsSent times out
246  while(packetCount == 0 && timeout != 0)
247  {
248  std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();
249  // Wait for a notification from the send thread
251 
252  std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
253 
254  // We need to make sure the timeout does not reset each time we call WaitForPacketsSent
255  uint32_t elapsedTime = static_cast<uint32_t>(
256  std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count());
257 
258  packetCount = mockProfilingConnection->CheckForPacket({packetType, length});
259 
260  if (elapsedTime > timeout)
261  {
262  break;
263  }
264 
265  timeout -= elapsedTime;
266  }
267  return packetCount;
268  }
269 
270 private:
271  MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
272  IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
273 };
274 
275 } // namespace profiling
276 
277 } // namespace armnn
Packet ReadPacket(uint32_t timeout) override
LogLevelSwapper(armnn::LogSeverity severity)
static ProfilingService & Instance()
long WaitForPacketsSent(MockProfilingConnection *mockProfilingConnection, MockProfilingConnection::PacketType packetType, uint32_t length=0, uint32_t timeout=1000)
void operator()(const Packet &packet) override
StreamRedirector(std::ostream &stream, std::streambuf *newStreamBuffer)
void ConfigureLogging(bool printToStandardOutput, bool printToDebugOutput, LogSeverity severity)
Configures the logging behaviour of the ARMNN library.
Definition: Utils.cpp:10
Copyright (c) 2020 ARM Limited.
void IgnoreUnused(Ts &&...)
BOOST_CHECK(profilingService.GetCurrentState()==ProfilingState::WaitingForAck)
bool WaitForPacketSent(ProfilingService &instance, uint32_t timeout=1000)
CommandHandlerFunctor(uint32_t familyId, uint32_t packetId, uint32_t version)
bool WritePacket(const unsigned char *buffer, uint32_t length) override
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
long CheckForPacket(const std::pair< PacketType, uint32_t > packetInfo)
LogSeverity
Definition: Utils.hpp:12
std::unique_ptr< MockProfilingConnectionFactory > MockProfilingConnectionFactoryPtr