From 3ecc5104a088fe9fd0b504ca0c6c3a932ed342c4 Mon Sep 17 00:00:00 2001 From: Nikhil Raj Date: Tue, 3 Sep 2019 15:55:33 +0100 Subject: IVGCVSW-3431 Create Profiling Service State Machine Change-Id: I30ae52d38181a91ce642e24919ad788902e42eb4 Signed-off-by: Nikhil Raj --- src/profiling/ProfilingStateMachine.cpp | 93 +++++++++++++++++++++++++++++++++ src/profiling/ProfilingStateMachine.hpp | 69 ++++++++++++++++++++++++ src/profiling/test/ProfilingTests.cpp | 91 ++++++++++++++++++++++++++++++++ 3 files changed, 253 insertions(+) create mode 100644 src/profiling/ProfilingStateMachine.cpp create mode 100644 src/profiling/ProfilingStateMachine.hpp (limited to 'src/profiling') diff --git a/src/profiling/ProfilingStateMachine.cpp b/src/profiling/ProfilingStateMachine.cpp new file mode 100644 index 0000000000..682e1b8894 --- /dev/null +++ b/src/profiling/ProfilingStateMachine.cpp @@ -0,0 +1,93 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ProfilingStateMachine.hpp" + +#include + +namespace armnn +{ + +namespace profiling +{ + +ProfilingState ProfilingStateMachine::GetCurrentState() const +{ + return m_State; +} + +void ProfilingStateMachine::TransitionToState(ProfilingState newState) +{ + switch (newState) + { + case ProfilingState::Uninitialised: + { + ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); + do { + if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised)) + { + throw armnn::Exception(std::string("Cannot transition from state [") + + GetProfilingStateName(expectedState) + +"] to [" + GetProfilingStateName(newState) + "]"); + } + } while (!m_State.compare_exchange_strong(expectedState, newState, + std::memory_order::memory_order_relaxed)); + + break; + } + case ProfilingState::NotConnected: + { + ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); + do { + if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised, ProfilingState::NotConnected, + ProfilingState::Active)) + { + throw armnn::Exception(std::string("Cannot transition from state [") + + GetProfilingStateName(expectedState) + +"] to [" + GetProfilingStateName(newState) + "]"); + } + } while (!m_State.compare_exchange_strong(expectedState, newState, + std::memory_order::memory_order_relaxed)); + + break; + } + case ProfilingState::WaitingForAck: + { + ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); + do { + if (!IsOneOfStates(expectedState, ProfilingState::NotConnected, ProfilingState::WaitingForAck)) + { + throw armnn::Exception(std::string("Cannot transition from state [") + + GetProfilingStateName(expectedState) + +"] to [" + GetProfilingStateName(newState) + "]"); + } + } while (!m_State.compare_exchange_strong(expectedState, newState, + std::memory_order::memory_order_relaxed)); + + break; + } + case ProfilingState::Active: + { + ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed); + do { + if (!IsOneOfStates(expectedState, ProfilingState::WaitingForAck, ProfilingState::Active)) + { + throw armnn::Exception(std::string("Cannot transition from state [") + + GetProfilingStateName(expectedState) + +"] to [" + GetProfilingStateName(newState) + "]"); + } + } while (!m_State.compare_exchange_strong(expectedState, newState, + std::memory_order::memory_order_relaxed)); + + break; + } + default: + break; + } +} + +} //namespace profiling + +} //namespace armnn \ No newline at end of file diff --git a/src/profiling/ProfilingStateMachine.hpp b/src/profiling/ProfilingStateMachine.hpp new file mode 100644 index 0000000000..66f8b2cd17 --- /dev/null +++ b/src/profiling/ProfilingStateMachine.hpp @@ -0,0 +1,69 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include + +namespace armnn +{ + +namespace profiling +{ + +enum class ProfilingState +{ + Uninitialised, + NotConnected, + WaitingForAck, + Active +}; + +class ProfilingStateMachine +{ +public: + ProfilingStateMachine(): m_State(ProfilingState::Uninitialised) {}; + ProfilingStateMachine(ProfilingState state): m_State(state) {}; + + ProfilingState GetCurrentState() const; + void TransitionToState(ProfilingState newState); + + bool IsOneOfStates(ProfilingState state1) + { + return false; + } + + template + bool IsOneOfStates(T state1, T state2, Args... args) + { + if (state1 == state2) + { + return true; + } + else + { + return IsOneOfStates(state1, args...); + } + } + +private: + std::atomic m_State; +}; + +constexpr char const* GetProfilingStateName(ProfilingState state) +{ + switch(state) + { + case ProfilingState::Uninitialised: return "Uninitialised"; + case ProfilingState::NotConnected: return "NotConnected"; + case ProfilingState::WaitingForAck: return "WaitingForAck"; + case ProfilingState::Active: return "Active"; + default: return "Unknown"; + } +} + +} //namespace profiling + +} //namespace armnn + diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 3fd8d790a2..ce278abee7 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -9,6 +9,7 @@ #include "../EncodeVersion.hpp" #include "../Packet.hpp" #include "../PacketVersionResolver.hpp" +#include "../ProfilingStateMachine.hpp" #include @@ -17,6 +18,7 @@ #include #include #include +#include BOOST_AUTO_TEST_SUITE(ExternalProfiling) @@ -265,5 +267,94 @@ BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver) BOOST_TEST(resolvedVersion == expectedVersion); } } +void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states) +{ + ProfilingState newState = ProfilingState::NotConnected; + states.GetCurrentState(); + states.TransitionToState(newState); +} + +BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine) +{ + ProfilingStateMachine profilingState1(ProfilingState::Uninitialised); + profilingState1.TransitionToState(ProfilingState::Uninitialised); + BOOST_CHECK(profilingState1.GetCurrentState() == ProfilingState::Uninitialised); + + ProfilingStateMachine profilingState2(ProfilingState::Uninitialised); + profilingState2.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK(profilingState2.GetCurrentState() == ProfilingState::NotConnected); + + ProfilingStateMachine profilingState3(ProfilingState::NotConnected); + profilingState3.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK(profilingState3.GetCurrentState() == ProfilingState::NotConnected); + + ProfilingStateMachine profilingState4(ProfilingState::NotConnected); + profilingState4.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK(profilingState4.GetCurrentState() == ProfilingState::WaitingForAck); + + ProfilingStateMachine profilingState5(ProfilingState::WaitingForAck); + profilingState5.TransitionToState(ProfilingState::WaitingForAck); + BOOST_CHECK(profilingState5.GetCurrentState() == ProfilingState::WaitingForAck); + + ProfilingStateMachine profilingState6(ProfilingState::WaitingForAck); + profilingState6.TransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingState6.GetCurrentState() == ProfilingState::Active); + + ProfilingStateMachine profilingState7(ProfilingState::Active); + profilingState7.TransitionToState(ProfilingState::NotConnected); + BOOST_CHECK(profilingState7.GetCurrentState() == ProfilingState::NotConnected); + + ProfilingStateMachine profilingState8(ProfilingState::Active); + profilingState8.TransitionToState(ProfilingState::Active); + BOOST_CHECK(profilingState8.GetCurrentState() == ProfilingState::Active); + + ProfilingStateMachine profilingState9(ProfilingState::Uninitialised); + BOOST_CHECK_THROW(profilingState9.TransitionToState(ProfilingState::WaitingForAck), + armnn::Exception); + + ProfilingStateMachine profilingState10(ProfilingState::Uninitialised); + BOOST_CHECK_THROW(profilingState10.TransitionToState(ProfilingState::Active), + armnn::Exception); + + ProfilingStateMachine profilingState11(ProfilingState::NotConnected); + BOOST_CHECK_THROW(profilingState11.TransitionToState(ProfilingState::Uninitialised), + armnn::Exception); + + ProfilingStateMachine profilingState12(ProfilingState::NotConnected); + BOOST_CHECK_THROW(profilingState12.TransitionToState(ProfilingState::Active), + armnn::Exception); + + ProfilingStateMachine profilingState13(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW(profilingState13.TransitionToState(ProfilingState::Uninitialised), + armnn::Exception); + + ProfilingStateMachine profilingState14(ProfilingState::WaitingForAck); + BOOST_CHECK_THROW(profilingState14.TransitionToState(ProfilingState::NotConnected), + armnn::Exception); + + ProfilingStateMachine profilingState15(ProfilingState::Active); + BOOST_CHECK_THROW(profilingState15.TransitionToState(ProfilingState::Uninitialised), + armnn::Exception); + + ProfilingStateMachine profilingState16(armnn::profiling::ProfilingState::Active); + BOOST_CHECK_THROW(profilingState16.TransitionToState(ProfilingState::WaitingForAck), + armnn::Exception); + + ProfilingStateMachine profilingState17(ProfilingState::Uninitialised); + + std::thread thread1 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17)); + std::thread thread2 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17)); + std::thread thread3 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17)); + std::thread thread4 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17)); + std::thread thread5 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17)); + + thread1.join(); + thread2.join(); + thread3.join(); + thread4.join(); + thread5.join(); + + BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected)); +} BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1