aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-10-02 12:50:57 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-10-07 10:34:54 +0100
commita84edee4702c112a6e004b1987acc11144e2d6dd (patch)
tree738ce957b2fa26423df188b0d370664d15c86665
parentd66d68b13fb309e8d4eac9435a58b89dd6a55158 (diff)
downloadarmnn-a84edee4702c112a6e004b1987acc11144e2d6dd.tar.gz
IVGCVSW-3937 Initial ServiceProfiling refactoring
* Made the ServiceProfiling class a singleton * Registered basic category and counters * Code refactoring * Updated unit tests accordingly Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com> Change-Id: I648a6202eead2a3016aac14d905511bd945a90cb
-rw-r--r--src/profiling/CounterDirectory.cpp108
-rw-r--r--src/profiling/CounterDirectory.hpp24
-rw-r--r--src/profiling/CounterValues.hpp26
-rw-r--r--src/profiling/ProfilingService.cpp201
-rw-r--r--src/profiling/ProfilingService.hpp63
-rw-r--r--src/profiling/ProfilingStateMachine.cpp136
-rw-r--r--src/profiling/ProfilingStateMachine.hpp22
-rw-r--r--src/profiling/SendCounterPacket.cpp2
-rw-r--r--src/profiling/SocketProfilingConnection.cpp18
-rw-r--r--src/profiling/test/ProfilingTests.cpp84
10 files changed, 411 insertions, 273 deletions
diff --git a/src/profiling/CounterDirectory.cpp b/src/profiling/CounterDirectory.cpp
index cef3d6a76d..979b8046be 100644
--- a/src/profiling/CounterDirectory.cpp
+++ b/src/profiling/CounterDirectory.cpp
@@ -29,7 +29,7 @@ const Category* CounterDirectory::RegisterCategory(const std::string& categoryNa
}
// Check that the given category is not already registered
- if (CheckIfCategoryIsRegistered(categoryName))
+ if (IsCategoryRegistered(categoryName))
{
throw InvalidArgumentException(
boost::str(boost::format("Trying to register a category already registered (\"%1%\")")
@@ -41,7 +41,7 @@ const Category* CounterDirectory::RegisterCategory(const std::string& categoryNa
if (deviceUidValue > 0)
{
// Check that the (optional) device is already registered
- if (!CheckIfDeviceIsRegistered(deviceUidValue))
+ if (!IsDeviceRegistered(deviceUidValue))
{
throw InvalidArgumentException(
boost::str(boost::format("Trying to connect a category (\"%1%\") to a device that is "
@@ -56,7 +56,7 @@ const Category* CounterDirectory::RegisterCategory(const std::string& categoryNa
if (counterSetUidValue > 0)
{
// Check that the (optional) counter set is already registered
- if (!CheckIfCounterSetIsRegistered(counterSetUidValue))
+ if (!IsCounterSetRegistered(counterSetUidValue))
{
throw InvalidArgumentException(
boost::str(boost::format("Trying to connect a category (name: \"%1%\") to a counter set "
@@ -92,7 +92,7 @@ const Device* CounterDirectory::RegisterDevice(const std::string& deviceName,
}
// Check that a device with the given name is not already registered
- if (CheckIfDeviceIsRegistered(deviceName))
+ if (IsDeviceRegistered(deviceName))
{
throw InvalidArgumentException(
boost::str(boost::format("Trying to register a device already registered (\"%1%\")")
@@ -188,7 +188,7 @@ const CounterSet* CounterDirectory::RegisterCounterSet(const std::string& counte
}
// Check that a counter set with the given name is not already registered
- if (CheckIfCounterSetIsRegistered(counterSetName))
+ if (IsCounterSetRegistered(counterSetName))
{
throw InvalidArgumentException(
boost::str(boost::format("Trying to register a counter set already registered (\"%1%\")")
@@ -365,7 +365,7 @@ const Counter* CounterDirectory::RegisterCounter(const std::string& parentCatego
if (counterSetUidValue > 0)
{
// Check that the (optional) counter set is already registered
- if (!CheckIfCounterSetIsRegistered(counterSetUidValue))
+ if (!IsCounterSetRegistered(counterSetUidValue))
{
throw InvalidArgumentException(
boost::str(boost::format("Trying to connect a counter to a counter set that is "
@@ -476,6 +476,64 @@ const Counter* CounterDirectory::GetCounter(uint16_t counterUid) const
return counter;
}
+bool CounterDirectory::IsCategoryRegistered(const std::string& categoryName) const
+{
+ auto it = FindCategory(categoryName);
+
+ return it != m_Categories.end();
+}
+
+bool CounterDirectory::IsDeviceRegistered(uint16_t deviceUid) const
+{
+ auto it = FindDevice(deviceUid);
+
+ return it != m_Devices.end();
+}
+
+bool CounterDirectory::IsDeviceRegistered(const std::string& deviceName) const
+{
+ auto it = FindDevice(deviceName);
+
+ return it != m_Devices.end();
+}
+
+bool CounterDirectory::IsCounterSetRegistered(uint16_t counterSetUid) const
+{
+ auto it = FindCounterSet(counterSetUid);
+
+ return it != m_CounterSets.end();
+}
+
+bool CounterDirectory::IsCounterSetRegistered(const std::string& counterSetName) const
+{
+ auto it = FindCounterSet(counterSetName);
+
+ return it != m_CounterSets.end();
+}
+
+bool CounterDirectory::IsCounterRegistered(uint16_t counterUid) const
+{
+ auto it = FindCounter(counterUid);
+
+ return it != m_Counters.end();
+}
+
+bool CounterDirectory::IsCounterRegistered(const std::string& counterName) const
+{
+ auto it = FindCounter(counterName);
+
+ return it != m_Counters.end();
+}
+
+void CounterDirectory::Clear()
+{
+ // Clear all the counter directory contents
+ m_Categories.clear();
+ m_Devices.clear();
+ m_CounterSets.clear();
+ m_Counters.clear();
+}
+
CategoriesIt CounterDirectory::FindCategory(const std::string& categoryName) const
{
return std::find_if(m_Categories.begin(), m_Categories.end(), [&categoryName](const CategoryPtr& category)
@@ -523,39 +581,15 @@ CountersIt CounterDirectory::FindCounter(uint16_t counterUid) const
return m_Counters.find(counterUid);
}
-bool CounterDirectory::CheckIfCategoryIsRegistered(const std::string& categoryName) const
-{
- auto it = FindCategory(categoryName);
-
- return it != m_Categories.end();
-}
-
-bool CounterDirectory::CheckIfDeviceIsRegistered(uint16_t deviceUid) const
-{
- auto it = FindDevice(deviceUid);
-
- return it != m_Devices.end();
-}
-
-bool CounterDirectory::CheckIfDeviceIsRegistered(const std::string& deviceName) const
+CountersIt CounterDirectory::FindCounter(const std::string& counterName) const
{
- auto it = FindDevice(deviceName);
-
- return it != m_Devices.end();
-}
-
-bool CounterDirectory::CheckIfCounterSetIsRegistered(uint16_t counterSetUid) const
-{
- auto it = FindCounterSet(counterSetUid);
-
- return it != m_CounterSets.end();
-}
-
-bool CounterDirectory::CheckIfCounterSetIsRegistered(const std::string& counterSetName) const
-{
- auto it = FindCounterSet(counterSetName);
+ return std::find_if(m_Counters.begin(), m_Counters.end(), [&counterName](const auto& pair)
+ {
+ BOOST_ASSERT(pair.second);
+ BOOST_ASSERT(pair.second->m_Uid == pair.first);
- return it != m_CounterSets.end();
+ return pair.second->m_Name == counterName;
+ });
}
uint16_t CounterDirectory::GetNumberOfCores(const Optional<uint16_t>& numberOfCores,
diff --git a/src/profiling/CounterDirectory.hpp b/src/profiling/CounterDirectory.hpp
index a756a9a7bd..bff5cfef98 100644
--- a/src/profiling/CounterDirectory.hpp
+++ b/src/profiling/CounterDirectory.hpp
@@ -66,6 +66,18 @@ public:
const CounterSet* GetCounterSet(uint16_t uid) const override;
const Counter* GetCounter(uint16_t uid) const override;
+ // Queries for profiling objects
+ bool IsCategoryRegistered(const std::string& categoryName) const;
+ bool IsDeviceRegistered(uint16_t deviceUid) const;
+ bool IsDeviceRegistered(const std::string& deviceName) const;
+ bool IsCounterSetRegistered(uint16_t counterSetUid) const;
+ bool IsCounterSetRegistered(const std::string& counterSetName) const;
+ bool IsCounterRegistered(uint16_t counterUid) const;
+ bool IsCounterRegistered(const std::string& counterName) const;
+
+ // Clears all the counter directory contents
+ void Clear();
+
private:
// The profiling collections owned by the counter directory
Categories m_Categories;
@@ -80,14 +92,10 @@ private:
CounterSetsIt FindCounterSet(uint16_t counterSetUid) const;
CounterSetsIt FindCounterSet(const std::string& counterSetName) const;
CountersIt FindCounter(uint16_t counterUid) const;
- bool CheckIfCategoryIsRegistered(const std::string& categoryName) const;
- bool CheckIfDeviceIsRegistered(uint16_t deviceUid) const;
- bool CheckIfDeviceIsRegistered(const std::string& deviceName) const;
- bool CheckIfCounterSetIsRegistered(uint16_t counterSetUid) const;
- bool CheckIfCounterSetIsRegistered(const std::string& counterSetName) const;
- uint16_t GetNumberOfCores(const Optional<uint16_t>& numberOfCores,
- uint16_t deviceUid,
- const CategoryPtr& parentCategory);
+ CountersIt FindCounter(const std::string& counterName) const;
+ uint16_t GetNumberOfCores(const Optional<uint16_t>& numberOfCores,
+ uint16_t deviceUid,
+ const CategoryPtr& parentCategory);
};
} // namespace profiling
diff --git a/src/profiling/CounterValues.hpp b/src/profiling/CounterValues.hpp
index 75ecad9961..9c06ff0a7d 100644
--- a/src/profiling/CounterValues.hpp
+++ b/src/profiling/CounterValues.hpp
@@ -15,24 +15,30 @@ namespace profiling
class IReadCounterValues
{
public:
- virtual uint16_t GetCounterCount() const = 0;
- virtual void GetCounterValue(uint16_t index, uint32_t& value) const = 0;
virtual ~IReadCounterValues() {}
+
+ virtual uint16_t GetCounterCount() const = 0;
+ virtual uint32_t GetCounterValue(uint16_t counterUid) const = 0;
};
-class IWriteCounterValues : public IReadCounterValues
+class IWriteCounterValues
{
public:
- virtual void SetCounterValue(uint16_t index, uint32_t value) = 0;
- virtual void AddCounterValue(uint16_t index, uint32_t value) = 0;
- virtual void SubtractCounterValue(uint16_t index, uint32_t value) = 0;
- virtual void IncrementCounterValue(uint16_t index) = 0;
- virtual void DecrementCounterValue(uint16_t index) = 0;
virtual ~IWriteCounterValues() {}
+
+ virtual void SetCounterValue(uint16_t counterUid, uint32_t value) = 0;
+ virtual uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) = 0;
+ virtual uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) = 0;
+ virtual uint32_t IncrementCounterValue(uint16_t counterUid) = 0;
+ virtual uint32_t DecrementCounterValue(uint16_t counterUid) = 0;
+};
+
+class IReadWriteCounterValues : public IReadCounterValues, public IWriteCounterValues
+{
+public:
+ virtual ~IReadWriteCounterValues() {}
};
} // namespace profiling
} // namespace armnn
-
-
diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp
index 786bfae12e..2da0f79da2 100644
--- a/src/profiling/ProfilingService.cpp
+++ b/src/profiling/ProfilingService.cpp
@@ -5,72 +5,53 @@
#include "ProfilingService.hpp"
+#include <boost/log/trivial.hpp>
+#include <boost/format.hpp>
+
namespace armnn
{
namespace profiling
{
-ProfilingService::ProfilingService(const Runtime::CreationOptions::ExternalProfilingOptions& options)
- : m_Options(options)
+void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOptions& options,
+ bool resetProfilingService)
{
- Initialise();
-}
+ // Update the profiling options
+ m_Options = options;
-void ProfilingService::Initialise()
-{
- if (m_Options.m_EnableProfiling == true)
+ if (resetProfilingService)
{
- // Setup provisional Counter Directory example - this should only be created if profiling is enabled
- // Setup provisional Counter meta example
- const std::string categoryName = "Category";
-
- m_CounterDirectory.RegisterCategory(categoryName);
- m_CounterDirectory.RegisterDevice("device name", 0, categoryName);
- m_CounterDirectory.RegisterCounterSet("counterSet_name", 2, categoryName);
-
- m_CounterDirectory.RegisterCounter(categoryName,
- 0,
- 1,
- 123.45f,
- "counter name 1",
- "counter description");
-
- m_CounterDirectory.RegisterCounter(categoryName,
- 0,
- 1,
- 123.45f,
- "counter name 2",
- "counter description");
-
- for (unsigned short i = 0; i < m_CounterDirectory.GetCounterCount(); ++i)
- {
- m_CounterIdToValue[i] = 0;
- }
-
- // For now until CounterDirectory setup is implemented, change m_State once everything initialised
- m_State.TransitionToState(ProfilingState::NotConnected);
+ // Reset the profiling service
+ m_CounterDirectory.Clear();
+ m_ProfilingConnection.reset();
+ m_StateMachine.Reset();
+ m_CounterIndex.clear();
+ m_CounterValues.clear();
}
+
+ // Re-initialize the profiling service
+ Initialize();
}
void ProfilingService::Run()
{
- if (m_State.GetCurrentState() == ProfilingState::NotConnected)
+ if (m_StateMachine.GetCurrentState() == ProfilingState::Uninitialised)
+ {
+ Initialize();
+ }
+ else if (m_StateMachine.GetCurrentState() == ProfilingState::NotConnected)
{
try
{
- m_Factory.GetProfilingConnection(m_Options);
- m_State.TransitionToState(ProfilingState::WaitingForAck);
+ m_ProfilingConnectionFactory.GetProfilingConnection(m_Options);
+ m_StateMachine.TransitionToState(ProfilingState::WaitingForAck);
}
catch (const armnn::Exception& e)
{
std::cerr << e.what() << std::endl;
}
}
- else if (m_State.GetCurrentState() == ProfilingState::Uninitialised && m_Options.m_EnableProfiling == true)
- {
- Initialise();
- }
}
const ICounterDirectory& ProfilingService::GetCounterDirectory() const
@@ -78,71 +59,143 @@ const ICounterDirectory& ProfilingService::GetCounterDirectory() const
return m_CounterDirectory;
}
-void ProfilingService::SetCounterValue(uint16_t counterIndex, uint32_t value)
+ProfilingState ProfilingService::GetCurrentState() const
{
- CheckIndexSize(counterIndex);
- m_CounterIdToValue.at(counterIndex).store(value, std::memory_order::memory_order_relaxed);
+ return m_StateMachine.GetCurrentState();
}
-void ProfilingService::GetCounterValue(uint16_t counterIndex, uint32_t& value) const
+uint16_t ProfilingService::GetCounterCount() const
{
- CheckIndexSize(counterIndex);
- value = m_CounterIdToValue.at(counterIndex).load(std::memory_order::memory_order_relaxed);
+ return m_CounterDirectory.GetCounterCount();
}
-void ProfilingService::AddCounterValue(uint16_t counterIndex, uint32_t value)
+uint32_t ProfilingService::GetCounterValue(uint16_t counterUid) const
{
- CheckIndexSize(counterIndex);
- m_CounterIdToValue.at(counterIndex).fetch_add(value, std::memory_order::memory_order_relaxed);
+ BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
+ BOOST_ASSERT(counterValuePtr);
+ return counterValuePtr->load(std::memory_order::memory_order_relaxed);
}
-void ProfilingService::SubtractCounterValue(uint16_t counterIndex, uint32_t value)
+void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
{
- CheckIndexSize(counterIndex);
- m_CounterIdToValue.at(counterIndex).fetch_sub(value, std::memory_order::memory_order_relaxed);
+ BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
+ BOOST_ASSERT(counterValuePtr);
+ counterValuePtr->store(value, std::memory_order::memory_order_relaxed);
}
-void ProfilingService::IncrementCounterValue(uint16_t counterIndex)
+uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
{
- CheckIndexSize(counterIndex);
- m_CounterIdToValue.at(counterIndex).operator++(std::memory_order::memory_order_relaxed);
+ BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
+ BOOST_ASSERT(counterValuePtr);
+ return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed);
}
-void ProfilingService::DecrementCounterValue(uint16_t counterIndex)
+uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value)
{
- CheckIndexSize(counterIndex);
- m_CounterIdToValue.at(counterIndex).operator--(std::memory_order::memory_order_relaxed);
+ BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
+ BOOST_ASSERT(counterValuePtr);
+ return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed);
}
-uint16_t ProfilingService::GetCounterCount() const
+uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
{
- return m_CounterDirectory.GetCounterCount();
+ BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
+ BOOST_ASSERT(counterValuePtr);
+ return counterValuePtr->operator++(std::memory_order::memory_order_relaxed);
}
-ProfilingState ProfilingService::GetCurrentState() const
+uint32_t ProfilingService::DecrementCounterValue(uint16_t counterUid)
{
- return m_State.GetCurrentState();
+ BOOST_ASSERT(counterUid < m_CounterIndex.size());
+ std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
+ BOOST_ASSERT(counterValuePtr);
+ return counterValuePtr->operator--(std::memory_order::memory_order_relaxed);
}
-void ProfilingService::ResetExternalProfilingOptions(const Runtime::CreationOptions::ExternalProfilingOptions& options)
+void ProfilingService::Initialize()
{
- if(!m_Options.m_EnableProfiling)
+ if (!m_Options.m_EnableProfiling)
{
- m_Options = options;
- Initialise();
+ // Skip the initialization if profiling is disabled
return;
}
- m_Options = options;
+
+ // Register a category for the basic runtime counters
+ if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime"))
+ {
+ m_CounterDirectory.RegisterCategory("ArmNN_Runtime");
+ }
+
+ // Register a counter for the number of loaded networks
+ if (!m_CounterDirectory.IsCounterRegistered("Loaded networks"))
+ {
+ const Counter* loadedNetworksCounter =
+ m_CounterDirectory.RegisterCounter("ArmNN_Runtime",
+ 0,
+ 0,
+ 1.f,
+ "Loaded networks",
+ "The number of networks loaded at runtime",
+ std::string("networks"));
+ BOOST_ASSERT(loadedNetworksCounter);
+ InitializeCounterValue(loadedNetworksCounter->m_Uid);
+ }
+
+ // Register a counter for the number of registered backends
+ if (!m_CounterDirectory.IsCounterRegistered("Registered backends"))
+ {
+ const Counter* registeredBackendsCounter =
+ m_CounterDirectory.RegisterCounter("ArmNN_Runtime",
+ 0,
+ 0,
+ 1.f,
+ "Registered backends",
+ "The number of registered backends",
+ std::string("backends"));
+ BOOST_ASSERT(registeredBackendsCounter);
+ InitializeCounterValue(registeredBackendsCounter->m_Uid);
+ }
+
+ // Register a counter for the number of inferences run
+ if (!m_CounterDirectory.IsCounterRegistered("Inferences run"))
+ {
+ const Counter* inferencesRunCounter =
+ m_CounterDirectory.RegisterCounter("ArmNN_Runtime",
+ 0,
+ 0,
+ 1.f,
+ "Inferences run",
+ "The number of inferences run",
+ std::string("inferences"));
+ BOOST_ASSERT(inferencesRunCounter);
+ InitializeCounterValue(inferencesRunCounter->m_Uid);
+ }
+
+ // Initialization is done, update the profiling service state
+ m_StateMachine.TransitionToState(ProfilingState::NotConnected);
}
-inline void ProfilingService::CheckIndexSize(uint16_t counterIndex) const
+void ProfilingService::InitializeCounterValue(uint16_t counterUid)
{
- if (counterIndex >= m_CounterDirectory.GetCounterCount())
+ // Increase the size of the counter index if necessary
+ if (counterUid >= m_CounterIndex.size())
{
- throw InvalidArgumentException("Counter index is out of range");
+ m_CounterIndex.resize(boost::numeric_cast<size_t>(counterUid) + 1);
}
+
+ // Create a new atomic counter and add it to the list
+ m_CounterValues.emplace_back(0);
+
+ // Register the new counter to the counter index for quick access
+ std::atomic<uint32_t>* counterValuePtr = &(m_CounterValues.back());
+ m_CounterIndex.at(counterUid) = counterValuePtr;
}
} // namespace profiling
-} // namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp
index 6d617978e5..36d95e0b5e 100644
--- a/src/profiling/ProfilingService.hpp
+++ b/src/profiling/ProfilingService.hpp
@@ -16,38 +16,63 @@ namespace armnn
namespace profiling
{
-class ProfilingService : IWriteCounterValues
+class ProfilingService final : public IReadWriteCounterValues
{
public:
- ProfilingService(const Runtime::CreationOptions::ExternalProfilingOptions& options);
- ~ProfilingService() = default;
+ using ExternalProfilingOptions = Runtime::CreationOptions::ExternalProfilingOptions;
+ using IProfilingConnectionPtr = std::unique_ptr<IProfilingConnection>;
+ using CounterIndices = std::vector<std::atomic<uint32_t>*>;
+ using CounterValues = std::list<std::atomic<uint32_t>>;
+
+ // Getter for the singleton instance
+ static ProfilingService& Instance()
+ {
+ static ProfilingService instance;
+ return instance;
+ }
+
+ // Resets the profiling options, optionally clears the profiling service entirely
+ void ResetExternalProfilingOptions(const ExternalProfilingOptions& options, bool resetProfilingService = false);
+ // Runs the profiling service
void Run();
+ // Getters for the profiling service state
const ICounterDirectory& GetCounterDirectory() const;
ProfilingState GetCurrentState() const;
- void ResetExternalProfilingOptions(const Runtime::CreationOptions::ExternalProfilingOptions& options);
+ uint16_t GetCounterCount() const override;
+ uint32_t GetCounterValue(uint16_t counterUid) const override;
- uint16_t GetCounterCount() const;
- void GetCounterValue(uint16_t index, uint32_t& value) const;
- void SetCounterValue(uint16_t index, uint32_t value);
- void AddCounterValue(uint16_t index, uint32_t value);
- void SubtractCounterValue(uint16_t index, uint32_t value);
- void IncrementCounterValue(uint16_t index);
- void DecrementCounterValue(uint16_t index);
+ // Setters for the profiling service state
+ void SetCounterValue(uint16_t counterUid, uint32_t value) override;
+ uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) override;
+ uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) override;
+ uint32_t IncrementCounterValue(uint16_t counterUid) override;
+ uint32_t DecrementCounterValue(uint16_t counterUid) override;
private:
- void Initialise();
- void CheckIndexSize(uint16_t counterIndex) const;
+ // Default/copy/move constructors/destructors and copy/move assignment operators are kept private
+ ProfilingService() = default;
+ ProfilingService(const ProfilingService&) = delete;
+ ProfilingService(ProfilingService&&) = delete;
+ ProfilingService& operator=(const ProfilingService&) = delete;
+ ProfilingService& operator=(ProfilingService&&) = delete;
+ ~ProfilingService() = default;
- CounterDirectory m_CounterDirectory;
- ProfilingConnectionFactory m_Factory;
- Runtime::CreationOptions::ExternalProfilingOptions m_Options;
- ProfilingStateMachine m_State;
+ // Initialization functions
+ void Initialize();
+ void InitializeCounterValue(uint16_t counterUid);
- std::unordered_map<uint16_t, std::atomic<uint32_t>> m_CounterIdToValue;
+ // Profiling service state variables
+ ExternalProfilingOptions m_Options;
+ CounterDirectory m_CounterDirectory;
+ ProfilingConnectionFactory m_ProfilingConnectionFactory;
+ IProfilingConnectionPtr m_ProfilingConnection;
+ ProfilingStateMachine m_StateMachine;
+ CounterIndices m_CounterIndex;
+ CounterValues m_CounterValues;
};
} // namespace profiling
-} // namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/profiling/ProfilingStateMachine.cpp b/src/profiling/ProfilingStateMachine.cpp
index 682e1b8894..5af5bfbed0 100644
--- a/src/profiling/ProfilingStateMachine.cpp
+++ b/src/profiling/ProfilingStateMachine.cpp
@@ -7,87 +7,89 @@
#include <armnn/Exceptions.hpp>
+#include <sstream>
+
namespace armnn
{
namespace profiling
{
+namespace
+{
+
+void ThrowStateTransitionException(ProfilingState expectedState, ProfilingState newState)
+{
+ std::stringstream ss;
+ ss << "Cannot transition from state [" << GetProfilingStateName(expectedState) << "] "
+ << "to state [" << GetProfilingStateName(newState) << "]";
+ throw armnn::RuntimeException(ss.str());
+}
+
+} // Anonymous namespace
+
ProfilingState ProfilingStateMachine::GetCurrentState() const
{
- return m_State;
+ return m_State.load();
}
void ProfilingStateMachine::TransitionToState(ProfilingState newState)
{
- switch (newState)
- {
- case ProfilingState::Uninitialised:
- {
- ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
- do {
- if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised))
- {
- throw armnn::Exception(std::string("Cannot transition from state [")
- + GetProfilingStateName(expectedState)
- +"] to [" + GetProfilingStateName(newState) + "]");
- }
- } while (!m_State.compare_exchange_strong(expectedState, newState,
- std::memory_order::memory_order_relaxed));
-
- break;
- }
- case ProfilingState::NotConnected:
- {
- ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
- do {
- if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised, ProfilingState::NotConnected,
- ProfilingState::Active))
- {
- throw armnn::Exception(std::string("Cannot transition from state [")
- + GetProfilingStateName(expectedState)
- +"] to [" + GetProfilingStateName(newState) + "]");
- }
- } while (!m_State.compare_exchange_strong(expectedState, newState,
- std::memory_order::memory_order_relaxed));
-
- break;
- }
- case ProfilingState::WaitingForAck:
- {
- ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
- do {
- if (!IsOneOfStates(expectedState, ProfilingState::NotConnected, ProfilingState::WaitingForAck))
- {
- throw armnn::Exception(std::string("Cannot transition from state [")
- + GetProfilingStateName(expectedState)
- +"] to [" + GetProfilingStateName(newState) + "]");
- }
- } while (!m_State.compare_exchange_strong(expectedState, newState,
- std::memory_order::memory_order_relaxed));
+ ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
- break;
- }
- case ProfilingState::Active:
- {
- ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
- do {
- if (!IsOneOfStates(expectedState, ProfilingState::WaitingForAck, ProfilingState::Active))
- {
- throw armnn::Exception(std::string("Cannot transition from state [")
- + GetProfilingStateName(expectedState)
- +"] to [" + GetProfilingStateName(newState) + "]");
- }
- } while (!m_State.compare_exchange_strong(expectedState, newState,
- std::memory_order::memory_order_relaxed));
+ switch (newState)
+ {
+ case ProfilingState::Uninitialised:
+ do
+ {
+ if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised))
+ {
+ ThrowStateTransitionException(expectedState, newState);
+ }
+ }
+ while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed));
+ break;
+ case ProfilingState::NotConnected:
+ do
+ {
+ if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised, ProfilingState::NotConnected,
+ ProfilingState::Active))
+ {
+ ThrowStateTransitionException(expectedState, newState);
+ }
+ }
+ while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed));
+ break;
+ case ProfilingState::WaitingForAck:
+ do
+ {
+ if (!IsOneOfStates(expectedState, ProfilingState::NotConnected, ProfilingState::WaitingForAck))
+ {
+ ThrowStateTransitionException(expectedState, newState);
+ }
+ }
+ while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed));
+ break;
+ case ProfilingState::Active:
+ do
+ {
+ if (!IsOneOfStates(expectedState, ProfilingState::WaitingForAck, ProfilingState::Active))
+ {
+ ThrowStateTransitionException(expectedState, newState);
+ }
+ }
+ while (!m_State.compare_exchange_strong(expectedState, newState, std::memory_order::memory_order_relaxed));
+ break;
+ default:
+ break;
+ }
+}
- break;
- }
- default:
- break;
- }
+void ProfilingStateMachine::Reset()
+{
+ m_State.store(ProfilingState::Uninitialised);
}
-} //namespace profiling
+} // namespace profiling
-} //namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/profiling/ProfilingStateMachine.hpp b/src/profiling/ProfilingStateMachine.hpp
index 66f8b2cd17..d070744b1b 100644
--- a/src/profiling/ProfilingStateMachine.hpp
+++ b/src/profiling/ProfilingStateMachine.hpp
@@ -2,6 +2,7 @@
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
#pragma once
#include <atomic>
@@ -23,11 +24,12 @@ enum class ProfilingState
class ProfilingStateMachine
{
public:
- ProfilingStateMachine(): m_State(ProfilingState::Uninitialised) {};
- ProfilingStateMachine(ProfilingState state): m_State(state) {};
+ ProfilingStateMachine() : m_State(ProfilingState::Uninitialised) {}
+ ProfilingStateMachine(ProfilingState state) : m_State(state) {}
ProfilingState GetCurrentState() const;
void TransitionToState(ProfilingState newState);
+ void Reset();
bool IsOneOfStates(ProfilingState state1)
{
@@ -53,17 +55,17 @@ private:
constexpr char const* GetProfilingStateName(ProfilingState state)
{
- switch(state)
+ switch (state)
{
- case ProfilingState::Uninitialised: return "Uninitialised";
- case ProfilingState::NotConnected: return "NotConnected";
- case ProfilingState::WaitingForAck: return "WaitingForAck";
- case ProfilingState::Active: return "Active";
- default: return "Unknown";
+ case ProfilingState::Uninitialised: return "Uninitialised";
+ case ProfilingState::NotConnected: return "NotConnected";
+ case ProfilingState::WaitingForAck: return "WaitingForAck";
+ case ProfilingState::Active: return "Active";
+ default: return "Unknown";
}
}
-} //namespace profiling
+} // namespace profiling
-} //namespace armnn
+} // namespace armnn
diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp
index 9aafa2ccc8..7f3696a940 100644
--- a/src/profiling/SendCounterPacket.cpp
+++ b/src/profiling/SendCounterPacket.cpp
@@ -903,7 +903,7 @@ void SendCounterPacket::SetReadyToRead()
void SendCounterPacket::Start()
{
- // Check is the send thread is already running
+ // Check if the send thread is already running
if (m_IsRunning.load())
{
// The send thread is already running
diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp
index 45b7f9dcfe..6955f70a48 100644
--- a/src/profiling/SocketProfilingConnection.cpp
+++ b/src/profiling/SocketProfilingConnection.cpp
@@ -23,7 +23,7 @@ SocketProfilingConnection::SocketProfilingConnection()
m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
if (m_Socket[0].fd == -1)
{
- throw armnn::RuntimeException(std::string("Socket construction failed: ") + strerror(errno));
+ throw RuntimeException(std::string("Socket construction failed: ") + strerror(errno));
}
// Connect to the named unix domain socket.
@@ -35,7 +35,7 @@ SocketProfilingConnection::SocketProfilingConnection()
if (0 != connect(m_Socket[0].fd, reinterpret_cast<const sockaddr*>(&server), sizeof(sockaddr_un)))
{
close(m_Socket[0].fd);
- throw armnn::RuntimeException(std::string("Cannot connect to stream socket: ") + strerror(errno));
+ throw RuntimeException(std::string("Cannot connect to stream socket: ") + strerror(errno));
}
// Our socket will only be interested in polling reads.
@@ -46,7 +46,7 @@ SocketProfilingConnection::SocketProfilingConnection()
if (0 != fcntl(m_Socket[0].fd, F_SETFL, currentFlags | O_NONBLOCK))
{
close(m_Socket[0].fd);
- throw armnn::RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno));
+ throw RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno));
}
}
@@ -59,7 +59,7 @@ void SocketProfilingConnection::Close()
{
if (close(m_Socket[0].fd) != 0)
{
- throw armnn::RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno));
+ throw RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno));
}
memset(m_Socket, 0, sizeof(m_Socket));
@@ -83,17 +83,17 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
switch (pollResult)
{
case -1: // Error
- throw armnn::RuntimeException(std::string("Read failure from socket: ") + strerror(errno));
+ throw RuntimeException(std::string("Read failure from socket: ") + strerror(errno));
case 0: // Timeout
- throw armnn::RuntimeException("Timeout while reading from socket");
+ throw RuntimeException("Timeout while reading from socket");
default: // Normal poll return but it could still contain an error signal
// Check if the socket reported an error
if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP))
{
- throw armnn::Exception(std::string("Socket 0 reported an error: ") + strerror(errno));
+ throw Exception(std::string("Socket 0 reported an error: ") + strerror(errno));
}
// Check if there is data to read
@@ -108,7 +108,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
if (8 != recv(m_Socket[0].fd, &header, sizeof(header), 0))
{
// What do we do here if there's not a valid 8 byte header to read?
- throw armnn::RuntimeException("The received packet did not contains a valid MIPE header");
+ throw RuntimeException("The received packet did not contains a valid MIPE header");
}
// stream_metadata_identifier is the first 4 bytes
@@ -133,7 +133,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
if (dataLength != static_cast<uint32_t>(receivedLength))
{
// What do we do here if we can't read in a full packet?
- throw armnn::RuntimeException("Invalid MIPE packet");
+ throw RuntimeException("Invalid MIPE packet");
}
return Packet(metadataIdentifier, dataLength, packetData);
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index 5ef9811b2c..d14791c43d 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -659,19 +659,18 @@ BOOST_AUTO_TEST_CASE(CaptureDataMethods)
BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisabled)
{
armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
- ProfilingService service(options);
- BOOST_CHECK(service.GetCurrentState() == ProfilingState::Uninitialised);
- service.Run();
- BOOST_CHECK(service.GetCurrentState() == ProfilingState::Uninitialised);
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Run();
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
}
-struct cerr_redirect {
+struct cerr_redirect
+{
cerr_redirect(std::streambuf* new_buffer)
- : old( std::cerr.rdbuf(new_buffer)) {}
-
- ~cerr_redirect( ) {
- std::cerr.rdbuf(old);
- }
+ : old(std::cerr.rdbuf(new_buffer)) {}
+ ~cerr_redirect() { std::cerr.rdbuf(old); }
private:
std::streambuf* old;
@@ -681,46 +680,49 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled)
{
armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
options.m_EnableProfiling = true;
- ProfilingService service(options);
- BOOST_CHECK(service.GetCurrentState() == ProfilingState::NotConnected);
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
// As there is no daemon running a connection cannot be made so expect a std::cerr to console
std::stringstream ss;
cerr_redirect guard(ss.rdbuf());
- service.Run();
+ profilingService.Run();
BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
}
BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabledRuntime)
{
armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
- ProfilingService service(options);
- BOOST_CHECK(service.GetCurrentState() == ProfilingState::Uninitialised);
- service.Run();
- BOOST_CHECK(service.GetCurrentState() == ProfilingState::Uninitialised);
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+ profilingService.Run();
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
options.m_EnableProfiling = true;
- service.ResetExternalProfilingOptions(options);
- BOOST_CHECK(service.GetCurrentState() == ProfilingState::NotConnected);
+ profilingService.ResetExternalProfilingOptions(options);
+ BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
// As there is no daemon running a connection cannot be made so expect a std::cerr to console
std::stringstream ss;
cerr_redirect guard(ss.rdbuf());
- service.Run();
+ profilingService.Run();
BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
}
BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterDirectory)
{
armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
- ProfilingService service(options);
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
- const ICounterDirectory& counterDirectory0 = service.GetCounterDirectory();
+ const ICounterDirectory& counterDirectory0 = profilingService.GetCounterDirectory();
BOOST_CHECK(counterDirectory0.GetCounterCount() == 0);
options.m_EnableProfiling = true;
- service.ResetExternalProfilingOptions(options);
+ profilingService.ResetExternalProfilingOptions(options);
- const ICounterDirectory& counterDirectory1 = service.GetCounterDirectory();
+ const ICounterDirectory& counterDirectory1 = profilingService.GetCounterDirectory();
BOOST_CHECK(counterDirectory1.GetCounterCount() != 0);
}
@@ -728,32 +730,38 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterValues)
{
armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
options.m_EnableProfiling = true;
- ProfilingService profilingService(options);
+ ProfilingService& profilingService = ProfilingService::Instance();
+ profilingService.ResetExternalProfilingOptions(options, true);
+
+ const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
+ const Counters& counters = counterDirectory.GetCounters();
+ BOOST_CHECK(!counters.empty());
+
+ // Get the UID of the first counter for testing
+ uint16_t counterUid = counters.begin()->first;
ProfilingService* profilingServicePtr = &profilingService;
std::vector<std::thread> writers;
- for(int i = 0; i < 100 ; ++i)
+ for (int i = 0; i < 100 ; ++i)
{
- // Increment and decrement counter 0
- writers.push_back(std::thread(&ProfilingService::IncrementCounterValue, profilingServicePtr, 0));
- writers.push_back(std::thread(&ProfilingService::DecrementCounterValue, profilingServicePtr, 0));
- // Add 10 to counter 0 and subtract 5 from counter 0
- writers.push_back(std::thread(&ProfilingService::AddCounterValue, profilingServicePtr, 0, 10));
- writers.push_back(std::thread(&ProfilingService::SubtractCounterValue, profilingServicePtr, 0, 5));
+ // Increment and decrement the first counter
+ writers.push_back(std::thread(&ProfilingService::IncrementCounterValue, profilingServicePtr, counterUid));
+ writers.push_back(std::thread(&ProfilingService::DecrementCounterValue, profilingServicePtr, counterUid));
+ // Add 10 and subtract 5 from the first counter
+ writers.push_back(std::thread(&ProfilingService::AddCounterValue, profilingServicePtr, counterUid, 10));
+ writers.push_back(std::thread(&ProfilingService::SubtractCounterValue, profilingServicePtr, counterUid, 5));
}
std::for_each(writers.begin(), writers.end(), mem_fn(&std::thread::join));
- uint32_t counterValue;
- profilingService.GetCounterValue(0, counterValue);
+ uint32_t counterValue = 0;
+ BOOST_CHECK_NO_THROW(counterValue = profilingService.GetCounterValue(counterUid));
BOOST_CHECK(counterValue == 500);
- profilingService.SetCounterValue(0, 0);
- profilingService.GetCounterValue(0, counterValue);
+ BOOST_CHECK_NO_THROW(profilingService.SetCounterValue(counterUid, 0));
+ BOOST_CHECK_NO_THROW(counterValue = profilingService.GetCounterValue(counterUid));
BOOST_CHECK(counterValue == 0);
-
- BOOST_CHECK_THROW(profilingService.SetCounterValue(profilingService.GetCounterCount(), 1), armnn::Exception);
}
BOOST_AUTO_TEST_CASE(CheckProfilingObjectUids)