ArmNN
 21.02
ProfilingService.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
9 #include "BufferManager.hpp"
10 #include "CommandHandler.hpp"
12 #include "CounterDirectory.hpp"
13 #include "CounterIdMap.hpp"
15 #include "ICounterRegistry.hpp"
16 #include "ICounterValues.hpp"
18 #include "IProfilingService.hpp"
19 #include "IReportStructure.hpp"
27 #include "SendCounterPacket.hpp"
28 #include "SendThread.hpp"
29 #include "SendTimelinePacket.hpp"
31 #include "INotifyBackends.hpp"
33 
34 #include <list>
35 
36 namespace armnn
37 {
38 
39 namespace profiling
40 {
41 // Static constants describing ArmNN's counter UID's
42 static const uint16_t NETWORK_LOADS = 0;
43 static const uint16_t NETWORK_UNLOADS = 1;
44 static const uint16_t REGISTERED_BACKENDS = 2;
45 static const uint16_t UNREGISTERED_BACKENDS = 3;
46 static const uint16_t INFERENCES_RUN = 4;
47 static const uint16_t MAX_ARMNN_COUNTER = INFERENCES_RUN;
48 
50 {
51 public:
53  using IProfilingConnectionFactoryPtr = std::unique_ptr<IProfilingConnectionFactory>;
54  using IProfilingConnectionPtr = std::unique_ptr<IProfilingConnection>;
55  using CounterIndices = std::vector<std::atomic<uint32_t>*>;
56  using CounterValues = std::list<std::atomic<uint32_t>>;
57  using BackendProfilingContext = std::unordered_map<BackendId,
58  std::shared_ptr<armnn::profiling::IBackendProfilingContext>>;
59 
61  : m_Options()
62  , m_TimelineReporting(false)
63  , m_CounterDirectory()
64  , m_ProfilingConnectionFactory(new ProfilingConnectionFactory())
65  , m_ProfilingConnection()
66  , m_StateMachine()
67  , m_CounterIndex()
68  , m_CounterValues()
69  , m_CommandHandlerRegistry()
70  , m_PacketVersionResolver()
71  , m_CommandHandler(1000,
72  false,
73  m_CommandHandlerRegistry,
74  m_PacketVersionResolver)
75  , m_BufferManager()
76  , m_SendCounterPacket(m_BufferManager)
77  , m_SendThread(m_StateMachine, m_BufferManager, m_SendCounterPacket)
78  , m_SendTimelinePacket(m_BufferManager)
79  , m_PeriodicCounterCapture(m_Holder, m_SendCounterPacket, *this, m_CounterIdMap, m_BackendProfilingContexts)
80  , m_ConnectionAcknowledgedCommandHandler(0,
81  1,
82  m_PacketVersionResolver.ResolvePacketVersion(0, 1).GetEncodedValue(),
83  m_CounterDirectory,
84  m_SendCounterPacket,
85  m_SendTimelinePacket,
86  m_StateMachine,
87  *this,
88  m_BackendProfilingContexts)
89  , m_RequestCounterDirectoryCommandHandler(0,
90  3,
91  m_PacketVersionResolver.ResolvePacketVersion(0, 3).GetEncodedValue(),
92  m_CounterDirectory,
93  m_SendCounterPacket,
94  m_SendTimelinePacket,
95  m_StateMachine)
96  , m_PeriodicCounterSelectionCommandHandler(0,
97  4,
98  m_PacketVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(),
99  m_BackendProfilingContexts,
100  m_CounterIdMap,
101  m_Holder,
102  MAX_ARMNN_COUNTER,
103  m_PeriodicCounterCapture,
104  *this,
105  m_SendCounterPacket,
106  m_StateMachine)
107  , m_PerJobCounterSelectionCommandHandler(0,
108  5,
109  m_PacketVersionResolver.ResolvePacketVersion(0, 5).GetEncodedValue(),
110  m_StateMachine)
111  , m_ActivateTimelineReportingCommandHandler(0,
112  6,
113  m_PacketVersionResolver.ResolvePacketVersion(0, 6)
114  .GetEncodedValue(),
115  m_SendTimelinePacket,
116  m_StateMachine,
117  reportStructure,
118  m_TimelineReporting,
119  *this)
120  , m_DeactivateTimelineReportingCommandHandler(0,
121  7,
122  m_PacketVersionResolver.ResolvePacketVersion(0, 7)
123  .GetEncodedValue(),
124  m_TimelineReporting,
125  m_StateMachine,
126  *this)
127  , m_TimelinePacketWriterFactory(m_BufferManager)
128  , m_MaxGlobalCounterId(armnn::profiling::INFERENCES_RUN)
129  , m_ServiceActive(false)
130  {
131  // Register the "Connection Acknowledged" command handler
132  m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler);
133 
134  // Register the "Request Counter Directory" command handler
135  m_CommandHandlerRegistry.RegisterFunctor(&m_RequestCounterDirectoryCommandHandler);
136 
137  // Register the "Periodic Counter Selection" command handler
138  m_CommandHandlerRegistry.RegisterFunctor(&m_PeriodicCounterSelectionCommandHandler);
139 
140  // Register the "Per-Job Counter Selection" command handler
141  m_CommandHandlerRegistry.RegisterFunctor(&m_PerJobCounterSelectionCommandHandler);
142 
143  m_CommandHandlerRegistry.RegisterFunctor(&m_ActivateTimelineReportingCommandHandler);
144 
145  m_CommandHandlerRegistry.RegisterFunctor(&m_DeactivateTimelineReportingCommandHandler);
146  }
147 
149 
150  // Resets the profiling options, optionally clears the profiling service entirely
151  void ResetExternalProfilingOptions(const ExternalProfilingOptions& options, bool resetProfilingService = false);
153  bool resetProfilingService = false);
154 
155 
156  // Updates the profiling service, making it transition to a new state if necessary
157  void Update();
158 
159  // Disconnects the profiling service from the external server
160  void Disconnect();
161 
162  // Store a profiling context returned from a backend that support profiling.
163  void AddBackendProfilingContext(const BackendId backendId,
164  std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext);
165 
166  // Enable the recording of timeline events and entities
167  void NotifyBackendsForTimelineReporting() override;
168 
169  const ICounterDirectory& GetCounterDirectory() const;
172  bool IsCounterRegistered(uint16_t counterUid) const override;
173  uint32_t GetAbsoluteCounterValue(uint16_t counterUid) const override;
174  uint32_t GetDeltaCounterValue(uint16_t counterUid) override;
175  uint16_t GetCounterCount() const override;
176  // counter global/backend mapping functions
177  const ICounterMappings& GetCounterMappings() const override;
179 
180  // Getters for the profiling service state
181  bool IsProfilingEnabled() const override;
182 
183  CaptureData GetCaptureData() override;
184  void SetCaptureData(uint32_t capturePeriod,
185  const std::vector<uint16_t>& counterIds,
186  const std::set<BackendId>& activeBackends);
187 
188  // Setters for the profiling service state
189  void SetCounterValue(uint16_t counterUid, uint32_t value) override;
190  uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) override;
191  uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) override;
192  uint32_t IncrementCounterValue(uint16_t counterUid) override;
193 
194  // IProfilingGuidGenerator functions
195  /// Return the next random Guid in the sequence
196  ProfilingDynamicGuid NextGuid() override;
197  /// Create a ProfilingStaticGuid based on a hash of the string
198  ProfilingStaticGuid GenerateStaticId(const std::string& str) override;
199 
200 
201  std::unique_ptr<ISendTimelinePacket> GetSendTimelinePacket() const override;
202 
204  {
205  return m_SendCounterPacket;
206  }
207 
209 
210  static ProfilingStaticGuid GetStaticId(const std::string& str);
211 
212  void ResetGuidGenerator();
213 
215  {
216  return m_TimelineReporting;
217  }
218 
219  void AddLocalPacketHandler(ILocalPacketHandlerSharedPtr localPacketHandler);
220 
221  void NotifyProfilingServiceActive() override; // IProfilingServiceStatus
222  void WaitForProfilingServiceActivation(unsigned int timeout) override; // IProfilingServiceStatus
223 
224 private:
225  //Copy/move constructors/destructors and copy/move assignment operators are deleted
226  ProfilingService(const ProfilingService&) = delete;
228  ProfilingService& operator=(const ProfilingService&) = delete;
229  ProfilingService& operator=(ProfilingService&&) = delete;
230 
231  // Initialization/reset functions
232  void Initialize();
233  void InitializeCounterValue(uint16_t counterUid);
234  void Reset();
235  void Stop();
236 
237  // Helper function
238  void CheckCounterUid(uint16_t counterUid) const;
239 
240  // Profiling service components
241  ExternalProfilingOptions m_Options;
242  std::atomic<bool> m_TimelineReporting;
243  CounterDirectory m_CounterDirectory;
244  CounterIdMap m_CounterIdMap;
245  IProfilingConnectionFactoryPtr m_ProfilingConnectionFactory;
246  IProfilingConnectionPtr m_ProfilingConnection;
247  ProfilingStateMachine m_StateMachine;
248  CounterIndices m_CounterIndex;
249  CounterValues m_CounterValues;
250  arm::pipe::CommandHandlerRegistry m_CommandHandlerRegistry;
251  arm::pipe::PacketVersionResolver m_PacketVersionResolver;
252  CommandHandler m_CommandHandler;
253  BufferManager m_BufferManager;
254  SendCounterPacket m_SendCounterPacket;
255  SendThread m_SendThread;
256  SendTimelinePacket m_SendTimelinePacket;
257 
258  Holder m_Holder;
259 
260  PeriodicCounterCapture m_PeriodicCounterCapture;
261 
262  ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler;
263  RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler;
264  PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler;
265  PerJobCounterSelectionCommandHandler m_PerJobCounterSelectionCommandHandler;
266  ActivateTimelineReportingCommandHandler m_ActivateTimelineReportingCommandHandler;
267  DeactivateTimelineReportingCommandHandler m_DeactivateTimelineReportingCommandHandler;
268 
269  TimelinePacketWriterFactory m_TimelinePacketWriterFactory;
270  BackendProfilingContext m_BackendProfilingContexts;
271  uint16_t m_MaxGlobalCounterId;
272 
273  static ProfilingGuidGenerator m_GuidGenerator;
274 
275  // Signalling to let external actors know when service is active or not
276  std::mutex m_ServiceActiveMutex;
277  std::condition_variable m_ServiceActiveConditionVariable;
278  bool m_ServiceActive;
279 
280 protected:
281 
282  // Protected methods for testing
286  {
287  ARMNN_ASSERT(instance.m_ProfilingConnectionFactory);
288  ARMNN_ASSERT(other);
289 
290  backup = instance.m_ProfilingConnectionFactory.release();
291  instance.m_ProfilingConnectionFactory.reset(other);
292  }
294  {
295  return instance.m_ProfilingConnection.get();
296  }
298  {
299  instance.m_StateMachine.TransitionToState(newState);
300  }
301  bool WaitForPacketSent(ProfilingService& instance, uint32_t timeout = 1000)
302  {
303  return instance.m_SendThread.WaitForPacketSent(timeout);
304  }
305 
307  {
308  return instance.m_BufferManager;
309  }
310 };
311 
312 } // namespace profiling
313 
314 } // namespace armnn
bool IsCounterRegistered(uint16_t counterUid) const override
void WaitForProfilingServiceActivation(unsigned int timeout) override
std::shared_ptr< ILocalPacketHandler > ILocalPacketHandlerSharedPtr
uint32_t GetAbsoluteCounterValue(uint16_t counterUid) const override
std::list< std::atomic< uint32_t > > CounterValues
std::unordered_map< BackendId, std::shared_ptr< armnn::profiling::IBackendProfilingContext > > BackendProfilingContext
ProfilingState GetCurrentState() const
DataLayout::NCHW false
Strongly typed guids to distinguish between those generated at runtime, and those that are statically...
Definition: Types.hpp:335
Copyright (c) 2021 ARM Limited and Contributors.
ProfilingDynamicGuid NextGuid() override
Return the next random Guid in the sequence.
uint32_t IncrementCounterValue(uint16_t counterUid) override
std::unique_ptr< IProfilingConnection > IProfilingConnectionPtr
void AddLocalPacketHandler(ILocalPacketHandlerSharedPtr localPacketHandler)
bool WaitForPacketSent(ProfilingService &instance, uint32_t timeout=1000)
bool WaitForPacketSent(uint32_t timeout)
Definition: SendThread.cpp:260
uint32_t GetDeltaCounterValue(uint16_t counterUid) override
uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) override
std::vector< std::atomic< uint32_t > * > CounterIndices
void SwapProfilingConnectionFactory(ProfilingService &instance, IProfilingConnectionFactory *other, IProfilingConnectionFactory *&backup)
static ProfilingStaticGuid GetStaticId(const std::string &str)
void SetCaptureData(uint32_t capturePeriod, const std::vector< uint16_t > &counterIds, const std::set< BackendId > &activeBackends)
std::unique_ptr< ISendTimelinePacket > GetSendTimelinePacket() const override
const ICounterMappings & GetCounterMappings() const override
void ResetExternalProfilingOptions(const ExternalProfilingOptions &options, bool resetProfilingService=false)
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) override
IRegisterCounterMapping & GetCounterMappingRegistry()
BufferManager & GetBufferManager(ProfilingService &instance)
ProfilingService(Optional< IReportStructure &> reportStructure=EmptyOptional())
void SetCounterValue(uint16_t counterUid, uint32_t value) override
EmptyOptional is used to initialize the Optional class in case we want to have default value for an O...
Definition: Optional.hpp:32
bool IsProfilingEnabled() const override
IProfilingConnection * GetProfilingConnection(ProfilingService &instance)
static ProfilingDynamicGuid GetNextGuid()
void TransitionToState(ProfilingState newState)
std::unique_ptr< IProfilingConnectionFactory > IProfilingConnectionFactoryPtr
void TransitionToState(ProfilingService &instance, ProfilingState newState)
ISendCounterPacket & GetSendCounterPacket() override
void AddBackendProfilingContext(const BackendId backendId, std::shared_ptr< armnn::profiling::IBackendProfilingContext > profilingContext)
const ICounterDirectory & GetCounterDirectory() const
ProfilingState ConfigureProfilingService(const ExternalProfilingOptions &options, bool resetProfilingService=false)
uint16_t GetCounterCount() const override
ProfilingStaticGuid GenerateStaticId(const std::string &str) override
Create a ProfilingStaticGuid based on a hash of the string.