From 8a837179ad883e9b5dd982a25cc5e94f245f79ed Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Fri, 4 Oct 2019 17:01:07 +0100 Subject: IVGCVSW-3937 Rename CommandThread to CommandHandler * Renamed files, class name and methods accordingly * Updated unit tests accordingly Signed-off-by: Matteo Martincigh Change-Id: Ifb88aa61edb93b852a07b1bd59bd259213677b44 --- CMakeLists.txt | 4 +- src/profiling/CommandHandler.cpp | 85 ++++++++++++++++++++ src/profiling/CommandHandler.hpp | 61 ++++++++++++++ src/profiling/CommandThread.cpp | 87 -------------------- src/profiling/CommandThread.hpp | 65 --------------- src/profiling/test/ProfilingTests.cpp | 146 ++++++++++++++++------------------ 6 files changed, 217 insertions(+), 231 deletions(-) create mode 100644 src/profiling/CommandHandler.cpp create mode 100644 src/profiling/CommandHandler.hpp delete mode 100644 src/profiling/CommandThread.cpp delete mode 100644 src/profiling/CommandThread.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c0c53bdda..8ef2949ed6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -428,14 +428,14 @@ list(APPEND armnn_sources src/armnn/optimizations/SquashEqualSiblings.hpp src/profiling/BufferManager.cpp src/profiling/BufferManager.hpp + src/profiling/CommandHandler.cpp + src/profiling/CommandHandler.hpp src/profiling/CommandHandlerFunctor.cpp src/profiling/CommandHandlerFunctor.hpp src/profiling/CommandHandlerKey.cpp src/profiling/CommandHandlerKey.hpp src/profiling/CommandHandlerRegistry.cpp src/profiling/CommandHandlerRegistry.hpp - src/profiling/CommandThread.cpp - src/profiling/CommandThread.hpp src/profiling/ConnectionAcknowledgedCommandHandler.cpp src/profiling/ConnectionAcknowledgedCommandHandler.hpp src/profiling/CounterDirectory.cpp diff --git a/src/profiling/CommandHandler.cpp b/src/profiling/CommandHandler.cpp new file mode 100644 index 0000000000..5eddfd5ec3 --- /dev/null +++ b/src/profiling/CommandHandler.cpp @@ -0,0 +1,85 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "CommandHandler.hpp" + +namespace armnn +{ + +namespace profiling +{ + +void CommandHandler::Start(IProfilingConnection& profilingConnection) +{ + if (IsRunning()) + { + return; + } + + m_IsRunning.store(true, std::memory_order_relaxed); + m_KeepRunning.store(true, std::memory_order_relaxed); + m_CommandThread = std::thread(&CommandHandler::HandleCommands, this, std::ref(profilingConnection)); +} + +void CommandHandler::Stop() +{ + m_KeepRunning.store(false, std::memory_order_relaxed); + + if (m_CommandThread.joinable()) + { + m_CommandThread.join(); + } +} + +bool CommandHandler::IsRunning() const +{ + return m_IsRunning.load(std::memory_order_relaxed); +} + +void CommandHandler::SetTimeout(uint32_t timeout) +{ + m_Timeout.store(timeout, std::memory_order_relaxed); +} + +void CommandHandler::SetStopAfterTimeout(bool stopAfterTimeout) +{ + m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed); +} + +void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection) +{ + do + { + try + { + Packet packet = profilingConnection.ReadPacket(m_Timeout); + Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId()); + + CommandHandlerFunctor* commandHandlerFunctor = + m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue()); + BOOST_ASSERT(commandHandlerFunctor); + commandHandlerFunctor->operator()(packet); + } + catch (const armnn::TimeoutException&) + { + if (m_StopAfterTimeout) + { + m_KeepRunning.store(false, std::memory_order_relaxed); + } + } + catch (...) + { + // Might want to differentiate the errors more + m_KeepRunning.store(false, std::memory_order_relaxed); + } + } + while (m_KeepRunning.load(std::memory_order_relaxed)); + + m_IsRunning.store(false, std::memory_order_relaxed); +} + +} // namespace profiling + +} // namespace armnn diff --git a/src/profiling/CommandHandler.hpp b/src/profiling/CommandHandler.hpp new file mode 100644 index 0000000000..598eabde76 --- /dev/null +++ b/src/profiling/CommandHandler.hpp @@ -0,0 +1,61 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "CommandHandlerRegistry.hpp" +#include "IProfilingConnection.hpp" +#include "PacketVersionResolver.hpp" + +#include +#include + +namespace armnn +{ + +namespace profiling +{ + +class CommandHandler +{ +public: + CommandHandler(uint32_t timeout, + bool stopAfterTimeout, + CommandHandlerRegistry& commandHandlerRegistry, + PacketVersionResolver& packetVersionResolver) + : m_Timeout(timeout) + , m_StopAfterTimeout(stopAfterTimeout) + , m_IsRunning(false) + , m_KeepRunning(false) + , m_CommandThread() + , m_CommandHandlerRegistry(commandHandlerRegistry) + , m_PacketVersionResolver(packetVersionResolver) + {} + ~CommandHandler() { Stop(); } + + void Start(IProfilingConnection& profilingConnection); + void Stop(); + + bool IsRunning() const; + + void SetTimeout(uint32_t timeout); + void SetStopAfterTimeout(bool stopAfterTimeout); + +private: + void HandleCommands(IProfilingConnection& profilingConnection); + + std::atomic m_Timeout; + std::atomic m_StopAfterTimeout; + std::atomic m_IsRunning; + std::atomic m_KeepRunning; + std::thread m_CommandThread; + + CommandHandlerRegistry& m_CommandHandlerRegistry; + PacketVersionResolver& m_PacketVersionResolver; +}; + +} // namespace profiling + +} // namespace armnn diff --git a/src/profiling/CommandThread.cpp b/src/profiling/CommandThread.cpp deleted file mode 100644 index 320e4bcf5c..0000000000 --- a/src/profiling/CommandThread.cpp +++ /dev/null @@ -1,87 +0,0 @@ -// -// Copyright © 2019 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include -#include "CommandThread.hpp" - -namespace armnn -{ - -namespace profiling -{ - -void CommandThread::Start() -{ - if (IsRunning()) - { - return; - } - - m_IsRunning.store(true, std::memory_order_relaxed); - m_KeepRunning.store(true, std::memory_order_relaxed); - m_CommandThread = std::thread(&CommandThread::WaitForPacket, this); -} - -void CommandThread::Stop() -{ - m_KeepRunning.store(false, std::memory_order_relaxed); - - if (m_CommandThread.joinable()) - { - m_CommandThread.join(); - } -} - -bool CommandThread::IsRunning() const -{ - return m_IsRunning.load(std::memory_order_relaxed); -} - -void CommandThread::SetTimeout(uint32_t timeout) -{ - m_Timeout.store(timeout, std::memory_order_relaxed); -} - -void CommandThread::SetStopAfterTimeout(bool stopAfterTimeout) -{ - m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed); -} - -void CommandThread::WaitForPacket() -{ - do - { - try - { - Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout); - Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId()); - - CommandHandlerFunctor* commandHandlerFunctor = - m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue()); - BOOST_ASSERT(commandHandlerFunctor); - commandHandlerFunctor->operator()(packet); - } - catch (const armnn::TimeoutException&) - { - if (m_StopAfterTimeout) - { - m_KeepRunning.store(false, std::memory_order_relaxed); - } - } - catch (...) - { - // Might want to differentiate the errors more - m_KeepRunning.store(false, std::memory_order_relaxed); - } - - } - while (m_KeepRunning.load(std::memory_order_relaxed)); - - m_IsRunning.store(false, std::memory_order_relaxed); -} - -} // namespace profiling - -} // namespace armnn diff --git a/src/profiling/CommandThread.hpp b/src/profiling/CommandThread.hpp deleted file mode 100644 index 0456ba4372..0000000000 --- a/src/profiling/CommandThread.hpp +++ /dev/null @@ -1,65 +0,0 @@ -// -// Copyright © 2019 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "CommandHandlerRegistry.hpp" -#include "IProfilingConnection.hpp" -#include "PacketVersionResolver.hpp" -#include "ProfilingService.hpp" - -#include -#include - -namespace armnn -{ - -namespace profiling -{ - -class CommandThread -{ -public: - CommandThread(uint32_t timeout, - bool stopAfterTimeout, - CommandHandlerRegistry& commandHandlerRegistry, - PacketVersionResolver& packetVersionResolver, - IProfilingConnection& socketProfilingConnection) - : m_Timeout(timeout) - , m_StopAfterTimeout(stopAfterTimeout) - , m_IsRunning(false) - , m_KeepRunning(false) - , m_CommandThread() - , m_CommandHandlerRegistry(commandHandlerRegistry) - , m_PacketVersionResolver(packetVersionResolver) - , m_SocketProfilingConnection(socketProfilingConnection) - {} - ~CommandThread() { Stop(); } - - void Start(); - void Stop(); - - bool IsRunning() const; - - void SetTimeout(uint32_t timeout); - void SetStopAfterTimeout(bool stopAfterTimeout); - -private: - void WaitForPacket(); - - std::atomic m_Timeout; - std::atomic m_StopAfterTimeout; - std::atomic m_IsRunning; - std::atomic m_KeepRunning; - std::thread m_CommandThread; - - CommandHandlerRegistry& m_CommandHandlerRegistry; - PacketVersionResolver& m_PacketVersionResolver; - IProfilingConnection& m_SocketProfilingConnection; -}; - -} // namespace profiling - -} // namespace armnn diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 9dd7cd3d64..bc962e3b17 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -4,8 +4,8 @@ // #include "SendCounterPacketTests.hpp" -#include "../CommandThread.hpp" +#include #include #include #include @@ -40,10 +40,10 @@ #include #include -BOOST_AUTO_TEST_SUITE(ExternalProfiling) - using namespace armnn::profiling; +BOOST_AUTO_TEST_SUITE(ExternalProfiling) + BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons) { CommandHandlerKey testKey0(1, 1); @@ -97,17 +97,11 @@ public: TestProfilingConnectionBase() = default; ~TestProfilingConnectionBase() = default; - bool IsOpen() - { - return true; - } + bool IsOpen() { return true; } - void Close(){} + void Close() {} - bool WritePacket(const unsigned char* buffer, uint32_t length) - { - return false; - } + bool WritePacket(const unsigned char* buffer, uint32_t length) { return false; } Packet ReadPacket(uint32_t timeout) { @@ -118,9 +112,8 @@ public: } }; -class TestProfilingConnectionTimeoutError :public TestProfilingConnectionBase +class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase { - int readRequests = 0; public: Packet ReadPacket(uint32_t timeout) { if (readRequests < 3) @@ -133,6 +126,9 @@ public: //Return connection acknowledged packet after three timeouts return {65536 ,0 , packetData}; } + +private: + int readRequests = 0; }; class TestProfilingConnectionArmnnError :public TestProfilingConnectionBase @@ -146,94 +142,90 @@ public: } }; -BOOST_AUTO_TEST_CASE(CheckCommandThread) +BOOST_AUTO_TEST_CASE(CheckCommandHandler) { - PacketVersionResolver packetVersionResolver; - ProfilingStateMachine profilingStateMachine; + PacketVersionResolver packetVersionResolver; + ProfilingStateMachine profilingStateMachine; - TestProfilingConnectionBase testProfilingConnectionBase; - TestProfilingConnectionTimeoutError testProfilingConnectionTimeOutError; - TestProfilingConnectionArmnnError testProfilingConnectionArmnnError; + TestProfilingConnectionBase testProfilingConnectionBase; + TestProfilingConnectionTimeoutError testProfilingConnectionTimeOutError; + TestProfilingConnectionArmnnError testProfilingConnectionArmnnError; - ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(1, 4194304, profilingStateMachine); - CommandHandlerRegistry commandHandlerRegistry; + ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(1, 4194304, profilingStateMachine); + CommandHandlerRegistry commandHandlerRegistry; - commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler, 1, 4194304); + commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler, 1, 4194304); - profilingStateMachine.TransitionToState(ProfilingState::NotConnected); - profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); + profilingStateMachine.TransitionToState(ProfilingState::NotConnected); + profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); - CommandThread commandThread0(1, - true, - commandHandlerRegistry, - packetVersionResolver, - testProfilingConnectionBase); + CommandHandler commandHandler0(1, + true, + commandHandlerRegistry, + packetVersionResolver); - commandThread0.Start(); - commandThread0.Start(); - commandThread0.Start(); + commandHandler0.Start(testProfilingConnectionBase); + commandHandler0.Start(testProfilingConnectionBase); + commandHandler0.Start(testProfilingConnectionBase); - commandThread0.Stop(); + commandHandler0.Stop(); - BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active); + BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active); - profilingStateMachine.TransitionToState(ProfilingState::NotConnected); - profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); - //commandThread1 should give up after one timeout - CommandThread commandThread1(1, - true, - commandHandlerRegistry, - packetVersionResolver, - testProfilingConnectionTimeOutError); + profilingStateMachine.TransitionToState(ProfilingState::NotConnected); + profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); + // commandHandler1 should give up after one timeout + CommandHandler commandHandler1(1, + true, + commandHandlerRegistry, + packetVersionResolver); - commandThread1.Start(); + commandHandler1.Start(testProfilingConnectionTimeOutError); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); - BOOST_CHECK(!commandThread1.IsRunning()); - commandThread1.Stop(); + BOOST_CHECK(!commandHandler1.IsRunning()); + commandHandler1.Stop(); - BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck); - //now commandThread1 should persist after a timeout - commandThread1.SetStopAfterTimeout(false); - commandThread1.Start(); + BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck); + // Now commandHandler1 should persist after a timeout + commandHandler1.SetStopAfterTimeout(false); + commandHandler1.Start(testProfilingConnectionTimeOutError); - for (int i = 0; i < 100; i++) + for (int i = 0; i < 100; i++) + { + if (profilingStateMachine.GetCurrentState() == ProfilingState::Active) { - if (profilingStateMachine.GetCurrentState() == ProfilingState::Active) - { - break; - } - else - { - std::this_thread::sleep_for(std::chrono::milliseconds(5)); - } + break; } - commandThread1.Stop(); + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } - BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active); + commandHandler1.Stop(); - CommandThread commandThread2(1, - false, - commandHandlerRegistry, - packetVersionResolver, - testProfilingConnectionArmnnError); + BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active); - commandThread2.Start(); + CommandHandler commandHandler2(1, + false, + commandHandlerRegistry, + packetVersionResolver); - for (int i = 0; i < 100; i++) + commandHandler2.Start(testProfilingConnectionArmnnError); + + for (int i = 0; i < 100; i++) + { + if (!commandHandler2.IsRunning()) { - if (!commandThread2.IsRunning()) - { - //commandThread2 should stop once it encounters a non timing error - return; - } - std::this_thread::sleep_for(std::chrono::milliseconds(5)); + // commandHandler2 should stop once it encounters a non timing error + return; } - BOOST_ERROR("commandThread2 has failed to stop"); - commandThread2.Stop(); + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + + BOOST_ERROR("commandHandler2 has failed to stop"); + commandHandler2.Stop(); } BOOST_AUTO_TEST_CASE(CheckEncodeVersion) -- cgit v1.2.1