ArmNN
 20.02
ProfilingService.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ProfilingService.hpp"
7 
8 #include <armnn/BackendId.hpp>
9 #include <armnn/Logging.hpp>
10 #include <common/include/SocketConnectionException.hpp>
11 
12 #include <boost/format.hpp>
13 
14 namespace armnn
15 {
16 
17 namespace profiling
18 {
19 
21  bool resetProfilingService)
22 {
23  // Update the profiling options
24  m_Options = options;
25 
26  // Check if the profiling service needs to be reset
27  if (resetProfilingService)
28  {
29  // Reset the profiling service
30  Reset();
31  }
32 }
33 
35 {
36  return m_Options.m_EnableProfiling;
37 }
38 
41  bool resetProfilingService)
42 {
43  ResetExternalProfilingOptions(options, resetProfilingService);
44  ProfilingState currentState = m_StateMachine.GetCurrentState();
45  if (options.m_EnableProfiling)
46  {
47  switch (currentState)
48  {
50  Update(); // should transition to NotConnected
51  Update(); // will either stay in NotConnected because there is no server
52  // or will enter WaitingForAck.
53  currentState = m_StateMachine.GetCurrentState();
54  if (currentState == ProfilingState::WaitingForAck)
55  {
56  Update(); // poke it again to send out the metadata packet
57  }
58  currentState = m_StateMachine.GetCurrentState();
59  return currentState;
61  Update(); // will either stay in NotConnected because there is no server
62  // or will enter WaitingForAck
63  currentState = m_StateMachine.GetCurrentState();
64  if (currentState == ProfilingState::WaitingForAck)
65  {
66  Update(); // poke it again to send out the metadata packet
67  }
68  currentState = m_StateMachine.GetCurrentState();
69  return currentState;
70  default:
71  return currentState;
72  }
73  }
74  else
75  {
76  // Make sure profiling is shutdown
77  switch (currentState)
78  {
81  return currentState;
82  default:
83  Stop();
84  return m_StateMachine.GetCurrentState();
85  }
86  }
87 }
88 
90 {
91  if (!m_Options.m_EnableProfiling)
92  {
93  // Don't run if profiling is disabled
94  return;
95  }
96 
97  ProfilingState currentState = m_StateMachine.GetCurrentState();
98  switch (currentState)
99  {
101 
102  // Initialize the profiling service
103  Initialize();
104 
105  // Move to the next state
107  break;
109  // Stop the command thread (if running)
110  m_CommandHandler.Stop();
111 
112  // Stop the send thread (if running)
113  m_SendThread.Stop(false);
114 
115  // Stop the periodic counter capture thread (if running)
116  m_PeriodicCounterCapture.Stop();
117 
118  // Reset any existing profiling connection
119  m_ProfilingConnection.reset();
120 
121  try
122  {
123  // Setup the profiling connection
124  BOOST_ASSERT(m_ProfilingConnectionFactory);
125  m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options);
126  }
127  catch (const Exception& e)
128  {
129  ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection: "
130  << e.what();
131  }
132  catch (const armnnProfiling::SocketConnectionException& e)
133  {
134  ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection ["
135  << e.what() << "] on socket [" << e.GetSocketFd() << "].";
136  }
137 
138  // Move to the next state
139  m_StateMachine.TransitionToState(m_ProfilingConnection
140  ? ProfilingState::WaitingForAck // Profiling connection obtained, wait for ack
141  : ProfilingState::NotConnected); // Profiling connection failed, stay in the
142  // "NotConnected" state
143  break;
145  BOOST_ASSERT(m_ProfilingConnection);
146 
147  // Start the command thread
148  m_CommandHandler.Start(*m_ProfilingConnection);
149 
150  // Start the send thread, while in "WaitingForAck" state it'll send out a "Stream MetaData" packet waiting for
151  // a valid "Connection Acknowledged" packet confirming the connection
152  m_SendThread.Start(*m_ProfilingConnection);
153 
154  // The connection acknowledged command handler will automatically transition the state to "Active" once a
155  // valid "Connection Acknowledged" packet has been received
156 
157  break;
159 
160  // The period counter capture thread is started by the Periodic Counter Selection command handler upon
161  // request by an external profiling service
162 
163  break;
164  default:
165  throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
166  % static_cast<int>(currentState)));
167  }
168 }
169 
171 {
172  ProfilingState currentState = m_StateMachine.GetCurrentState();
173  switch (currentState)
174  {
178  return; // NOP
180  // Stop the command thread (if running)
181  Stop();
182 
183  break;
184  default:
185  throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
186  % static_cast<int>(currentState)));
187  }
188 }
189 
190 // Store a profiling context returned from a backend that support profiling, and register its counters
192  std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext)
193 {
194  BOOST_ASSERT(profilingContext != nullptr);
195  // Register the backend counters
196  m_MaxGlobalCounterId = profilingContext->RegisterCounters(m_MaxGlobalCounterId);
197  m_BackendProfilingContexts.emplace(backendId, std::move(profilingContext));
198 }
200 {
201  return m_CounterDirectory;
202 }
203 
205 {
206  return m_CounterDirectory;
207 }
208 
210 {
211  return m_StateMachine.GetCurrentState();
212 }
213 
215 {
216  return m_CounterDirectory.GetCounterCount();
217 }
218 
219 bool ProfilingService::IsCounterRegistered(uint16_t counterUid) const
220 {
221  return counterUid < m_CounterIndex.size();
222 }
223 
224 uint32_t ProfilingService::GetCounterValue(uint16_t counterUid) const
225 {
226  CheckCounterUid(counterUid);
227  std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
228  BOOST_ASSERT(counterValuePtr);
229  return counterValuePtr->load(std::memory_order::memory_order_relaxed);
230 }
231 
233 {
234  return m_CounterIdMap;
235 }
236 
238 {
239  return m_CounterIdMap;
240 }
241 
243 {
244  return m_Holder.GetCaptureData();
245 }
246 
247 void ProfilingService::SetCaptureData(uint32_t capturePeriod,
248  const std::vector<uint16_t>& counterIds,
249  const std::set<BackendId>& activeBackends)
250 {
251  m_Holder.SetCaptureData(capturePeriod, counterIds, activeBackends);
252 }
253 
254 void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
255 {
256  CheckCounterUid(counterUid);
257  std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
258  BOOST_ASSERT(counterValuePtr);
259  counterValuePtr->store(value, std::memory_order::memory_order_relaxed);
260 }
261 
262 uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
263 {
264  CheckCounterUid(counterUid);
265  std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
266  BOOST_ASSERT(counterValuePtr);
267  return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed);
268 }
269 
270 uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value)
271 {
272  CheckCounterUid(counterUid);
273  std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
274  BOOST_ASSERT(counterValuePtr);
275  return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed);
276 }
277 
278 uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
279 {
280  CheckCounterUid(counterUid);
281  std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
282  BOOST_ASSERT(counterValuePtr);
283  return counterValuePtr->operator++(std::memory_order::memory_order_relaxed);
284 }
285 
287 {
288  return m_GuidGenerator.NextGuid();
289 }
290 
292 {
293  return m_GuidGenerator.GenerateStaticId(str);
294 }
295 
296 std::unique_ptr<ISendTimelinePacket> ProfilingService::GetSendTimelinePacket() const
297 {
298  return m_TimelinePacketWriterFactory.GetSendTimelinePacket();
299 }
300 
301 void ProfilingService::Initialize()
302 {
303  // Register a category for the basic runtime counters
304  if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime"))
305  {
306  m_CounterDirectory.RegisterCategory("ArmNN_Runtime");
307  }
308 
309  // Register a counter for the number of Network loads
310  if (!m_CounterDirectory.IsCounterRegistered("Network loads"))
311  {
312  const Counter* loadedNetworksCounter =
313  m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
314  armnn::profiling::NETWORK_LOADS,
315  "ArmNN_Runtime",
316  0,
317  0,
318  1.f,
319  "Network loads",
320  "The number of networks loaded at runtime",
321  std::string("networks"));
322  BOOST_ASSERT(loadedNetworksCounter);
323  InitializeCounterValue(loadedNetworksCounter->m_Uid);
324  }
325  // Register a counter for the number of unloaded networks
326  if (!m_CounterDirectory.IsCounterRegistered("Network unloads"))
327  {
328  const Counter* unloadedNetworksCounter =
329  m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
330  armnn::profiling::NETWORK_UNLOADS,
331  "ArmNN_Runtime",
332  0,
333  0,
334  1.f,
335  "Network unloads",
336  "The number of networks unloaded at runtime",
337  std::string("networks"));
338  BOOST_ASSERT(unloadedNetworksCounter);
339  InitializeCounterValue(unloadedNetworksCounter->m_Uid);
340  }
341  // Register a counter for the number of registered backends
342  if (!m_CounterDirectory.IsCounterRegistered("Backends registered"))
343  {
344  const Counter* registeredBackendsCounter =
345  m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
346  armnn::profiling::REGISTERED_BACKENDS,
347  "ArmNN_Runtime",
348  0,
349  0,
350  1.f,
351  "Backends registered",
352  "The number of registered backends",
353  std::string("backends"));
354  BOOST_ASSERT(registeredBackendsCounter);
355  InitializeCounterValue(registeredBackendsCounter->m_Uid);
356  }
357  // Register a counter for the number of registered backends
358  if (!m_CounterDirectory.IsCounterRegistered("Backends unregistered"))
359  {
360  const Counter* unregisteredBackendsCounter =
361  m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
362  armnn::profiling::UNREGISTERED_BACKENDS,
363  "ArmNN_Runtime",
364  0,
365  0,
366  1.f,
367  "Backends unregistered",
368  "The number of unregistered backends",
369  std::string("backends"));
370  BOOST_ASSERT(unregisteredBackendsCounter);
371  InitializeCounterValue(unregisteredBackendsCounter->m_Uid);
372  }
373  // Register a counter for the number of inferences run
374  if (!m_CounterDirectory.IsCounterRegistered("Inferences run"))
375  {
376  const Counter* inferencesRunCounter =
377  m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
378  armnn::profiling::INFERENCES_RUN,
379  "ArmNN_Runtime",
380  0,
381  0,
382  1.f,
383  "Inferences run",
384  "The number of inferences run",
385  std::string("inferences"));
386  BOOST_ASSERT(inferencesRunCounter);
387  InitializeCounterValue(inferencesRunCounter->m_Uid);
388  }
389 }
390 
391 void ProfilingService::InitializeCounterValue(uint16_t counterUid)
392 {
393  // Increase the size of the counter index if necessary
394  if (counterUid >= m_CounterIndex.size())
395  {
396  m_CounterIndex.resize(boost::numeric_cast<size_t>(counterUid) + 1);
397  }
398 
399  // Create a new atomic counter and add it to the list
400  m_CounterValues.emplace_back(0);
401 
402  // Register the new counter to the counter index for quick access
403  std::atomic<uint32_t>* counterValuePtr = &(m_CounterValues.back());
404  m_CounterIndex.at(counterUid) = counterValuePtr;
405 }
406 
407 void ProfilingService::Reset()
408 {
409  // Stop the profiling service...
410  Stop();
411 
412  // ...then delete all the counter data and configuration...
413  m_CounterIndex.clear();
414  m_CounterValues.clear();
415  m_CounterDirectory.Clear();
416  m_CounterIdMap.Reset();
417  m_BufferManager.Reset();
418 
419  // ...finally reset the profiling state machine
420  m_StateMachine.Reset();
421  m_BackendProfilingContexts.clear();
422  m_MaxGlobalCounterId = armnn::profiling::INFERENCES_RUN;
423 }
424 
425 void ProfilingService::Stop()
426 {
427  // The order in which we reset/stop the components is not trivial!
428  // First stop the producing threads
429  // Command Handler first as it is responsible for launching then Periodic Counter capture thread
430  m_CommandHandler.Stop();
431  m_PeriodicCounterCapture.Stop();
432  // The the consuming thread
433  m_SendThread.Stop(false);
434 
435  // ...then close and destroy the profiling connection...
436  if (m_ProfilingConnection != nullptr && m_ProfilingConnection->IsOpen())
437  {
438  m_ProfilingConnection->Close();
439  }
440  m_ProfilingConnection.reset();
441 
442  // ...then move to the "NotConnected" state
444 }
445 
446 inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const
447 {
448  if (!IsCounterRegistered(counterUid))
449  {
450  throw InvalidArgumentException(boost::str(boost::format("Counter UID %1% is not registered") % counterUid));
451  }
452 }
453 
455 {
456  Stop();
457 }
458 
459 } // namespace profiling
460 
461 } // namespace armnn
bool IsCounterRegistered(uint16_t counterUid) const override
bool IsCategoryRegistered(const std::string &categoryName) const
const Category * RegisterCategory(const std::string &categoryName) override
const Counter * RegisterCounter(const BackendId &backendId, const uint16_t uid, const std::string &parentCategoryName, uint16_t counterClass, uint16_t interpolation, double multiplier, const std::string &name, const std::string &description, const Optional< std::string > &units=EmptyOptional(), const Optional< uint16_t > &numberOfCores=EmptyOptional(), const Optional< uint16_t > &deviceUid=EmptyOptional(), const Optional< uint16_t > &counterSetUid=EmptyOptional()) override
ProfilingState GetCurrentState() const
void Start(IProfilingConnection &profilingConnection)
uint32_t GetCounterValue(uint16_t counterUid) const override
Strongly typed guids to distinguish between those generated at runtime, and those that are statically...
Definition: Types.hpp:294
virtual const char * what() const noexcept override
Definition: Exceptions.cpp:32
#define ARMNN_LOG(severity)
Definition: Logging.hpp:163
Copyright (c) 2020 ARM Limited.
ProfilingStaticGuid GenerateStaticId(const std::string &str) override
Create a ProfilingStaticGuid based on a hash of the string.
ProfilingDynamicGuid NextGuid() override
Return the next random Guid in the sequence.
uint32_t IncrementCounterValue(uint16_t counterUid) override
uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) override
CaptureData GetCaptureData() const
Definition: Holder.cpp:54
void SetCaptureData(uint32_t capturePeriod, const std::vector< uint16_t > &counterIds, const std::set< BackendId > &activeBackends)
void SetCaptureData(uint32_t capturePeriod, const std::vector< uint16_t > &counterIds, const std::set< armnn::BackendId > &activeBackends)
Definition: Holder.cpp:74
std::unique_ptr< ISendTimelinePacket > GetSendTimelinePacket() const override
const ICounterMappings & GetCounterMappings() const override
void ResetExternalProfilingOptions(const ExternalProfilingOptions &options, bool resetProfilingService=false)
uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) override
ProfilingDynamicGuid NextGuid() override
Return the next random Guid in the sequence.
IRegisterCounterMapping & GetCounterMappingRegistry()
void SetCounterValue(uint16_t counterUid, uint32_t value) override
bool IsProfilingEnabled() const override
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
bool IsCounterRegistered(uint16_t counterUid) const
uint16_t GetCounterCount() const override
void TransitionToState(ProfilingState newState)
void AddBackendProfilingContext(const BackendId backendId, std::shared_ptr< armnn::profiling::IBackendProfilingContext > profilingContext)
std::unique_ptr< ISendTimelinePacket > GetSendTimelinePacket() const
armnn::Runtime::CreationOptions::ExternalProfilingOptions options
const ICounterDirectory & GetCounterDirectory() const
ProfilingState ConfigureProfilingService(const ExternalProfilingOptions &options, bool resetProfilingService=false)
void Stop(bool rethrowSendThreadExceptions=true) override
Stop the thread.
Definition: SendThread.cpp:81
uint16_t GetCounterCount() const override
void Start(IProfilingConnection &profilingConnection) override
Start the thread.
Definition: SendThread.cpp:51
ProfilingStaticGuid GenerateStaticId(const std::string &str) override
Create a ProfilingStaticGuid based on a hash of the string.