From 88813936232bc47fc7768800c6895191585570e8 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Fri, 4 Oct 2019 14:40:04 +0100 Subject: IVGCVSW-3937 Refactor the command thread * Integrated the Join method into Stop * Updated the unit tests accordingly * General code refactoring Signed-off-by: Matteo Martincigh Change-Id: If8537e77b3d3ff2b780f58a07df01191a91d83d2 --- src/profiling/CommandThread.cpp | 108 +++++++++++++++------------------- src/profiling/CommandThread.hpp | 26 +++++--- src/profiling/test/ProfilingTests.cpp | 13 ++-- 3 files changed, 75 insertions(+), 72 deletions(-) (limited to 'src') diff --git a/src/profiling/CommandThread.cpp b/src/profiling/CommandThread.cpp index bd4aa96c7c..320e4bcf5c 100644 --- a/src/profiling/CommandThread.cpp +++ b/src/profiling/CommandThread.cpp @@ -12,86 +12,76 @@ namespace armnn namespace profiling { -CommandThread::CommandThread(uint32_t timeout, - bool stopAfterTimeout, - CommandHandlerRegistry& commandHandlerRegistry, - PacketVersionResolver& packetVersionResolver, - IProfilingConnection& socketProfilingConnection) - : m_Timeout(timeout) - , m_StopAfterTimeout(stopAfterTimeout) - , m_IsRunning(false) - , m_CommandHandlerRegistry(commandHandlerRegistry) - , m_PacketVersionResolver(packetVersionResolver) - , m_SocketProfilingConnection(socketProfilingConnection) -{}; - -void CommandThread::WaitForPacket() +void CommandThread::Start() { - do { - try - { - Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout); - Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId()); - - CommandHandlerFunctor* commandHandlerFunctor = - m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue()); - commandHandlerFunctor->operator()(packet); - } - catch(const armnn::TimeoutException&) - { - if(m_StopAfterTimeout) - { - m_IsRunning.store(false, std::memory_order_relaxed); - return; - } - } - catch(...) - { - //might want to differentiate the errors more - m_IsRunning.store(false, std::memory_order_relaxed); - return; - } - - } while(m_KeepRunning.load(std::memory_order_relaxed)); + if (IsRunning()) + { + return; + } - m_IsRunning.store(false, std::memory_order_relaxed); + m_IsRunning.store(true, std::memory_order_relaxed); + m_KeepRunning.store(true, std::memory_order_relaxed); + m_CommandThread = std::thread(&CommandThread::WaitForPacket, this); } -void CommandThread::Start() +void CommandThread::Stop() { - if (!m_CommandThread.joinable() && !IsRunning()) + m_KeepRunning.store(false, std::memory_order_relaxed); + + if (m_CommandThread.joinable()) { - m_IsRunning.store(true, std::memory_order_relaxed); - m_KeepRunning.store(true, std::memory_order_relaxed); - m_CommandThread = std::thread(&CommandThread::WaitForPacket, this); + m_CommandThread.join(); } } -void CommandThread::Stop() +bool CommandThread::IsRunning() const { - m_KeepRunning.store(false, std::memory_order_relaxed); + return m_IsRunning.load(std::memory_order_relaxed); } -void CommandThread::Join() +void CommandThread::SetTimeout(uint32_t timeout) { - m_CommandThread.join(); + m_Timeout.store(timeout, std::memory_order_relaxed); } -bool CommandThread::IsRunning() const +void CommandThread::SetStopAfterTimeout(bool stopAfterTimeout) { - return m_IsRunning.load(std::memory_order_relaxed); + m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed); } -bool CommandThread::StopAfterTimeout(bool stopAfterTimeout) +void CommandThread::WaitForPacket() { - if (!IsRunning()) + do { - m_StopAfterTimeout = stopAfterTimeout; - return true; + try + { + Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout); + Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId()); + + CommandHandlerFunctor* commandHandlerFunctor = + m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue()); + BOOST_ASSERT(commandHandlerFunctor); + commandHandlerFunctor->operator()(packet); + } + catch (const armnn::TimeoutException&) + { + if (m_StopAfterTimeout) + { + m_KeepRunning.store(false, std::memory_order_relaxed); + } + } + catch (...) + { + // Might want to differentiate the errors more + m_KeepRunning.store(false, std::memory_order_relaxed); + } + } - return false; + while (m_KeepRunning.load(std::memory_order_relaxed)); + + m_IsRunning.store(false, std::memory_order_relaxed); } -}//namespace profiling +} // namespace profiling -}//namespace armnn +} // namespace armnn diff --git a/src/profiling/CommandThread.hpp b/src/profiling/CommandThread.hpp index 6237cd2914..0456ba4372 100644 --- a/src/profiling/CommandThread.hpp +++ b/src/profiling/CommandThread.hpp @@ -26,19 +26,31 @@ public: bool stopAfterTimeout, CommandHandlerRegistry& commandHandlerRegistry, PacketVersionResolver& packetVersionResolver, - IProfilingConnection& socketProfilingConnection); + IProfilingConnection& socketProfilingConnection) + : m_Timeout(timeout) + , m_StopAfterTimeout(stopAfterTimeout) + , m_IsRunning(false) + , m_KeepRunning(false) + , m_CommandThread() + , m_CommandHandlerRegistry(commandHandlerRegistry) + , m_PacketVersionResolver(packetVersionResolver) + , m_SocketProfilingConnection(socketProfilingConnection) + {} + ~CommandThread() { Stop(); } void Start(); void Stop(); - void Join(); + bool IsRunning() const; - bool StopAfterTimeout(bool StopAfterTimeout); + + void SetTimeout(uint32_t timeout); + void SetStopAfterTimeout(bool stopAfterTimeout); private: void WaitForPacket(); - uint32_t m_Timeout; - bool m_StopAfterTimeout; + std::atomic m_Timeout; + std::atomic m_StopAfterTimeout; std::atomic m_IsRunning; std::atomic m_KeepRunning; std::thread m_CommandThread; @@ -48,6 +60,6 @@ private: IProfilingConnection& m_SocketProfilingConnection; }; -}//namespace profiling +} // namespace profiling -}//namespace armnn \ No newline at end of file +} // namespace armnn diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index d14791c43d..9dd7cd3d64 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -174,7 +174,6 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread) commandThread0.Start(); commandThread0.Stop(); - commandThread0.Join(); BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active); @@ -188,11 +187,15 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread) testProfilingConnectionTimeOutError); commandThread1.Start(); - commandThread1.Join(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + BOOST_CHECK(!commandThread1.IsRunning()); + commandThread1.Stop(); BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck); //now commandThread1 should persist after a timeout - commandThread1.StopAfterTimeout(false); + commandThread1.SetStopAfterTimeout(false); commandThread1.Start(); for (int i = 0; i < 100; i++) @@ -208,11 +211,9 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread) } commandThread1.Stop(); - commandThread1.Join(); BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active); - CommandThread commandThread2(1, false, commandHandlerRegistry, @@ -226,13 +227,13 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread) if (!commandThread2.IsRunning()) { //commandThread2 should stop once it encounters a non timing error - commandThread2.Join(); return; } std::this_thread::sleep_for(std::chrono::milliseconds(5)); } BOOST_ERROR("commandThread2 has failed to stop"); + commandThread2.Stop(); } BOOST_AUTO_TEST_CASE(CheckEncodeVersion) -- cgit v1.2.1