ArmNN
 20.05
PeriodicCounterSelectionCommandHandler.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 #include "ProfilingUtils.hpp"
8 
9 #include <armnn/Types.hpp>
10 #include <boost/numeric/conversion/cast.hpp>
11 #include <boost/format.hpp>
12 
13 #include <vector>
14 
15 namespace armnn
16 {
17 
18 namespace profiling
19 {
20 
21 void PeriodicCounterSelectionCommandHandler::ParseData(const Packet& packet, CaptureData& captureData)
22 {
23  std::vector<uint16_t> counterIds;
24  uint32_t sizeOfUint32 = boost::numeric_cast<uint32_t>(sizeof(uint32_t));
25  uint32_t sizeOfUint16 = boost::numeric_cast<uint32_t>(sizeof(uint16_t));
26  uint32_t offset = 0;
27 
28  if (packet.GetLength() < 4)
29  {
30  // Insufficient packet size
31  return;
32  }
33 
34  // Parse the capture period
35  uint32_t capturePeriod = ReadUint32(packet.GetData(), offset);
36 
37  // Set the capture period
38  captureData.SetCapturePeriod(capturePeriod);
39 
40  // Parse the counter ids
41  unsigned int counters = (packet.GetLength() - 4) / 2;
42  if (counters > 0)
43  {
44  counterIds.reserve(counters);
45  offset += sizeOfUint32;
46  for (unsigned int i = 0; i < counters; ++i)
47  {
48  // Parse the counter id
49  uint16_t counterId = ReadUint16(packet.GetData(), offset);
50  counterIds.emplace_back(counterId);
51  offset += sizeOfUint16;
52  }
53  }
54 
55  // Set the counter ids
56  captureData.SetCounterIds(counterIds);
57 }
58 
60 {
61  ProfilingState currentState = m_StateMachine.GetCurrentState();
62  switch (currentState)
63  {
67  throw RuntimeException(boost::str(boost::format("Periodic Counter Selection Command Handler invoked while in "
68  "an wrong state: %1%")
69  % GetProfilingStateName(currentState)));
71  {
72  // Process the packet
73  if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 4u))
74  {
75  throw armnn::InvalidArgumentException(boost::str(boost::format("Expected Packet family = 0, id = 4 but "
76  "received family = %1%, id = %2%")
77  % packet.GetPacketFamily()
78  % packet.GetPacketId()));
79  }
80 
81  // Parse the packet to get the capture period and counter UIDs
82  CaptureData captureData;
83  ParseData(packet, captureData);
84 
85  // Get the capture data
86  uint32_t capturePeriod = captureData.GetCapturePeriod();
87  // Validate that the capture period is within the acceptable range.
88  if (capturePeriod > 0 && capturePeriod < LOWEST_CAPTURE_PERIOD)
89  {
90  capturePeriod = LOWEST_CAPTURE_PERIOD;
91  }
92  const std::vector<uint16_t>& counterIds = captureData.GetCounterIds();
93 
94  // Check whether the selected counter UIDs are valid
95  std::vector<uint16_t> validCounterIds;
96  for (uint16_t counterId : counterIds)
97  {
98  // Check whether the counter is registered
99  if (!m_ReadCounterValues.IsCounterRegistered(counterId))
100  {
101  // Invalid counter UID, ignore it and continue
102  continue;
103  }
104  // The counter is valid
105  validCounterIds.emplace_back(counterId);
106  }
107 
108  std::sort(validCounterIds.begin(), validCounterIds.end());
109 
110  auto backendIdStart = std::find_if(validCounterIds.begin(), validCounterIds.end(), [&](uint16_t& counterId)
111  {
112  return counterId > m_MaxArmCounterId;
113  });
114 
115  std::set<armnn::BackendId> activeBackends;
116  std::set<uint16_t> backendCounterIds = std::set<uint16_t>(backendIdStart, validCounterIds.end());
117 
118  if (m_BackendCounterMap.size() != 0)
119  {
120  std::set<uint16_t> newCounterIds;
121  std::set<uint16_t> unusedCounterIds;
122 
123  // Get any backend counter ids that is in backendCounterIds but not in m_PrevBackendCounterIds
124  std::set_difference(backendCounterIds.begin(), backendCounterIds.end(),
125  m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
126  std::inserter(newCounterIds, newCounterIds.begin()));
127 
128  // Get any backend counter ids that is in m_PrevBackendCounterIds but not in backendCounterIds
129  std::set_difference(m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
130  backendCounterIds.begin(), backendCounterIds.end(),
131  std::inserter(unusedCounterIds, unusedCounterIds.begin()));
132 
133  activeBackends = ProcessBackendCounterIds(capturePeriod, newCounterIds, unusedCounterIds);
134  }
135  else
136  {
137  activeBackends = ProcessBackendCounterIds(capturePeriod, backendCounterIds, {});
138  }
139 
140  // save the new backend counter ids for next time
141  m_PrevBackendCounterIds = backendCounterIds;
142 
143  // Set the capture data with only the valid armnn counter UIDs
144  m_CaptureDataHolder.SetCaptureData(capturePeriod, {validCounterIds.begin(), backendIdStart}, activeBackends);
145 
146  // Echo back the Periodic Counter Selection packet to the Counter Stream Buffer
147  m_SendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, validCounterIds);
148 
149  if (capturePeriod == 0 || validCounterIds.empty())
150  {
151  // No data capture stop the thread
152  m_PeriodicCounterCapture.Stop();
153  }
154  else
155  {
156  // Start the Period Counter Capture thread (if not running already)
157  m_PeriodicCounterCapture.Start();
158  }
159 
160  break;
161  }
162  default:
163  throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%")
164  % static_cast<int>(currentState)));
165  }
166 }
167 
168 std::set<armnn::BackendId> PeriodicCounterSelectionCommandHandler::ProcessBackendCounterIds(
169  const uint32_t capturePeriod,
170  const std::set<uint16_t> newCounterIds,
171  const std::set<uint16_t> unusedCounterIds)
172 {
173  std::set<armnn::BackendId> changedBackends;
174  std::set<armnn::BackendId> activeBackends = m_CaptureDataHolder.GetCaptureData().GetActiveBackends();
175 
176  for (uint16_t counterId : newCounterIds)
177  {
178  auto backendId = m_CounterIdMap.GetBackendId(counterId);
179  m_BackendCounterMap[backendId.second].emplace_back(backendId.first);
180  changedBackends.insert(backendId.second);
181  }
182  // Add any new backends to active backends
183  activeBackends.insert(changedBackends.begin(), changedBackends.end());
184 
185  for (uint16_t counterId : unusedCounterIds)
186  {
187  auto backendId = m_CounterIdMap.GetBackendId(counterId);
188  std::vector<uint16_t>& backendCounters = m_BackendCounterMap[backendId.second];
189 
190  backendCounters.erase(std::remove(backendCounters.begin(), backendCounters.end(), backendId.first));
191 
192  if(backendCounters.size() == 0)
193  {
194  // If a backend has no counters associated with it we remove it from active backends and
195  // send a capture period of zero with an empty vector, this will deactivate all the backends counters
196  activeBackends.erase(backendId.second);
197  ActivateBackedCounters(backendId.second, 0, {});
198  }
199  else
200  {
201  changedBackends.insert(backendId.second);
202  }
203  }
204 
205  // If the capture period remains the same we only need to update the backends who's counters have changed
206  if(capturePeriod == m_PrevCapturePeriod)
207  {
208  for (auto backend : changedBackends)
209  {
210  ActivateBackedCounters(backend, capturePeriod, m_BackendCounterMap[backend]);
211  }
212  }
213  // Otherwise update all the backends with the new capture period and any new/unused counters
214  else
215  {
216  for (auto backend : m_BackendCounterMap)
217  {
218  ActivateBackedCounters(backend.first, capturePeriod, backend.second);
219  }
220  if(capturePeriod == 0)
221  {
222  activeBackends = {};
223  }
224  m_PrevCapturePeriod = capturePeriod;
225  }
226 
227  return activeBackends;
228 }
229 
230 } // namespace profiling
231 
232 } // namespace armnn
const std::set< armnn::BackendId > & GetActiveBackends() const
Definition: Holder.cpp:39
const std::vector< uint16_t > & GetCounterIds() const
Definition: Holder.cpp:49
virtual void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod, const std::vector< uint16_t > &selectedCounterIds)=0
Create and write a PeriodicCounterSelectionPacket from the parameters to the buffer.
uint16_t ReadUint16(const IPacketBufferPtr &packetBuffer, unsigned int offset)
Copyright (c) 2020 ARM Limited.
uint32_t GetCapturePeriod() const
Definition: Holder.cpp:44
CaptureData GetCaptureData() const
Definition: Holder.cpp:54
void SetCaptureData(uint32_t capturePeriod, const std::vector< uint16_t > &counterIds, const std::set< armnn::BackendId > &activeBackends)
Definition: Holder.cpp:74
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:33
constexpr unsigned int LOWEST_CAPTURE_PERIOD
The lowest performance data capture interval we support is 10 miliseconds.
Definition: Types.hpp:21
virtual bool IsCounterRegistered(uint16_t counterUid) const =0
uint32_t ReadUint32(const IPacketBufferPtr &packetBuffer, unsigned int offset)
virtual const std::pair< uint16_t, armnn::BackendId > & GetBackendId(uint16_t globalCounterId) const =0
constexpr char const * GetProfilingStateName(ProfilingState state)