ArmNN  NotReleased
ProfilingService.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ProfilingService.hpp"
7 
8 #include <armnn/BackendId.hpp>
9 #include <armnn/Logging.hpp>
10 
11 #include <boost/format.hpp>
12 
13 namespace armnn
14 {
15 
16 namespace profiling
17 {
18 
20  bool resetProfilingService)
21 {
22  // Update the profiling options
23  m_Options = options;
24 
25  // Check if the profiling service needs to be reset
26  if (resetProfilingService)
27  {
28  // Reset the profiling service
29  Reset();
30  }
31 }
32 
34 {
35  return m_Options.m_EnableProfiling;
36 }
37 
40  bool resetProfilingService)
41 {
42  ResetExternalProfilingOptions(options, resetProfilingService);
43  ProfilingState currentState = m_StateMachine.GetCurrentState();
44  if (options.m_EnableProfiling)
45  {
46  switch (currentState)
47  {
49  Update(); // should transition to NotConnected
50  Update(); // will either stay in NotConnected because there is no server
51  // or will enter WaitingForAck.
52  currentState = m_StateMachine.GetCurrentState();
53  if (currentState == ProfilingState::WaitingForAck)
54  {
55  Update(); // poke it again to send out the metadata packet
56  }
57  currentState = m_StateMachine.GetCurrentState();
58  return currentState;
60  Update(); // will either stay in NotConnected because there is no server
61  // or will enter WaitingForAck
62  currentState = m_StateMachine.GetCurrentState();
63  if (currentState == ProfilingState::WaitingForAck)
64  {
65  Update(); // poke it again to send out the metadata packet
66  }
67  currentState = m_StateMachine.GetCurrentState();
68  return currentState;
69  default:
70  return currentState;
71  }
72  }
73  else
74  {
75  // Make sure profiling is shutdown
76  switch (currentState)
77  {
80  return currentState;
81  default:
82  Stop();
83  return m_StateMachine.GetCurrentState();
84  }
85  }
86 }
87 
89 {
90  if (!m_Options.m_EnableProfiling)
91  {
92  // Don't run if profiling is disabled
93  return;
94  }
95 
96  ProfilingState currentState = m_StateMachine.GetCurrentState();
97  switch (currentState)
98  {
100 
101  // Initialize the profiling service
102  Initialize();
103 
104  // Move to the next state
106  break;
108  // Stop the command thread (if running)
109  m_CommandHandler.Stop();
110 
111  // Stop the send thread (if running)
112  m_SendThread.Stop(false);
113 
114  // Stop the periodic counter capture thread (if running)
115  m_PeriodicCounterCapture.Stop();
116 
117  // Reset any existing profiling connection
118  m_ProfilingConnection.reset();
119 
120  try
121  {
122  // Setup the profiling connection
123  BOOST_ASSERT(m_ProfilingConnectionFactory);
124  m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options);
125  }
126  catch (const Exception& e)
127  {
128  ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection: "
129  << e.what();
130  }
131 
132  // Move to the next state
133  m_StateMachine.TransitionToState(m_ProfilingConnection
134  ? ProfilingState::WaitingForAck // Profiling connection obtained, wait for ack
135  : ProfilingState::NotConnected); // Profiling connection failed, stay in the
136  // "NotConnected" state
137  break;
139  BOOST_ASSERT(m_ProfilingConnection);
140 
141  // Start the command thread
142  m_CommandHandler.Start(*m_ProfilingConnection);
143 
144  // Start the send thread, while in "WaitingForAck" state it'll send out a "Stream MetaData" packet waiting for
145  // a valid "Connection Acknowledged" packet confirming the connection
146  m_SendThread.Start(*m_ProfilingConnection);
147 
148  // The connection acknowledged command handler will automatically transition the state to "Active" once a
149  // valid "Connection Acknowledged" packet has been received
150 
151  break;
153 
154  // The period counter capture thread is started by the Periodic Counter Selection command handler upon
155  // request by an external profiling service
156 
157  break;
158  default:
159  throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
160  % static_cast<int>(currentState)));
161  }
162 }
163 
165 {
166  ProfilingState currentState = m_StateMachine.GetCurrentState();
167  switch (currentState)
168  {
172  return; // NOP
174  // Stop the command thread (if running)
175  Stop();
176 
177  break;
178  default:
179  throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
180  % static_cast<int>(currentState)));
181  }
182 }
183 
184 // Store a profiling context returned from a backend that support profiling, and register its counters
186  std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext)
187 {
188  BOOST_ASSERT(profilingContext != nullptr);
189  // Register the backend counters
190  m_MaxGlobalCounterId = profilingContext->RegisterCounters(m_MaxGlobalCounterId);
191  m_BackendProfilingContexts.emplace(backendId, std::move(profilingContext));
192 }
194 {
195  return m_CounterDirectory;
196 }
197 
199 {
200  return m_CounterDirectory;
201 }
202 
204 {
205  return m_StateMachine.GetCurrentState();
206 }
207 
209 {
210  return m_CounterDirectory.GetCounterCount();
211 }
212 
213 bool ProfilingService::IsCounterRegistered(uint16_t counterUid) const
214 {
215  return counterUid < m_CounterIndex.size();
216 }
217 
218 uint32_t ProfilingService::GetCounterValue(uint16_t counterUid) const
219 {
220  CheckCounterUid(counterUid);
221  std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
222  BOOST_ASSERT(counterValuePtr);
223  return counterValuePtr->load(std::memory_order::memory_order_relaxed);
224 }
225 
227 {
228  return m_CounterIdMap;
229 }
230 
232 {
233  return m_CounterIdMap;
234 }
235 
237 {
238  return m_Holder.GetCaptureData();
239 }
240 
241 void ProfilingService::SetCaptureData(uint32_t capturePeriod,
242  const std::vector<uint16_t>& counterIds,
243  const std::set<BackendId>& activeBackends)
244 {
245  m_Holder.SetCaptureData(capturePeriod, counterIds, activeBackends);
246 }
247 
248 void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
249 {
250  CheckCounterUid(counterUid);
251  std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
252  BOOST_ASSERT(counterValuePtr);
253  counterValuePtr->store(value, std::memory_order::memory_order_relaxed);
254 }
255 
256 uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
257 {
258  CheckCounterUid(counterUid);
259  std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
260  BOOST_ASSERT(counterValuePtr);
261  return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed);
262 }
263 
264 uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value)
265 {
266  CheckCounterUid(counterUid);
267  std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
268  BOOST_ASSERT(counterValuePtr);
269  return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed);
270 }
271 
272 uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
273 {
274  CheckCounterUid(counterUid);
275  std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
276  BOOST_ASSERT(counterValuePtr);
277  return counterValuePtr->operator++(std::memory_order::memory_order_relaxed);
278 }
279 
281 {
282  return m_GuidGenerator.NextGuid();
283 }
284 
286 {
287  return m_GuidGenerator.GenerateStaticId(str);
288 }
289 
290 std::unique_ptr<ISendTimelinePacket> ProfilingService::GetSendTimelinePacket() const
291 {
292  return m_TimelinePacketWriterFactory.GetSendTimelinePacket();
293 }
294 
295 void ProfilingService::Initialize()
296 {
297  // Register a category for the basic runtime counters
298  if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime"))
299  {
300  m_CounterDirectory.RegisterCategory("ArmNN_Runtime");
301  }
302 
303  // Register a counter for the number of Network loads
304  if (!m_CounterDirectory.IsCounterRegistered("Network loads"))
305  {
306  const Counter* loadedNetworksCounter =
307  m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
308  armnn::profiling::NETWORK_LOADS,
309  "ArmNN_Runtime",
310  0,
311  0,
312  1.f,
313  "Network loads",
314  "The number of networks loaded at runtime",
315  std::string("networks"));
316  BOOST_ASSERT(loadedNetworksCounter);
317  InitializeCounterValue(loadedNetworksCounter->m_Uid);
318  }
319  // Register a counter for the number of unloaded networks
320  if (!m_CounterDirectory.IsCounterRegistered("Network unloads"))
321  {
322  const Counter* unloadedNetworksCounter =
323  m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
324  armnn::profiling::NETWORK_UNLOADS,
325  "ArmNN_Runtime",
326  0,
327  0,
328  1.f,
329  "Network unloads",
330  "The number of networks unloaded at runtime",
331  std::string("networks"));
332  BOOST_ASSERT(unloadedNetworksCounter);
333  InitializeCounterValue(unloadedNetworksCounter->m_Uid);
334  }
335  // Register a counter for the number of registered backends
336  if (!m_CounterDirectory.IsCounterRegistered("Backends registered"))
337  {
338  const Counter* registeredBackendsCounter =
339  m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
340  armnn::profiling::REGISTERED_BACKENDS,
341  "ArmNN_Runtime",
342  0,
343  0,
344  1.f,
345  "Backends registered",
346  "The number of registered backends",
347  std::string("backends"));
348  BOOST_ASSERT(registeredBackendsCounter);
349  InitializeCounterValue(registeredBackendsCounter->m_Uid);
350  }
351  // Register a counter for the number of registered backends
352  if (!m_CounterDirectory.IsCounterRegistered("Backends unregistered"))
353  {
354  const Counter* unregisteredBackendsCounter =
355  m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
356  armnn::profiling::UNREGISTERED_BACKENDS,
357  "ArmNN_Runtime",
358  0,
359  0,
360  1.f,
361  "Backends unregistered",
362  "The number of unregistered backends",
363  std::string("backends"));
364  BOOST_ASSERT(unregisteredBackendsCounter);
365  InitializeCounterValue(unregisteredBackendsCounter->m_Uid);
366  }
367  // Register a counter for the number of inferences run
368  if (!m_CounterDirectory.IsCounterRegistered("Inferences run"))
369  {
370  const Counter* inferencesRunCounter =
371  m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
372  armnn::profiling::INFERENCES_RUN,
373  "ArmNN_Runtime",
374  0,
375  0,
376  1.f,
377  "Inferences run",
378  "The number of inferences run",
379  std::string("inferences"));
380  BOOST_ASSERT(inferencesRunCounter);
381  InitializeCounterValue(inferencesRunCounter->m_Uid);
382  }
383 }
384 
385 void ProfilingService::InitializeCounterValue(uint16_t counterUid)
386 {
387  // Increase the size of the counter index if necessary
388  if (counterUid >= m_CounterIndex.size())
389  {
390  m_CounterIndex.resize(boost::numeric_cast<size_t>(counterUid) + 1);
391  }
392 
393  // Create a new atomic counter and add it to the list
394  m_CounterValues.emplace_back(0);
395 
396  // Register the new counter to the counter index for quick access
397  std::atomic<uint32_t>* counterValuePtr = &(m_CounterValues.back());
398  m_CounterIndex.at(counterUid) = counterValuePtr;
399 }
400 
401 void ProfilingService::Reset()
402 {
403  // Stop the profiling service...
404  Stop();
405 
406  // ...then delete all the counter data and configuration...
407  m_CounterIndex.clear();
408  m_CounterValues.clear();
409  m_CounterDirectory.Clear();
410  m_CounterIdMap.Reset();
411  m_BufferManager.Reset();
412 
413  // ...finally reset the profiling state machine
414  m_StateMachine.Reset();
415  m_BackendProfilingContexts.clear();
416  m_MaxGlobalCounterId = armnn::profiling::INFERENCES_RUN;
417 }
418 
419 void ProfilingService::Stop()
420 {
421  // The order in which we reset/stop the components is not trivial!
422  // First stop the producing threads
423  // Command Handler first as it is responsible for launching then Periodic Counter capture thread
424  m_CommandHandler.Stop();
425  m_PeriodicCounterCapture.Stop();
426  // The the consuming thread
427  m_SendThread.Stop(false);
428 
429  // ...then close and destroy the profiling connection...
430  if (m_ProfilingConnection != nullptr && m_ProfilingConnection->IsOpen())
431  {
432  m_ProfilingConnection->Close();
433  }
434  m_ProfilingConnection.reset();
435 
436  // ...then move to the "NotConnected" state
438 }
439 
440 inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const
441 {
442  if (!IsCounterRegistered(counterUid))
443  {
444  throw InvalidArgumentException(boost::str(boost::format("Counter UID %1% is not registered") % counterUid));
445  }
446 }
447 
449 {
450  Stop();
451 }
452 
453 } // namespace profiling
454 
455 } // namespace armnn
void Start(IProfilingConnection &profilingConnection) override
Start the thread.
Definition: SendThread.cpp:52
uint32_t GetCounterValue(uint16_t counterUid) const override
uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) override
uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) override
const Category * RegisterCategory(const std::string &categoryName, const Optional< uint16_t > &deviceUid=EmptyOptional(), const Optional< uint16_t > &counterSetUid=EmptyOptional()) override
const ICounterMappings & GetCounterMappings() const override
ProfilingStaticGuid GenerateStaticId(const std::string &str) override
Create a ProfilingStaticGuid based on a hash of the string.
CaptureData GetCaptureData() const
Definition: Holder.cpp:54
void SetCaptureData(uint32_t capturePeriod, const std::vector< uint16_t > &counterIds, const std::set< armnn::BackendId > &activeBackends)
Definition: Holder.cpp:74
bool IsCategoryRegistered(const std::string &categoryName) const
IRegisterCounterMapping & GetCounterMappingRegistry()
ProfilingStaticGuid GenerateStaticId(const std::string &str) override
Create a ProfilingStaticGuid based on a hash of the string.
std::unique_ptr< ISendTimelinePacket > GetSendTimelinePacket() const
void TransitionToState(ProfilingState newState)
const Counter * RegisterCounter(const BackendId &backendId, const uint16_t uid, const std::string &parentCategoryName, uint16_t counterClass, uint16_t interpolation, double multiplier, const std::string &name, const std::string &description, const Optional< std::string > &units=EmptyOptional(), const Optional< uint16_t > &numberOfCores=EmptyOptional(), const Optional< uint16_t > &deviceUid=EmptyOptional(), const Optional< uint16_t > &counterSetUid=EmptyOptional()) override
ProfilingDynamicGuid NextGuid() override
Return the next random Guid in the sequence.
void Stop(bool rethrowSendThreadExceptions=true) override
Stop the thread.
Definition: SendThread.cpp:82
#define ARMNN_LOG(severity)
Definition: Logging.hpp:163
ProfilingDynamicGuid NextGuid() override
Return the next random Guid in the sequence.
void SetCounterValue(uint16_t counterUid, uint32_t value) override
void AddBackendProfilingContext(const BackendId backendId, std::shared_ptr< armnn::profiling::IBackendProfilingContext > profilingContext)
const ICounterDirectory & GetCounterDirectory() const
bool IsCounterRegistered(uint16_t counterUid) const override
bool IsProfilingEnabled() const override
virtual const char * what() const noexcept override
Definition: Exceptions.cpp:32
void Start(IProfilingConnection &profilingConnection)
uint16_t GetCounterCount() const override
void ResetExternalProfilingOptions(const ExternalProfilingOptions &options, bool resetProfilingService=false)
uint32_t IncrementCounterValue(uint16_t counterUid) override
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
Strongly typed guids to distinguish between those generated at runtime, and those that are statically...
Definition: Types.hpp:291
uint16_t GetCounterCount() const override
ProfilingState GetCurrentState() const
std::unique_ptr< ISendTimelinePacket > GetSendTimelinePacket() const override
ProfilingState ConfigureProfilingService(const ExternalProfilingOptions &options, bool resetProfilingService=false)
armnn::Runtime::CreationOptions::ExternalProfilingOptions options
bool IsCounterRegistered(uint16_t counterUid) const
void SetCaptureData(uint32_t capturePeriod, const std::vector< uint16_t > &counterIds, const std::set< BackendId > &activeBackends)