From a84edee4702c112a6e004b1987acc11144e2d6dd Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Wed, 2 Oct 2019 12:50:57 +0100 Subject: IVGCVSW-3937 Initial ServiceProfiling refactoring * Made the ServiceProfiling class a singleton * Registered basic category and counters * Code refactoring * Updated unit tests accordingly Signed-off-by: Matteo Martincigh Change-Id: I648a6202eead2a3016aac14d905511bd945a90cb --- src/profiling/CounterDirectory.cpp | 108 ++++++++++----- src/profiling/CounterDirectory.hpp | 24 ++-- src/profiling/CounterValues.hpp | 26 ++-- src/profiling/ProfilingService.cpp | 201 ++++++++++++++++++---------- src/profiling/ProfilingService.hpp | 63 ++++++--- src/profiling/ProfilingStateMachine.cpp | 136 +++++++++---------- src/profiling/ProfilingStateMachine.hpp | 22 +-- src/profiling/SendCounterPacket.cpp | 2 +- src/profiling/SocketProfilingConnection.cpp | 18 +-- src/profiling/test/ProfilingTests.cpp | 84 ++++++------ 10 files changed, 411 insertions(+), 273 deletions(-) diff --git a/src/profiling/CounterDirectory.cpp b/src/profiling/CounterDirectory.cpp index cef3d6a76d..979b8046be 100644 --- a/src/profiling/CounterDirectory.cpp +++ b/src/profiling/CounterDirectory.cpp @@ -29,7 +29,7 @@ const Category* CounterDirectory::RegisterCategory(const std::string& categoryNa } // Check that the given category is not already registered - if (CheckIfCategoryIsRegistered(categoryName)) + if (IsCategoryRegistered(categoryName)) { throw InvalidArgumentException( boost::str(boost::format("Trying to register a category already registered (\"%1%\")") @@ -41,7 +41,7 @@ const Category* CounterDirectory::RegisterCategory(const std::string& categoryNa if (deviceUidValue > 0) { // Check that the (optional) device is already registered - if (!CheckIfDeviceIsRegistered(deviceUidValue)) + if (!IsDeviceRegistered(deviceUidValue)) { throw InvalidArgumentException( boost::str(boost::format("Trying to connect a category (\"%1%\") to a device that is " @@ -56,7 +56,7 @@ const Category* CounterDirectory::RegisterCategory(const std::string& categoryNa if (counterSetUidValue > 0) { // Check that the (optional) counter set is already registered - if (!CheckIfCounterSetIsRegistered(counterSetUidValue)) + if (!IsCounterSetRegistered(counterSetUidValue)) { throw InvalidArgumentException( boost::str(boost::format("Trying to connect a category (name: \"%1%\") to a counter set " @@ -92,7 +92,7 @@ const Device* CounterDirectory::RegisterDevice(const std::string& deviceName, } // Check that a device with the given name is not already registered - if (CheckIfDeviceIsRegistered(deviceName)) + if (IsDeviceRegistered(deviceName)) { throw InvalidArgumentException( boost::str(boost::format("Trying to register a device already registered (\"%1%\")") @@ -188,7 +188,7 @@ const CounterSet* CounterDirectory::RegisterCounterSet(const std::string& counte } // Check that a counter set with the given name is not already registered - if (CheckIfCounterSetIsRegistered(counterSetName)) + if (IsCounterSetRegistered(counterSetName)) { throw InvalidArgumentException( boost::str(boost::format("Trying to register a counter set already registered (\"%1%\")") @@ -365,7 +365,7 @@ const Counter* CounterDirectory::RegisterCounter(const std::string& parentCatego if (counterSetUidValue > 0) { // Check that the (optional) counter set is already registered - if (!CheckIfCounterSetIsRegistered(counterSetUidValue)) + if (!IsCounterSetRegistered(counterSetUidValue)) { throw InvalidArgumentException( boost::str(boost::format("Trying to connect a counter to a counter set that is " @@ -476,6 +476,64 @@ const Counter* CounterDirectory::GetCounter(uint16_t counterUid) const return counter; } +bool CounterDirectory::IsCategoryRegistered(const std::string& categoryName) const +{ + auto it = FindCategory(categoryName); + + return it != m_Categories.end(); +} + +bool CounterDirectory::IsDeviceRegistered(uint16_t deviceUid) const +{ + auto it = FindDevice(deviceUid); + + return it != m_Devices.end(); +} + +bool CounterDirectory::IsDeviceRegistered(const std::string& deviceName) const +{ + auto it = FindDevice(deviceName); + + return it != m_Devices.end(); +} + +bool CounterDirectory::IsCounterSetRegistered(uint16_t counterSetUid) const +{ + auto it = FindCounterSet(counterSetUid); + + return it != m_CounterSets.end(); +} + +bool CounterDirectory::IsCounterSetRegistered(const std::string& counterSetName) const +{ + auto it = FindCounterSet(counterSetName); + + return it != m_CounterSets.end(); +} + +bool CounterDirectory::IsCounterRegistered(uint16_t counterUid) const +{ + auto it = FindCounter(counterUid); + + return it != m_Counters.end(); +} + +bool CounterDirectory::IsCounterRegistered(const std::string& counterName) const +{ + auto it = FindCounter(counterName); + + return it != m_Counters.end(); +} + +void CounterDirectory::Clear() +{ + // Clear all the counter directory contents + m_Categories.clear(); + m_Devices.clear(); + m_CounterSets.clear(); + m_Counters.clear(); +} + CategoriesIt CounterDirectory::FindCategory(const std::string& categoryName) const { return std::find_if(m_Categories.begin(), m_Categories.end(), [&categoryName](const CategoryPtr& category) @@ -523,39 +581,15 @@ CountersIt CounterDirectory::FindCounter(uint16_t counterUid) const return m_Counters.find(counterUid); } -bool CounterDirectory::CheckIfCategoryIsRegistered(const std::string& categoryName) const -{ - auto it = FindCategory(categoryName); - - return it != m_Categories.end(); -} - -bool CounterDirectory::CheckIfDeviceIsRegistered(uint16_t deviceUid) const -{ - auto it = FindDevice(deviceUid); - - return it != m_Devices.end(); -} - -bool CounterDirectory::CheckIfDeviceIsRegistered(const std::string& deviceName) const +CountersIt CounterDirectory::FindCounter(const std::string& counterName) const { - auto it = FindDevice(deviceName); - - return it != m_Devices.end(); -} - -bool CounterDirectory::CheckIfCounterSetIsRegistered(uint16_t counterSetUid) const -{ - auto it = FindCounterSet(counterSetUid); - - return it != m_CounterSets.end(); -} - -bool CounterDirectory::CheckIfCounterSetIsRegistered(const std::string& counterSetName) const -{ - auto it = FindCounterSet(counterSetName); + return std::find_if(m_Counters.begin(), m_Counters.end(), [&counterName](const auto& pair) + { + BOOST_ASSERT(pair.second); + BOOST_ASSERT(pair.second->m_Uid == pair.first); - return it != m_CounterSets.end(); + return pair.second->m_Name == counterName; + }); } uint16_t CounterDirectory::GetNumberOfCores(const Optional& numberOfCores, diff --git a/src/profiling/CounterDirectory.hpp b/src/profiling/CounterDirectory.hpp index a756a9a7bd..bff5cfef98 100644 --- a/src/profiling/CounterDirectory.hpp +++ b/src/profiling/CounterDirectory.hpp @@ -66,6 +66,18 @@ public: const CounterSet* GetCounterSet(uint16_t uid) const override; const Counter* GetCounter(uint16_t uid) const override; + // Queries for profiling objects + bool IsCategoryRegistered(const std::string& categoryName) const; + bool IsDeviceRegistered(uint16_t deviceUid) const; + bool IsDeviceRegistered(const std::string& deviceName) const; + bool IsCounterSetRegistered(uint16_t counterSetUid) const; + bool IsCounterSetRegistered(const std::string& counterSetName) const; + bool IsCounterRegistered(uint16_t counterUid) const; + bool IsCounterRegistered(const std::string& counterName) const; + + // Clears all the counter directory contents + void Clear(); + private: // The profiling collections owned by the counter directory Categories m_Categories; @@ -80,14 +92,10 @@ private: CounterSetsIt FindCounterSet(uint16_t counterSetUid) const; CounterSetsIt FindCounterSet(const std::string& counterSetName) const; CountersIt FindCounter(uint16_t counterUid) const; - bool CheckIfCategoryIsRegistered(const std::string& categoryName) const; - bool CheckIfDeviceIsRegistered(uint16_t deviceUid) const; - bool CheckIfDeviceIsRegistered(const std::string& deviceName) const; - bool CheckIfCounterSetIsRegistered(uint16_t counterSetUid) const; - bool CheckIfCounterSetIsRegistered(const std::string& counterSetName) const; - uint16_t GetNumberOfCores(const Optional& numberOfCores, - uint16_t deviceUid, - const CategoryPtr& parentCategory); + CountersIt FindCounter(const std::string& counterName) const; + uint16_t GetNumberOfCores(const Optional& numberOfCores, + uint16_t deviceUid, + const CategoryPtr& parentCategory); }; } // namespace profiling diff --git a/src/profiling/CounterValues.hpp b/src/profiling/CounterValues.hpp index 75ecad9961..9c06ff0a7d 100644 --- a/src/profiling/CounterValues.hpp +++ b/src/profiling/CounterValues.hpp @@ -15,24 +15,30 @@ namespace profiling class IReadCounterValues { public: - virtual uint16_t GetCounterCount() const = 0; - virtual void GetCounterValue(uint16_t index, uint32_t& value) const = 0; virtual ~IReadCounterValues() {} + + virtual uint16_t GetCounterCount() const = 0; + virtual uint32_t GetCounterValue(uint16_t counterUid) const = 0; }; -class IWriteCounterValues : public IReadCounterValues +class IWriteCounterValues { public: - virtual void SetCounterValue(uint16_t index, uint32_t value) = 0; - virtual void AddCounterValue(uint16_t index, uint32_t value) = 0; - virtual void SubtractCounterValue(uint16_t index, uint32_t value) = 0; - virtual void IncrementCounterValue(uint16_t index) = 0; - virtual void DecrementCounterValue(uint16_t index) = 0; virtual ~IWriteCounterValues() {} + + virtual void SetCounterValue(uint16_t counterUid, uint32_t value) = 0; + virtual uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) = 0; + virtual uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) = 0; + virtual uint32_t IncrementCounterValue(uint16_t counterUid) = 0; + virtual uint32_t DecrementCounterValue(uint16_t counterUid) = 0; +}; + +class IReadWriteCounterValues : public IReadCounterValues, public IWriteCounterValues +{ +public: + virtual ~IReadWriteCounterValues() {} }; } // namespace profiling } // namespace armnn - - diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp index 786bfae12e..2da0f79da2 100644 --- a/src/profiling/ProfilingService.cpp +++ b/src/profiling/ProfilingService.cpp @@ -5,72 +5,53 @@ #include "ProfilingService.hpp" +#include +#include + namespace armnn { namespace profiling { -ProfilingService::ProfilingService(const Runtime::CreationOptions::ExternalProfilingOptions& options) - : m_Options(options) +void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOptions& options, + bool resetProfilingService) { - Initialise(); -} + // Update the profiling options + m_Options = options; -void ProfilingService::Initialise() -{ - if (m_Options.m_EnableProfiling == true) + if (resetProfilingService) { - // Setup provisional Counter Directory example - this should only be created if profiling is enabled - // Setup provisional Counter meta example - const std::string categoryName = "Category"; - - m_CounterDirectory.RegisterCategory(categoryName); - m_CounterDirectory.RegisterDevice("device name", 0, categoryName); - m_CounterDirectory.RegisterCounterSet("counterSet_name", 2, categoryName); - - m_CounterDirectory.RegisterCounter(categoryName, - 0, - 1, - 123.45f, - "counter name 1", - "counter description"); - - m_CounterDirectory.RegisterCounter(categoryName, - 0, - 1, - 123.45f, - "counter name 2", - "counter description"); - - for (unsigned short i = 0; i < m_CounterDirectory.GetCounterCount(); ++i) - { - m_CounterIdToValue[i] = 0; - } - - // For now until CounterDirectory setup is implemented, change m_State once everything initialised - m_State.TransitionToState(ProfilingState::NotConnected); + // Reset the profiling service + m_CounterDirectory.Clear(); + m_ProfilingConnection.reset(); + m_StateMachine.Reset(); + m_CounterIndex.clear(); + m_CounterValues.clear(); } + + // Re-initialize the profiling service + Initialize(); } void ProfilingService::Run() { - if (m_State.GetCurrentState() == ProfilingState::NotConnected) + if (m_StateMachine.GetCurrentState() == ProfilingState::Uninitialised) + { + Initialize(); + } + else if (m_StateMachine.GetCurrentState() == ProfilingState::NotConnected) { try { - m_Factory.GetProfilingConnection(m_Options); - m_State.TransitionToState(ProfilingState::WaitingForAck); + m_ProfilingConnectionFactory.GetProfilingConnection(m_Options); + m_StateMachine.TransitionToState(ProfilingState::WaitingForAck); } catch (const armnn::Exception& e) { std::cerr << e.what() << std::endl; } } - else if (m_State.GetCurrentState() == ProfilingState::Uninitialised && m_Options.m_EnableProfiling == true) - { - Initialise(); - } } const ICounterDirectory& ProfilingService::GetCounterDirectory() const @@ -78,71 +59,143 @@ const ICounterDirectory& ProfilingService::GetCounterDirectory() const return m_CounterDirectory; } -void ProfilingService::SetCounterValue(uint16_t counterIndex, uint32_t value) +ProfilingState ProfilingService::GetCurrentState() const { - CheckIndexSize(counterIndex); - m_CounterIdToValue.at(counterIndex).store(value, std::memory_order::memory_order_relaxed); + return m_StateMachine.GetCurrentState(); } -void ProfilingService::GetCounterValue(uint16_t counterIndex, uint32_t& value) const +uint16_t ProfilingService::GetCounterCount() const { - CheckIndexSize(counterIndex); - value = m_CounterIdToValue.at(counterIndex).load(std::memory_order::memory_order_relaxed); + return m_CounterDirectory.GetCounterCount(); } -void ProfilingService::AddCounterValue(uint16_t counterIndex, uint32_t value) +uint32_t ProfilingService::GetCounterValue(uint16_t counterUid) const { - CheckIndexSize(counterIndex); - m_CounterIdToValue.at(counterIndex).fetch_add(value, std::memory_order::memory_order_relaxed); + BOOST_ASSERT(counterUid < m_CounterIndex.size()); + std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); + BOOST_ASSERT(counterValuePtr); + return counterValuePtr->load(std::memory_order::memory_order_relaxed); } -void ProfilingService::SubtractCounterValue(uint16_t counterIndex, uint32_t value) +void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value) { - CheckIndexSize(counterIndex); - m_CounterIdToValue.at(counterIndex).fetch_sub(value, std::memory_order::memory_order_relaxed); + BOOST_ASSERT(counterUid < m_CounterIndex.size()); + std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); + BOOST_ASSERT(counterValuePtr); + counterValuePtr->store(value, std::memory_order::memory_order_relaxed); } -void ProfilingService::IncrementCounterValue(uint16_t counterIndex) +uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value) { - CheckIndexSize(counterIndex); - m_CounterIdToValue.at(counterIndex).operator++(std::memory_order::memory_order_relaxed); + BOOST_ASSERT(counterUid < m_CounterIndex.size()); + std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); + BOOST_ASSERT(counterValuePtr); + return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed); } -void ProfilingService::DecrementCounterValue(uint16_t counterIndex) +uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value) { - CheckIndexSize(counterIndex); - m_CounterIdToValue.at(counterIndex).operator--(std::memory_order::memory_order_relaxed); + BOOST_ASSERT(counterUid < m_CounterIndex.size()); + std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); + BOOST_ASSERT(counterValuePtr); + return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed); } -uint16_t ProfilingService::GetCounterCount() const +uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid) { - return m_CounterDirectory.GetCounterCount(); + BOOST_ASSERT(counterUid < m_CounterIndex.size()); + std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); + BOOST_ASSERT(counterValuePtr); + return counterValuePtr->operator++(std::memory_order::memory_order_relaxed); } -ProfilingState ProfilingService::GetCurrentState() const +uint32_t ProfilingService::DecrementCounterValue(uint16_t counterUid) { - return m_State.GetCurrentState(); + BOOST_ASSERT(counterUid < m_CounterIndex.size()); + std::atomic* counterValuePtr = m_CounterIndex.at(counterUid); + BOOST_ASSERT(counterValuePtr); + return counterValuePtr->operator--(std::memory_order::memory_order_relaxed); } -void ProfilingService::ResetExternalProfilingOptions(const Runtime::CreationOptions::ExternalProfilingOptions& options) +void ProfilingService::Initialize() { - if(!m_Options.m_EnableProfiling) + if (!m_Options.m_EnableProfiling) { - m_Options = options; - Initialise(); + // Skip the initialization if profiling is disabled return; } - m_Options = options; + + // Register a category for the basic runtime counters + if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime")) + { + m_CounterDirectory.RegisterCategory("ArmNN_Runtime"); + } + + // Register a counter for the number of loaded networks + if (!m_CounterDirectory.IsCounterRegistered("Loaded networks")) + { + const Counter* loadedNetworksCounter = + m_CounterDirectory.RegisterCounter("ArmNN_Runtime", + 0, + 0, + 1.f, + "Loaded networks", + "The number of networks loaded at runtime", + std::string("networks")); + BOOST_ASSERT(loadedNetworksCounter); + InitializeCounterValue(loadedNetworksCounter->m_Uid); + } + + // Register a counter for the number of registered backends + if (!m_CounterDirectory.IsCounterRegistered("Registered backends")) + { + const Counter* registeredBackendsCounter = + m_CounterDirectory.RegisterCounter("ArmNN_Runtime", + 0, + 0, + 1.f, + "Registered backends", + "The number of registered backends", + std::string("backends")); + BOOST_ASSERT(registeredBackendsCounter); + InitializeCounterValue(registeredBackendsCounter->m_Uid); + } + + // Register a counter for the number of inferences run + if (!m_CounterDirectory.IsCounterRegistered("Inferences run")) + { + const Counter* inferencesRunCounter = + m_CounterDirectory.RegisterCounter("ArmNN_Runtime", + 0, + 0, + 1.f, + "Inferences run", + "The number of inferences run", + std::string("inferences")); + BOOST_ASSERT(inferencesRunCounter); + InitializeCounterValue(inferencesRunCounter->m_Uid); + } + + // Initialization is done, update the profiling service state + m_StateMachine.TransitionToState(ProfilingState::NotConnected); } -inline void ProfilingService::CheckIndexSize(uint16_t counterIndex) const +void ProfilingService::InitializeCounterValue(uint16_t counterUid) { - if (counterIndex >= m_CounterDirectory.GetCounterCount()) + // Increase the size of the counter index if necessary + if (counterUid >= m_CounterIndex.size()) { - throw InvalidArgumentException("Counter index is out of range"); + m_CounterIndex.resize(boost::numeric_cast(counterUid) + 1); } + + // Create a new atomic counter and add it to the list + m_CounterValues.emplace_back(0); + + // Register the new counter to the counter index for quick access + std::atomic* counterValuePtr = &(m_CounterValues.back()); + m_CounterIndex.at(counterUid) = counterValuePtr; } } // namespace profiling -} // namespace armnn \ No newline at end of file +} // namespace armnn diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp index 6d617978e5..36d95e0b5e 100644 --- a/src/profiling/ProfilingService.hpp +++ b/src/profiling/ProfilingService.hpp @@ -16,38 +16,63 @@ namespace armnn namespace profiling { -class ProfilingService : IWriteCounterValues +class ProfilingService final : public IReadWriteCounterValues { public: - ProfilingService(const Runtime::CreationOptions::ExternalProfilingOptions& options); - ~ProfilingService() = default; + using ExternalProfilingOptions = Runtime::CreationOptions::ExternalProfilingOptions; + using IProfilingConnectionPtr = std::unique_ptr; + using CounterIndices = std::vector*>; + using CounterValues = std::list>; + + // Getter for the singleton instance + static ProfilingService& Instance() + { + static ProfilingService instance; + return instance; + } + + // Resets the profiling options, optionally clears the profiling service entirely + void ResetExternalProfilingOptions(const ExternalProfilingOptions& options, bool resetProfilingService = false); + // Runs the profiling service void Run(); + // Getters for the profiling service state const ICounterDirectory& GetCounterDirectory() const; ProfilingState GetCurrentState() const; - void ResetExternalProfilingOptions(const Runtime::CreationOptions::ExternalProfilingOptions& options); + uint16_t GetCounterCount() const override; + uint32_t GetCounterValue(uint16_t counterUid) const override; - uint16_t GetCounterCount() const; - void GetCounterValue(uint16_t index, uint32_t& value) const; - void SetCounterValue(uint16_t index, uint32_t value); - void AddCounterValue(uint16_t index, uint32_t value); - void SubtractCounterValue(uint16_t index, uint32_t value); - void IncrementCounterValue(uint16_t index); - void DecrementCounterValue(uint16_t index); + // Setters for the profiling service state + void SetCounterValue(uint16_t counterUid, uint32_t value) override; + uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) override; + uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) override; + uint32_t IncrementCounterValue(uint16_t counterUid) override; + uint32_t DecrementCounterValue(uint16_t counterUid) override; private: - void Initialise(); - void CheckIndexSize(uint16_t counterIndex) const; + // Default/copy/move constructors/destructors and copy/move assignment operators are kept private + ProfilingService() = default; + ProfilingService(const ProfilingService&) = delete; + ProfilingService(ProfilingService&&) = delete; + ProfilingService& operator=(const ProfilingService&) = delete; + ProfilingService& operator=(ProfilingService&&) = delete; + ~ProfilingService() = default; - CounterDirectory m_CounterDirectory; - ProfilingConnectionFactory m_Factory; - Runtime::CreationOptions::ExternalProfilingOptions m_Options; - ProfilingStateMachine m_State; + // Initialization functions + void Initialize(); + void InitializeCounterValue(uint16_t counterUid); - std::unordered_map> m_CounterIdToValue; + // Profiling service state variables + ExternalProfilingOptions m_Options; + CounterDirectory m_CounterDirectory; + ProfilingConnectionFactory m_ProfilingConnectionFactory; + IProfilingConnectionPtr m_ProfilingConnection; + ProfilingStateMachine m_StateMachine; + CounterIndices m_CounterIndex; + CounterValues m_CounterValues; }; } // namespace profiling -} // namespace armnn \ No newline at end of file +} // namespace armnn diff --git a/src/profiling/ProfilingStateMachine.cpp b/src/profiling/ProfilingStateMachine.cpp index 682e1b8894..5af5bfbed0 100644 --- a/src/profiling/ProfilingStateMachine.cpp +++ b/src/profiling/ProfilingStateMachine.cpp @@ -7,87 +7,89 @@ #include +#include + namespace armnn { namespace profiling { +namespace +{ + +void ThrowStateTransitionException(ProfilingState expectedState, ProfilingState newState) +{ + std::stringstream ss; + ss << "Cannot transition from state [" << GetProfilingStateName(expectedState) << "] " + << "to state [" << GetProfilingStateName(newState) << "]"; + throw armnn::RuntimeException(ss.str()); +} + +} // Anonymous namespace + ProfilingState ProfilingStateMachine::GetCurrentState() const { - return m_State; + return m_State.load(); } void ProfilingStateMachine::TransitionToState(ProfilingState newState) { - switch (newState) - { - case ProfilingState::Uninitialised: - { - ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); - do { - if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised)) - { - throw armnn::Exception(std::string("Cannot transition from state [") - + GetProfilingStateName(expectedState) - +"] to [" + GetProfilingStateName(newState) + "]"); - } - } while (!m_State.compare_exchange_strong(expectedState, newState, - std::memory_order::memory_order_relaxed)); - - break; - } - case ProfilingState::NotConnected: - { - ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); - do { - if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised, ProfilingState::NotConnected, - ProfilingState::Active)) - { - throw armnn::Exception(std::string("Cannot transition from state [") - + GetProfilingStateName(expectedState) - +"] to [" + GetProfilingStateName(newState) + "]"); - } - } while (!m_State.compare_exchange_strong(expectedState, newState, - std::memory_order::memory_order_relaxed)); - - break; - } - case ProfilingState::WaitingForAck: - { - ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); - do { - if (!IsOneOfStates(expectedState, ProfilingState::NotConnected, ProfilingState::WaitingForAck)) - { - throw armnn::Exception(std::string("Cannot transition from state [") - + GetProfilingStateName(expectedState) - +"] to [" + GetProfilingStateName(newState) + "]"); - } - } while (!m_State.compare_exchange_strong(expectedState, newState, - std::memory_order::memory_order_relaxed)); + ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); - break; - } - case ProfilingState::Active: - { - ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); - do { - if (!IsOneOfStates(expectedState, ProfilingState::WaitingForAck, ProfilingState::Active)) - { - throw armnn::Exception(std::string("Cannot transition from state [") - + GetProfilingStateName(expectedState) - +"] to [" + GetProfilingStateName(newState) + "]"); - } - } while (!m_State.compare_exchange_strong(expectedState, newState, - std::memory_order::memory_order_relaxed)); + switch (newState) + { + case ProfilingState::Uninitialised: + do + { + if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised)) + { + ThrowStateTransitionException(expectedState, newState); + } + } + while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed)); + break; + case ProfilingState::NotConnected: + do + { + if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised, ProfilingState::NotConnected, + ProfilingState::Active)) + { + ThrowStateTransitionException(expectedState, newState); + } + } + while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed)); + break; + case ProfilingState::WaitingForAck: + do + { + if (!IsOneOfStates(expectedState, ProfilingState::NotConnected, ProfilingState::WaitingForAck)) + { + ThrowStateTransitionException(expectedState, newState); + } + } + while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed)); + break; + case ProfilingState::Active: + do + { + if (!IsOneOfStates(expectedState, ProfilingState::WaitingForAck, ProfilingState::Active)) + { + ThrowStateTransitionException(expectedState, newState); + } + } + while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed)); + break; + default: + break; + } +} - break; - } - default: - break; - } +void ProfilingStateMachine::Reset() +{ + m_State.store(ProfilingState::Uninitialised); } -} //namespace profiling +} // namespace profiling -} //namespace armnn \ No newline at end of file +} // namespace armnn diff --git a/src/profiling/ProfilingStateMachine.hpp b/src/profiling/ProfilingStateMachine.hpp index 66f8b2cd17..d070744b1b 100644 --- a/src/profiling/ProfilingStateMachine.hpp +++ b/src/profiling/ProfilingStateMachine.hpp @@ -2,6 +2,7 @@ // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // + #pragma once #include @@ -23,11 +24,12 @@ enum class ProfilingState class ProfilingStateMachine { public: - ProfilingStateMachine(): m_State(ProfilingState::Uninitialised) {}; - ProfilingStateMachine(ProfilingState state): m_State(state) {}; + ProfilingStateMachine() : m_State(ProfilingState::Uninitialised) {} + ProfilingStateMachine(ProfilingState state) : m_State(state) {} ProfilingState GetCurrentState() const; void TransitionToState(ProfilingState newState); + void Reset(); bool IsOneOfStates(ProfilingState state1) { @@ -53,17 +55,17 @@ private: constexpr char const* GetProfilingStateName(ProfilingState state) { - switch(state) + switch (state) { - case ProfilingState::Uninitialised: return "Uninitialised"; - case ProfilingState::NotConnected: return "NotConnected"; - case ProfilingState::WaitingForAck: return "WaitingForAck"; - case ProfilingState::Active: return "Active"; - default: return "Unknown"; + case ProfilingState::Uninitialised: return "Uninitialised"; + case ProfilingState::NotConnected: return "NotConnected"; + case ProfilingState::WaitingForAck: return "WaitingForAck"; + case ProfilingState::Active: return "Active"; + default: return "Unknown"; } } -} //namespace profiling +} // namespace profiling -} //namespace armnn +} // namespace armnn diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp index 9aafa2ccc8..7f3696a940 100644 --- a/src/profiling/SendCounterPacket.cpp +++ b/src/profiling/SendCounterPacket.cpp @@ -903,7 +903,7 @@ void SendCounterPacket::SetReadyToRead() void SendCounterPacket::Start() { - // Check is the send thread is already running + // Check if the send thread is already running if (m_IsRunning.load()) { // The send thread is already running diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp index 45b7f9dcfe..6955f70a48 100644 --- a/src/profiling/SocketProfilingConnection.cpp +++ b/src/profiling/SocketProfilingConnection.cpp @@ -23,7 +23,7 @@ SocketProfilingConnection::SocketProfilingConnection() m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); if (m_Socket[0].fd == -1) { - throw armnn::RuntimeException(std::string("Socket construction failed: ") + strerror(errno)); + throw RuntimeException(std::string("Socket construction failed: ") + strerror(errno)); } // Connect to the named unix domain socket. @@ -35,7 +35,7 @@ SocketProfilingConnection::SocketProfilingConnection() if (0 != connect(m_Socket[0].fd, reinterpret_cast(&server), sizeof(sockaddr_un))) { close(m_Socket[0].fd); - throw armnn::RuntimeException(std::string("Cannot connect to stream socket: ") + strerror(errno)); + throw RuntimeException(std::string("Cannot connect to stream socket: ") + strerror(errno)); } // Our socket will only be interested in polling reads. @@ -46,7 +46,7 @@ SocketProfilingConnection::SocketProfilingConnection() if (0 != fcntl(m_Socket[0].fd, F_SETFL, currentFlags | O_NONBLOCK)) { close(m_Socket[0].fd); - throw armnn::RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno)); + throw RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno)); } } @@ -59,7 +59,7 @@ void SocketProfilingConnection::Close() { if (close(m_Socket[0].fd) != 0) { - throw armnn::RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno)); + throw RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno)); } memset(m_Socket, 0, sizeof(m_Socket)); @@ -83,17 +83,17 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) switch (pollResult) { case -1: // Error - throw armnn::RuntimeException(std::string("Read failure from socket: ") + strerror(errno)); + throw RuntimeException(std::string("Read failure from socket: ") + strerror(errno)); case 0: // Timeout - throw armnn::RuntimeException("Timeout while reading from socket"); + throw RuntimeException("Timeout while reading from socket"); default: // Normal poll return but it could still contain an error signal // Check if the socket reported an error if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP)) { - throw armnn::Exception(std::string("Socket 0 reported an error: ") + strerror(errno)); + throw Exception(std::string("Socket 0 reported an error: ") + strerror(errno)); } // Check if there is data to read @@ -108,7 +108,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) if (8 != recv(m_Socket[0].fd, &header, sizeof(header), 0)) { // What do we do here if there's not a valid 8 byte header to read? - throw armnn::RuntimeException("The received packet did not contains a valid MIPE header"); + throw RuntimeException("The received packet did not contains a valid MIPE header"); } // stream_metadata_identifier is the first 4 bytes @@ -133,7 +133,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) if (dataLength != static_cast(receivedLength)) { // What do we do here if we can't read in a full packet? - throw armnn::RuntimeException("Invalid MIPE packet"); + throw RuntimeException("Invalid MIPE packet"); } return Packet(metadataIdentifier, dataLength, packetData); diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 5ef9811b2c..d14791c43d 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -659,19 +659,18 @@ BOOST_AUTO_TEST_CASE(CaptureDataMethods) BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisabled) { armnn::Runtime::CreationOptions::ExternalProfilingOptions options; - ProfilingService service(options); - BOOST_CHECK(service.GetCurrentState() == ProfilingState::Uninitialised); - service.Run(); - BOOST_CHECK(service.GetCurrentState() == ProfilingState::Uninitialised); + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Run(); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); } -struct cerr_redirect { +struct cerr_redirect +{ cerr_redirect(std::streambuf* new_buffer) - : old( std::cerr.rdbuf(new_buffer)) {} - - ~cerr_redirect( ) { - std::cerr.rdbuf(old); - } + : old(std::cerr.rdbuf(new_buffer)) {} + ~cerr_redirect() { std::cerr.rdbuf(old); } private: std::streambuf* old; @@ -681,46 +680,49 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled) { armnn::Runtime::CreationOptions::ExternalProfilingOptions options; options.m_EnableProfiling = true; - ProfilingService service(options); - BOOST_CHECK(service.GetCurrentState() == ProfilingState::NotConnected); + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); // As there is no daemon running a connection cannot be made so expect a std::cerr to console std::stringstream ss; cerr_redirect guard(ss.rdbuf()); - service.Run(); + profilingService.Run(); BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused")); } BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabledRuntime) { armnn::Runtime::CreationOptions::ExternalProfilingOptions options; - ProfilingService service(options); - BOOST_CHECK(service.GetCurrentState() == ProfilingState::Uninitialised); - service.Run(); - BOOST_CHECK(service.GetCurrentState() == ProfilingState::Uninitialised); + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); + profilingService.Run(); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); options.m_EnableProfiling = true; - service.ResetExternalProfilingOptions(options); - BOOST_CHECK(service.GetCurrentState() == ProfilingState::NotConnected); + profilingService.ResetExternalProfilingOptions(options); + BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected); // As there is no daemon running a connection cannot be made so expect a std::cerr to console std::stringstream ss; cerr_redirect guard(ss.rdbuf()); - service.Run(); + profilingService.Run(); BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused")); } BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterDirectory) { armnn::Runtime::CreationOptions::ExternalProfilingOptions options; - ProfilingService service(options); + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); - const ICounterDirectory& counterDirectory0 = service.GetCounterDirectory(); + const ICounterDirectory& counterDirectory0 = profilingService.GetCounterDirectory(); BOOST_CHECK(counterDirectory0.GetCounterCount() == 0); options.m_EnableProfiling = true; - service.ResetExternalProfilingOptions(options); + profilingService.ResetExternalProfilingOptions(options); - const ICounterDirectory& counterDirectory1 = service.GetCounterDirectory(); + const ICounterDirectory& counterDirectory1 = profilingService.GetCounterDirectory(); BOOST_CHECK(counterDirectory1.GetCounterCount() != 0); } @@ -728,32 +730,38 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterValues) { armnn::Runtime::CreationOptions::ExternalProfilingOptions options; options.m_EnableProfiling = true; - ProfilingService profilingService(options); + ProfilingService& profilingService = ProfilingService::Instance(); + profilingService.ResetExternalProfilingOptions(options, true); + + const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory(); + const Counters& counters = counterDirectory.GetCounters(); + BOOST_CHECK(!counters.empty()); + + // Get the UID of the first counter for testing + uint16_t counterUid = counters.begin()->first; ProfilingService* profilingServicePtr = &profilingService; std::vector writers; - for(int i = 0; i < 100 ; ++i) + for (int i = 0; i < 100 ; ++i) { - // Increment and decrement counter 0 - writers.push_back(std::thread(&ProfilingService::IncrementCounterValue, profilingServicePtr, 0)); - writers.push_back(std::thread(&ProfilingService::DecrementCounterValue, profilingServicePtr, 0)); - // Add 10 to counter 0 and subtract 5 from counter 0 - writers.push_back(std::thread(&ProfilingService::AddCounterValue, profilingServicePtr, 0, 10)); - writers.push_back(std::thread(&ProfilingService::SubtractCounterValue, profilingServicePtr, 0, 5)); + // Increment and decrement the first counter + writers.push_back(std::thread(&ProfilingService::IncrementCounterValue, profilingServicePtr, counterUid)); + writers.push_back(std::thread(&ProfilingService::DecrementCounterValue, profilingServicePtr, counterUid)); + // Add 10 and subtract 5 from the first counter + writers.push_back(std::thread(&ProfilingService::AddCounterValue, profilingServicePtr, counterUid, 10)); + writers.push_back(std::thread(&ProfilingService::SubtractCounterValue, profilingServicePtr, counterUid, 5)); } std::for_each(writers.begin(), writers.end(), mem_fn(&std::thread::join)); - uint32_t counterValue; - profilingService.GetCounterValue(0, counterValue); + uint32_t counterValue = 0; + BOOST_CHECK_NO_THROW(counterValue = profilingService.GetCounterValue(counterUid)); BOOST_CHECK(counterValue == 500); - profilingService.SetCounterValue(0, 0); - profilingService.GetCounterValue(0, counterValue); + BOOST_CHECK_NO_THROW(profilingService.SetCounterValue(counterUid, 0)); + BOOST_CHECK_NO_THROW(counterValue = profilingService.GetCounterValue(counterUid)); BOOST_CHECK(counterValue == 0); - - BOOST_CHECK_THROW(profilingService.SetCounterValue(profilingService.GetCounterCount(), 1), armnn::Exception); } BOOST_AUTO_TEST_CASE(CheckProfilingObjectUids) -- cgit v1.2.1