From 4833cea9036df428634cf64d8f1c4b54fc5da41f Mon Sep 17 00:00:00 2001 From: FinnWilliamsArm Date: Tue, 17 Sep 2019 16:53:53 +0100 Subject: IVGCVSW-3439 Create the Command Thread Signed-off-by: FinnWilliamsArm Change-Id: I9548c5937967f4c25841bb851273168379687bcd --- src/profiling/CommandThread.cpp | 97 +++++++++++++++++++ src/profiling/CommandThread.hpp | 53 ++++++++++ src/profiling/SocketProfilingConnection.cpp | 2 +- src/profiling/test/ProfilingTests.cpp | 145 ++++++++++++++++++++++++++++ 4 files changed, 296 insertions(+), 1 deletion(-) create mode 100644 src/profiling/CommandThread.cpp create mode 100644 src/profiling/CommandThread.hpp (limited to 'src') diff --git a/src/profiling/CommandThread.cpp b/src/profiling/CommandThread.cpp new file mode 100644 index 0000000000..4cd622c477 --- /dev/null +++ b/src/profiling/CommandThread.cpp @@ -0,0 +1,97 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include +#include "CommandThread.hpp" + +namespace armnn +{ + +namespace profiling +{ + +CommandThread::CommandThread(uint32_t timeout, + bool stopAfterTimeout, + CommandHandlerRegistry& commandHandlerRegistry, + PacketVersionResolver& packetVersionResolver, + IProfilingConnection& socketProfilingConnection) + : m_Timeout(timeout) + , m_StopAfterTimeout(stopAfterTimeout) + , m_IsRunning(false) + , m_CommandHandlerRegistry(commandHandlerRegistry) + , m_PacketVersionResolver(packetVersionResolver) + , m_SocketProfilingConnection(socketProfilingConnection) +{}; + +void CommandThread::WaitForPacket() +{ + do { + try + { + Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout); + Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId()); + + CommandHandlerFunctor* commandHandlerFunctor = + m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue()); + commandHandlerFunctor->operator()(packet); + } + catch(armnn::TimeoutException) + { + if(m_StopAfterTimeout) + { + m_IsRunning.store(false, std::memory_order_relaxed); + return; + } + } + catch(...) + { + //might want to differentiate the errors more + m_IsRunning.store(false, std::memory_order_relaxed); + return; + } + + } while(m_KeepRunning.load(std::memory_order_relaxed)); + + m_IsRunning.store(false, std::memory_order_relaxed); +} + +void CommandThread::Start() +{ + if (!m_CommandThread.joinable() && !IsRunning()) + { + m_IsRunning.store(true, std::memory_order_relaxed); + m_KeepRunning.store(true, std::memory_order_relaxed); + m_CommandThread = std::thread(&CommandThread::WaitForPacket, this); + } +} + +void CommandThread::Stop() +{ + m_KeepRunning.store(false, std::memory_order_relaxed); +} + +void CommandThread::Join() +{ + m_CommandThread.join(); +} + +bool CommandThread::IsRunning() const +{ + return m_IsRunning.load(std::memory_order_relaxed); +} + +bool CommandThread::StopAfterTimeout(bool stopAfterTimeout) +{ + if (!IsRunning()) + { + m_StopAfterTimeout = stopAfterTimeout; + return true; + } + return false; +} + +}//namespace profiling + +}//namespace armnn \ No newline at end of file diff --git a/src/profiling/CommandThread.hpp b/src/profiling/CommandThread.hpp new file mode 100644 index 0000000000..6237cd2914 --- /dev/null +++ b/src/profiling/CommandThread.hpp @@ -0,0 +1,53 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "CommandHandlerRegistry.hpp" +#include "IProfilingConnection.hpp" +#include "PacketVersionResolver.hpp" +#include "ProfilingService.hpp" + +#include +#include + +namespace armnn +{ + +namespace profiling +{ + +class CommandThread +{ +public: + CommandThread(uint32_t timeout, + bool stopAfterTimeout, + CommandHandlerRegistry& commandHandlerRegistry, + PacketVersionResolver& packetVersionResolver, + IProfilingConnection& socketProfilingConnection); + + void Start(); + void Stop(); + void Join(); + bool IsRunning() const; + bool StopAfterTimeout(bool StopAfterTimeout); + +private: + void WaitForPacket(); + + uint32_t m_Timeout; + bool m_StopAfterTimeout; + std::atomic m_IsRunning; + std::atomic m_KeepRunning; + std::thread m_CommandThread; + + CommandHandlerRegistry& m_CommandHandlerRegistry; + PacketVersionResolver& m_PacketVersionResolver; + IProfilingConnection& m_SocketProfilingConnection; +}; + +}//namespace profiling + +}//namespace armnn \ No newline at end of file diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp index 188ca23e12..91d57cc9bd 100644 --- a/src/profiling/SocketProfilingConnection.cpp +++ b/src/profiling/SocketProfilingConnection.cpp @@ -135,7 +135,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) } else // it's 0 so a timeout. { - throw armnn::Exception(": Timeout while reading from socket."); + throw armnn::TimeoutException(": Timeout while reading from socket."); } } diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 32a41f37c2..48723dbc34 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -4,6 +4,7 @@ // #include "SendCounterPacketTests.hpp" +#include "../CommandThread.hpp" #include #include @@ -87,6 +88,150 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons) BOOST_CHECK(vect == expectedVect); } +class TestProfilingConnectionBase :public IProfilingConnection +{ +public: + TestProfilingConnectionBase() = default; + ~TestProfilingConnectionBase() = default; + + bool IsOpen() + { + return true; + } + + void Close(){} + + bool WritePacket(const char* buffer, uint32_t length) + { + return false; + } + + Packet ReadPacket(uint32_t timeout) + { + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + std::unique_ptr packetData; + //Return connection acknowledged packet + return {65536 ,0 , packetData}; + } +}; + +class TestProfilingConnectionTimeoutError :public TestProfilingConnectionBase +{ + int readRequests = 0; +public: + Packet ReadPacket(uint32_t timeout) { + if (readRequests < 3) + { + readRequests++; + throw armnn::TimeoutException(": Simulate a timeout"); + } + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + std::unique_ptr packetData; + //Return connection acknowledged packet after three timeouts + return {65536 ,0 , packetData}; + } +}; + +class TestProfilingConnectionArmnnError :public TestProfilingConnectionBase +{ +public: + + Packet ReadPacket(uint32_t timeout) + { + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + throw armnn::Exception(": Simulate a non timeout error"); + } +}; + +BOOST_AUTO_TEST_CASE(CheckCommandThread) +{ + PacketVersionResolver packetVersionResolver; + ProfilingStateMachine profilingStateMachine; + + TestProfilingConnectionBase testProfilingConnectionBase; + TestProfilingConnectionTimeoutError testProfilingConnectionTimeOutError; + TestProfilingConnectionArmnnError testProfilingConnectionArmnnError; + + ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(1, 4194304, profilingStateMachine); + CommandHandlerRegistry commandHandlerRegistry; + + commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler, 1, 4194304); + + profilingStateMachine.TransitionToState(ProfilingState::NotConnected); + profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); + + CommandThread commandThread0(1, + true, + commandHandlerRegistry, + packetVersionResolver, + testProfilingConnectionBase); + + commandThread0.Start(); + commandThread0.Start(); + commandThread0.Start(); + + commandThread0.Stop(); + commandThread0.Join(); + + BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active); + + profilingStateMachine.TransitionToState(ProfilingState::NotConnected); + profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); + //commandThread1 should give up after one timeout + CommandThread commandThread1(1, + true, + commandHandlerRegistry, + packetVersionResolver, + testProfilingConnectionTimeOutError); + + commandThread1.Start(); + commandThread1.Join(); + + BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck); + //now commandThread1 should persist after a timeout + commandThread1.StopAfterTimeout(false); + commandThread1.Start(); + + for (int i = 0; i < 100; i++) + { + if (profilingStateMachine.GetCurrentState() == ProfilingState::Active) + { + break; + } + else + { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + } + + commandThread1.Stop(); + commandThread1.Join(); + + BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active); + + + CommandThread commandThread2(1, + false, + commandHandlerRegistry, + packetVersionResolver, + testProfilingConnectionArmnnError); + + commandThread2.Start(); + + for (int i = 0; i < 100; i++) + { + if (!commandThread2.IsRunning()) + { + //commandThread2 should stop once it encounters a non timing error + commandThread2.Join(); + return; + } + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + + BOOST_ERROR("commandThread2 has failed to stop"); +} + BOOST_AUTO_TEST_CASE(CheckEncodeVersion) { Version version1(12); -- cgit v1.2.1