diff options
Diffstat (limited to 'profiling/client/src/PeriodicCounterSelectionCommandHandler.cpp')
-rw-r--r-- | profiling/client/src/PeriodicCounterSelectionCommandHandler.cpp | 234 |
1 files changed, 234 insertions, 0 deletions
diff --git a/profiling/client/src/PeriodicCounterSelectionCommandHandler.cpp b/profiling/client/src/PeriodicCounterSelectionCommandHandler.cpp new file mode 100644 index 0000000000..06f2c6588b --- /dev/null +++ b/profiling/client/src/PeriodicCounterSelectionCommandHandler.cpp @@ -0,0 +1,234 @@ +// +// Copyright © 2019 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "PeriodicCounterSelectionCommandHandler.hpp" +#include "ProfilingUtils.hpp" + +#include <client/include/ProfilingOptions.hpp> + +#include <common/include/NumericCast.hpp> + +#include <fmt/format.h> + +#include <vector> + +namespace arm +{ + +namespace pipe +{ + +void PeriodicCounterSelectionCommandHandler::ParseData(const arm::pipe::Packet& packet, CaptureData& captureData) +{ + std::vector<uint16_t> counterIds; + uint32_t sizeOfUint32 = arm::pipe::numeric_cast<uint32_t>(sizeof(uint32_t)); + uint32_t sizeOfUint16 = arm::pipe::numeric_cast<uint32_t>(sizeof(uint16_t)); + uint32_t offset = 0; + + if (packet.GetLength() < 4) + { + // Insufficient packet size + return; + } + + // Parse the capture period + uint32_t capturePeriod = ReadUint32(packet.GetData(), offset); + + // Set the capture period + captureData.SetCapturePeriod(capturePeriod); + + // Parse the counter ids + unsigned int counters = (packet.GetLength() - 4) / 2; + if (counters > 0) + { + counterIds.reserve(counters); + offset += sizeOfUint32; + for (unsigned int i = 0; i < counters; ++i) + { + // Parse the counter id + uint16_t counterId = ReadUint16(packet.GetData(), offset); + counterIds.emplace_back(counterId); + offset += sizeOfUint16; + } + } + + // Set the counter ids + captureData.SetCounterIds(counterIds); +} + +void PeriodicCounterSelectionCommandHandler::operator()(const arm::pipe::Packet& packet) +{ + ProfilingState currentState = m_StateMachine.GetCurrentState(); + switch (currentState) + { + case ProfilingState::Uninitialised: + case ProfilingState::NotConnected: + case ProfilingState::WaitingForAck: + throw arm::pipe::ProfilingException(fmt::format("Periodic Counter Selection Command Handler invoked while in " + "an wrong state: {}", + GetProfilingStateName(currentState))); + case ProfilingState::Active: + { + // Process the packet + if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 4u)) + { + throw arm::pipe::InvalidArgumentException(fmt::format("Expected Packet family = 0, id = 4 but " + "received family = {}, id = {}", + packet.GetPacketFamily(), + packet.GetPacketId())); + } + + // Parse the packet to get the capture period and counter UIDs + CaptureData captureData; + ParseData(packet, captureData); + + // Get the capture data + uint32_t capturePeriod = captureData.GetCapturePeriod(); + // Validate that the capture period is within the acceptable range. + if (capturePeriod > 0 && capturePeriod < arm::pipe::LOWEST_CAPTURE_PERIOD) + { + capturePeriod = arm::pipe::LOWEST_CAPTURE_PERIOD; + } + const std::vector<uint16_t>& counterIds = captureData.GetCounterIds(); + + // Check whether the selected counter UIDs are valid + std::vector<uint16_t> validCounterIds; + for (uint16_t counterId : counterIds) + { + // Check whether the counter is registered + if (!m_ReadCounterValues.IsCounterRegistered(counterId)) + { + // Invalid counter UID, ignore it and continue + continue; + } + // The counter is valid + validCounterIds.emplace_back(counterId); + } + + std::sort(validCounterIds.begin(), validCounterIds.end()); + + auto backendIdStart = std::find_if(validCounterIds.begin(), validCounterIds.end(), [&](uint16_t& counterId) + { + return counterId > m_MaxArmCounterId; + }); + + std::set<std::string> activeBackends; + std::set<uint16_t> backendCounterIds = std::set<uint16_t>(backendIdStart, validCounterIds.end()); + + if (m_BackendCounterMap.size() != 0) + { + std::set<uint16_t> newCounterIds; + std::set<uint16_t> unusedCounterIds; + + // Get any backend counter ids that is in backendCounterIds but not in m_PrevBackendCounterIds + std::set_difference(backendCounterIds.begin(), backendCounterIds.end(), + m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(), + std::inserter(newCounterIds, newCounterIds.begin())); + + // Get any backend counter ids that is in m_PrevBackendCounterIds but not in backendCounterIds + std::set_difference(m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(), + backendCounterIds.begin(), backendCounterIds.end(), + std::inserter(unusedCounterIds, unusedCounterIds.begin())); + + activeBackends = ProcessBackendCounterIds(capturePeriod, newCounterIds, unusedCounterIds); + } + else + { + activeBackends = ProcessBackendCounterIds(capturePeriod, backendCounterIds, {}); + } + + // save the new backend counter ids for next time + m_PrevBackendCounterIds = backendCounterIds; + + // Set the capture data with only the valid armnn counter UIDs + m_CaptureDataHolder.SetCaptureData(capturePeriod, {validCounterIds.begin(), backendIdStart}, activeBackends); + + // Echo back the Periodic Counter Selection packet to the Counter Stream Buffer + m_SendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, validCounterIds); + + if (capturePeriod == 0 || validCounterIds.empty()) + { + // No data capture stop the thread + m_PeriodicCounterCapture.Stop(); + } + else + { + // Start the Period Counter Capture thread (if not running already) + m_PeriodicCounterCapture.Start(); + } + + break; + } + default: + throw arm::pipe::ProfilingException(fmt::format("Unknown profiling service state: {}", + static_cast<int>(currentState))); + } +} + +std::set<std::string> PeriodicCounterSelectionCommandHandler::ProcessBackendCounterIds( + const uint32_t capturePeriod, + const std::set<uint16_t> newCounterIds, + const std::set<uint16_t> unusedCounterIds) +{ + std::set<std::string> changedBackends; + std::set<std::string> activeBackends = m_CaptureDataHolder.GetCaptureData().GetActiveBackends(); + + for (uint16_t counterId : newCounterIds) + { + auto backendId = m_CounterIdMap.GetBackendId(counterId); + m_BackendCounterMap[backendId.second].emplace_back(backendId.first); + changedBackends.insert(backendId.second); + } + // Add any new backends to active backends + activeBackends.insert(changedBackends.begin(), changedBackends.end()); + + for (uint16_t counterId : unusedCounterIds) + { + auto backendId = m_CounterIdMap.GetBackendId(counterId); + std::vector<uint16_t>& backendCounters = m_BackendCounterMap[backendId.second]; + + backendCounters.erase(std::remove(backendCounters.begin(), backendCounters.end(), backendId.first)); + + if(backendCounters.size() == 0) + { + // If a backend has no counters associated with it we remove it from active backends and + // send a capture period of zero with an empty vector, this will deactivate all the backends counters + activeBackends.erase(backendId.second); + ActivateBackendCounters(backendId.second, 0, {}); + } + else + { + changedBackends.insert(backendId.second); + } + } + + // If the capture period remains the same we only need to update the backends who's counters have changed + if(capturePeriod == m_PrevCapturePeriod) + { + for (auto backend : changedBackends) + { + ActivateBackendCounters(backend, capturePeriod, m_BackendCounterMap[backend]); + } + } + // Otherwise update all the backends with the new capture period and any new/unused counters + else + { + for (auto backend : m_BackendCounterMap) + { + ActivateBackendCounters(backend.first, capturePeriod, backend.second); + } + if(capturePeriod == 0) + { + activeBackends = {}; + } + m_PrevCapturePeriod = capturePeriod; + } + + return activeBackends; +} + +} // namespace pipe + +} // namespace arm |