aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorNikhil Raj <nikhil.raj@arm.com>2019-09-03 15:55:33 +0100
committerNikhil Raj <nikhil.raj@arm.com>2019-09-03 15:55:33 +0100
commit3ecc5104a088fe9fd0b504ca0c6c3a932ed342c4 (patch)
tree7ea04111888647bc53f88003f5d2b44ceed22279 /src
parent7388217c07ccbd387d3bcfbde76cca744c4fd6fe (diff)
downloadarmnn-3ecc5104a088fe9fd0b504ca0c6c3a932ed342c4.tar.gz
IVGCVSW-3431 Create Profiling Service State Machine
Change-Id: I30ae52d38181a91ce642e24919ad788902e42eb4 Signed-off-by: Nikhil Raj <nikhil.raj@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/profiling/ProfilingStateMachine.cpp93
-rw-r--r--src/profiling/ProfilingStateMachine.hpp69
-rw-r--r--src/profiling/test/ProfilingTests.cpp91
3 files changed, 253 insertions, 0 deletions
diff --git a/src/profiling/ProfilingStateMachine.cpp b/src/profiling/ProfilingStateMachine.cpp
new file mode 100644
index 0000000000..682e1b8894
--- /dev/null
+++ b/src/profiling/ProfilingStateMachine.cpp
@@ -0,0 +1,93 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ProfilingStateMachine.hpp"
+
+#include <armnn/Exceptions.hpp>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+ProfilingState ProfilingStateMachine::GetCurrentState() const
+{
+ return m_State;
+}
+
+void ProfilingStateMachine::TransitionToState(ProfilingState newState)
+{
+ switch (newState)
+ {
+ case ProfilingState::Uninitialised:
+ {
+ ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
+ do {
+ if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised))
+ {
+ throw armnn::Exception(std::string("Cannot transition from state [")
+ + GetProfilingStateName(expectedState)
+ +"] to [" + GetProfilingStateName(newState) + "]");
+ }
+ } while (!m_State.compare_exchange_strong(expectedState, newState,
+ std::memory_order::memory_order_relaxed));
+
+ break;
+ }
+ case ProfilingState::NotConnected:
+ {
+ ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
+ do {
+ if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised, ProfilingState::NotConnected,
+ ProfilingState::Active))
+ {
+ throw armnn::Exception(std::string("Cannot transition from state [")
+ + GetProfilingStateName(expectedState)
+ +"] to [" + GetProfilingStateName(newState) + "]");
+ }
+ } while (!m_State.compare_exchange_strong(expectedState, newState,
+ std::memory_order::memory_order_relaxed));
+
+ break;
+ }
+ case ProfilingState::WaitingForAck:
+ {
+ ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
+ do {
+ if (!IsOneOfStates(expectedState, ProfilingState::NotConnected, ProfilingState::WaitingForAck))
+ {
+ throw armnn::Exception(std::string("Cannot transition from state [")
+ + GetProfilingStateName(expectedState)
+ +"] to [" + GetProfilingStateName(newState) + "]");
+ }
+ } while (!m_State.compare_exchange_strong(expectedState, newState,
+ std::memory_order::memory_order_relaxed));
+
+ break;
+ }
+ case ProfilingState::Active:
+ {
+ ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
+ do {
+ if (!IsOneOfStates(expectedState, ProfilingState::WaitingForAck, ProfilingState::Active))
+ {
+ throw armnn::Exception(std::string("Cannot transition from state [")
+ + GetProfilingStateName(expectedState)
+ +"] to [" + GetProfilingStateName(newState) + "]");
+ }
+ } while (!m_State.compare_exchange_strong(expectedState, newState,
+ std::memory_order::memory_order_relaxed));
+
+ break;
+ }
+ default:
+ break;
+ }
+}
+
+} //namespace profiling
+
+} //namespace armnn \ No newline at end of file
diff --git a/src/profiling/ProfilingStateMachine.hpp b/src/profiling/ProfilingStateMachine.hpp
new file mode 100644
index 0000000000..66f8b2cd17
--- /dev/null
+++ b/src/profiling/ProfilingStateMachine.hpp
@@ -0,0 +1,69 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include <atomic>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+enum class ProfilingState
+{
+ Uninitialised,
+ NotConnected,
+ WaitingForAck,
+ Active
+};
+
+class ProfilingStateMachine
+{
+public:
+ ProfilingStateMachine(): m_State(ProfilingState::Uninitialised) {};
+ ProfilingStateMachine(ProfilingState state): m_State(state) {};
+
+ ProfilingState GetCurrentState() const;
+ void TransitionToState(ProfilingState newState);
+
+ bool IsOneOfStates(ProfilingState state1)
+ {
+ return false;
+ }
+
+ template<typename T, typename... Args >
+ bool IsOneOfStates(T state1, T state2, Args... args)
+ {
+ if (state1 == state2)
+ {
+ return true;
+ }
+ else
+ {
+ return IsOneOfStates(state1, args...);
+ }
+ }
+
+private:
+ std::atomic<ProfilingState> m_State;
+};
+
+constexpr char const* GetProfilingStateName(ProfilingState state)
+{
+ switch(state)
+ {
+ case ProfilingState::Uninitialised: return "Uninitialised";
+ case ProfilingState::NotConnected: return "NotConnected";
+ case ProfilingState::WaitingForAck: return "WaitingForAck";
+ case ProfilingState::Active: return "Active";
+ default: return "Unknown";
+ }
+}
+
+} //namespace profiling
+
+} //namespace armnn
+
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index 3fd8d790a2..ce278abee7 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -9,6 +9,7 @@
#include "../EncodeVersion.hpp"
#include "../Packet.hpp"
#include "../PacketVersionResolver.hpp"
+#include "../ProfilingStateMachine.hpp"
#include <boost/test/unit_test.hpp>
@@ -17,6 +18,7 @@
#include <limits>
#include <map>
#include <random>
+#include <thread>
BOOST_AUTO_TEST_SUITE(ExternalProfiling)
@@ -265,5 +267,94 @@ BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver)
BOOST_TEST(resolvedVersion == expectedVersion);
}
}
+void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
+{
+ ProfilingState newState = ProfilingState::NotConnected;
+ states.GetCurrentState();
+ states.TransitionToState(newState);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine)
+{
+ ProfilingStateMachine profilingState1(ProfilingState::Uninitialised);
+ profilingState1.TransitionToState(ProfilingState::Uninitialised);
+ BOOST_CHECK(profilingState1.GetCurrentState() == ProfilingState::Uninitialised);
+
+ ProfilingStateMachine profilingState2(ProfilingState::Uninitialised);
+ profilingState2.TransitionToState(ProfilingState::NotConnected);
+ BOOST_CHECK(profilingState2.GetCurrentState() == ProfilingState::NotConnected);
+
+ ProfilingStateMachine profilingState3(ProfilingState::NotConnected);
+ profilingState3.TransitionToState(ProfilingState::NotConnected);
+ BOOST_CHECK(profilingState3.GetCurrentState() == ProfilingState::NotConnected);
+
+ ProfilingStateMachine profilingState4(ProfilingState::NotConnected);
+ profilingState4.TransitionToState(ProfilingState::WaitingForAck);
+ BOOST_CHECK(profilingState4.GetCurrentState() == ProfilingState::WaitingForAck);
+
+ ProfilingStateMachine profilingState5(ProfilingState::WaitingForAck);
+ profilingState5.TransitionToState(ProfilingState::WaitingForAck);
+ BOOST_CHECK(profilingState5.GetCurrentState() == ProfilingState::WaitingForAck);
+
+ ProfilingStateMachine profilingState6(ProfilingState::WaitingForAck);
+ profilingState6.TransitionToState(ProfilingState::Active);
+ BOOST_CHECK(profilingState6.GetCurrentState() == ProfilingState::Active);
+
+ ProfilingStateMachine profilingState7(ProfilingState::Active);
+ profilingState7.TransitionToState(ProfilingState::NotConnected);
+ BOOST_CHECK(profilingState7.GetCurrentState() == ProfilingState::NotConnected);
+
+ ProfilingStateMachine profilingState8(ProfilingState::Active);
+ profilingState8.TransitionToState(ProfilingState::Active);
+ BOOST_CHECK(profilingState8.GetCurrentState() == ProfilingState::Active);
+
+ ProfilingStateMachine profilingState9(ProfilingState::Uninitialised);
+ BOOST_CHECK_THROW(profilingState9.TransitionToState(ProfilingState::WaitingForAck),
+ armnn::Exception);
+
+ ProfilingStateMachine profilingState10(ProfilingState::Uninitialised);
+ BOOST_CHECK_THROW(profilingState10.TransitionToState(ProfilingState::Active),
+ armnn::Exception);
+
+ ProfilingStateMachine profilingState11(ProfilingState::NotConnected);
+ BOOST_CHECK_THROW(profilingState11.TransitionToState(ProfilingState::Uninitialised),
+ armnn::Exception);
+
+ ProfilingStateMachine profilingState12(ProfilingState::NotConnected);
+ BOOST_CHECK_THROW(profilingState12.TransitionToState(ProfilingState::Active),
+ armnn::Exception);
+
+ ProfilingStateMachine profilingState13(ProfilingState::WaitingForAck);
+ BOOST_CHECK_THROW(profilingState13.TransitionToState(ProfilingState::Uninitialised),
+ armnn::Exception);
+
+ ProfilingStateMachine profilingState14(ProfilingState::WaitingForAck);
+ BOOST_CHECK_THROW(profilingState14.TransitionToState(ProfilingState::NotConnected),
+ armnn::Exception);
+
+ ProfilingStateMachine profilingState15(ProfilingState::Active);
+ BOOST_CHECK_THROW(profilingState15.TransitionToState(ProfilingState::Uninitialised),
+ armnn::Exception);
+
+ ProfilingStateMachine profilingState16(armnn::profiling::ProfilingState::Active);
+ BOOST_CHECK_THROW(profilingState16.TransitionToState(ProfilingState::WaitingForAck),
+ armnn::Exception);
+
+ ProfilingStateMachine profilingState17(ProfilingState::Uninitialised);
+
+ std::thread thread1 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+ std::thread thread2 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+ std::thread thread3 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+ std::thread thread4 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+ std::thread thread5 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+
+ thread1.join();
+ thread2.join();
+ thread3.join();
+ thread4.join();
+ thread5.join();
+
+ BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
+}
BOOST_AUTO_TEST_SUITE_END()