diff options
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/profiling/ProfilingStateMachine.cpp | 93 | ||||
-rw-r--r-- | src/profiling/ProfilingStateMachine.hpp | 69 | ||||
-rw-r--r-- | src/profiling/test/ProfilingTests.cpp | 91 |
4 files changed, 255 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 1a07f6905f..a285c36a68 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -433,6 +433,8 @@ list(APPEND armnn_sources src/profiling/SendCounterPacket.cpp src/profiling/ProfilingUtils.hpp src/profiling/ProfilingUtils.cpp + src/profiling/ProfilingStateMachine.cpp + src/profiling/ProfilingStateMachine.hpp third-party/half/half.hpp ) diff --git a/src/profiling/ProfilingStateMachine.cpp b/src/profiling/ProfilingStateMachine.cpp new file mode 100644 index 0000000000..682e1b8894 --- /dev/null +++ b/src/profiling/ProfilingStateMachine.cpp @@ -0,0 +1,93 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ProfilingStateMachine.hpp" + +#include <armnn/Exceptions.hpp> + +namespace armnn +{ + +namespace profiling +{ + +ProfilingState ProfilingStateMachine::GetCurrentState() const +{ + return m_State; +} + +void ProfilingStateMachine::TransitionToState(ProfilingState newState) +{ + switch (newState) + { + case ProfilingState::Uninitialised: + { + ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); + do { + if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised)) + { + throw armnn::Exception(std::string("Cannot transition from state [") + + GetProfilingStateName(expectedState) + +"] to [" + GetProfilingStateName(newState) + "]"); + } + } while (!m_State.compare_exchange_strong(expectedState, newState, + std::memory_order::memory_order_relaxed)); + + break; + } + case ProfilingState::NotConnected: + { + ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); + do { + if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised, ProfilingState::NotConnected, + ProfilingState::Active)) + { + throw armnn::Exception(std::string("Cannot transition from state [") + + GetProfilingStateName(expectedState) + +"] to [" + GetProfilingStateName(newState) + "]"); + } + } while (!m_State.compare_exchange_strong(expectedState, newState, + std::memory_order::memory_order_relaxed)); + + break; + } + case ProfilingState::WaitingForAck: + { + ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); + do { + if (!IsOneOfStates(expectedState, ProfilingState::NotConnected, ProfilingState::WaitingForAck)) + { + throw armnn::Exception(std::string("Cannot transition from state [") + + GetProfilingStateName(expectedState) + +"] to [" + GetProfilingStateName(newState) + "]"); + } + } while (!m_State.compare_exchange_strong(expectedState, newState, + std::memory_order::memory_order_relaxed)); + + break; + } + case ProfilingState::Active: + { + ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); + do { + if (!IsOneOfStates(expectedState, ProfilingState::WaitingForAck, ProfilingState::Active)) + { + throw armnn::Exception(std::string("Cannot transition from state [") + + GetProfilingStateName(expectedState) + +"] to [" + GetProfilingStateName(newState) + "]"); + } + } while (!m_State.compare_exchange_strong(expectedState, newState, + std::memory_order::memory_order_relaxed)); + + break; + } + default: + break; + } +} + +} //namespace profiling + +} //namespace armnn
\ No newline at end of file diff --git a/src/profiling/ProfilingStateMachine.hpp b/src/profiling/ProfilingStateMachine.hpp new file mode 100644 index 0000000000..66f8b2cd17 --- /dev/null +++ b/src/profiling/ProfilingStateMachine.hpp @@ -0,0 +1,69 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include <atomic> + +namespace armnn +{ + +namespace profiling +{ + +enum class ProfilingState +{ + Uninitialised, + NotConnected, + WaitingForAck, + Active +}; + +class ProfilingStateMachine +{ +public: + ProfilingStateMachine(): m_State(ProfilingState::Uninitialised) {}; + ProfilingStateMachine(ProfilingState state): m_State(state) {}; + + ProfilingState GetCurrentState() const; + void TransitionToState(ProfilingState newState); + + bool IsOneOfStates(ProfilingState state1) + { + return false; + } + + template<typename T, typename... Args > + bool IsOneOfStates(T state1, T state2, Args... args) + { + if (state1 == state2) + { + return true; + } + else + { + return IsOneOfStates(state1, args...); + } + } + +private: + std::atomic<ProfilingState> m_State; +}; + +constexpr char const* GetProfilingStateName(ProfilingState state) +{ + switch(state) + { + case ProfilingState::Uninitialised: return "Uninitialised"; + case ProfilingState::NotConnected: return "NotConnected"; + case ProfilingState::WaitingForAck: return "WaitingForAck"; + case ProfilingState::Active: return "Active"; + default: return "Unknown"; + } +} + +} //namespace profiling + +} //namespace armnn + diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 3fd8d790a2..ce278abee7 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -9,6 +9,7 @@ #include "../EncodeVersion.hpp" #include "../Packet.hpp" #include "../PacketVersionResolver.hpp" +#include "../ProfilingStateMachine.hpp" #include <boost/test/unit_test.hpp> @@ -17,6 +18,7 @@ #include <limits> #include <map> #include <random> +#include <thread> BOOST_AUTO_TEST_SUITE(ExternalProfiling) @@ -265,5 +267,94 @@ BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver) BOOST_TEST(resolvedVersion == expectedVersion); } } +void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states) +{ + ProfilingState newState = ProfilingState::NotConnected; + states.GetCurrentState(); + states.TransitionToState(newState); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine) +{ + ProfilingStateMachine profilingState1(ProfilingState::Uninitialised); + profilingState1.TransitionToState(ProfilingState::Uninitialised); + BOOST_CHECK(profilingState1.GetCurrentState() == ProfilingState::Uninitialised); + + ProfilingStateMachine profilingState2(ProfilingState::Uninitialised); + profilingState2.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK(profilingState2.GetCurrentState() == ProfilingState::NotConnected); + + ProfilingStateMachine profilingState3(ProfilingState::NotConnected); + profilingState3.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK(profilingState3.GetCurrentState() == ProfilingState::NotConnected); + + ProfilingStateMachine profilingState4(ProfilingState::NotConnected); + profilingState4.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK(profilingState4.GetCurrentState() == ProfilingState::WaitingForAck); + + ProfilingStateMachine profilingState5(ProfilingState::WaitingForAck); + profilingState5.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK(profilingState5.GetCurrentState() == ProfilingState::WaitingForAck); + + ProfilingStateMachine profilingState6(ProfilingState::WaitingForAck); + profilingState6.TransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingState6.GetCurrentState() == ProfilingState::Active); + + ProfilingStateMachine profilingState7(ProfilingState::Active); + profilingState7.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK(profilingState7.GetCurrentState() == ProfilingState::NotConnected); + + ProfilingStateMachine profilingState8(ProfilingState::Active); + profilingState8.TransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingState8.GetCurrentState() == ProfilingState::Active); + + ProfilingStateMachine profilingState9(ProfilingState::Uninitialised); + BOOST_CHECK_THROW(profilingState9.TransitionToState(ProfilingState::WaitingForAck), + armnn::Exception); + + ProfilingStateMachine profilingState10(ProfilingState::Uninitialised); + BOOST_CHECK_THROW(profilingState10.TransitionToState(ProfilingState::Active), + armnn::Exception); + + ProfilingStateMachine profilingState11(ProfilingState::NotConnected); + BOOST_CHECK_THROW(profilingState11.TransitionToState(ProfilingState::Uninitialised), + armnn::Exception); + + ProfilingStateMachine profilingState12(ProfilingState::NotConnected); + BOOST_CHECK_THROW(profilingState12.TransitionToState(ProfilingState::Active), + armnn::Exception); + + ProfilingStateMachine profilingState13(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW(profilingState13.TransitionToState(ProfilingState::Uninitialised), + armnn::Exception); + + ProfilingStateMachine profilingState14(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW(profilingState14.TransitionToState(ProfilingState::NotConnected), + armnn::Exception); + + ProfilingStateMachine profilingState15(ProfilingState::Active); + BOOST_CHECK_THROW(profilingState15.TransitionToState(ProfilingState::Uninitialised), + armnn::Exception); + + ProfilingStateMachine profilingState16(armnn::profiling::ProfilingState::Active); + BOOST_CHECK_THROW(profilingState16.TransitionToState(ProfilingState::WaitingForAck), + armnn::Exception); + + ProfilingStateMachine profilingState17(ProfilingState::Uninitialised); + + std::thread thread1 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17)); + std::thread thread2 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17)); + std::thread thread3 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17)); + std::thread thread4 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17)); + std::thread thread5 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17)); + + thread1.join(); + thread2.join(); + thread3.join(); + thread4.join(); + thread5.join(); + + BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected)); +} BOOST_AUTO_TEST_SUITE_END() |