From c2728f95086c54aa842e4c1dae8f3b5c290a72fa Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Mon, 7 Oct 2019 12:35:21 +0100 Subject: IVGCVSW-3937 Refactor and improve the CommandHandleRegistry class * Added simplified RegisterFunctor method * Code refactoring * Updated the unit tests accordingly Signed-off-by: Matteo Martincigh Change-Id: Iee941d898facd9c1ab5366e87c611c99a0468830 --- src/profiling/CommandHandler.cpp | 27 +++++---------------- src/profiling/CommandHandler.hpp | 9 ++++--- src/profiling/CommandHandlerFunctor.hpp | 7 ++++-- src/profiling/CommandHandlerRegistry.cpp | 28 ++++++++++++++++++---- src/profiling/CommandHandlerRegistry.hpp | 2 ++ .../ConnectionAcknowledgedCommandHandler.cpp | 8 ++++--- .../ConnectionAcknowledgedCommandHandler.hpp | 6 +++-- src/profiling/test/ProfilingTests.cpp | 8 +++---- 8 files changed, 54 insertions(+), 41 deletions(-) (limited to 'src') diff --git a/src/profiling/CommandHandler.cpp b/src/profiling/CommandHandler.cpp index 5eddfd5ec3..49784056bf 100644 --- a/src/profiling/CommandHandler.cpp +++ b/src/profiling/CommandHandler.cpp @@ -18,14 +18,14 @@ void CommandHandler::Start(IProfilingConnection& profilingConnection) return; } - m_IsRunning.store(true, std::memory_order_relaxed); - m_KeepRunning.store(true, std::memory_order_relaxed); + m_IsRunning.store(true); + m_KeepRunning.store(true); m_CommandThread = std::thread(&CommandHandler::HandleCommands, this, std::ref(profilingConnection)); } void CommandHandler::Stop() { - m_KeepRunning.store(false, std::memory_order_relaxed); + m_KeepRunning.store(false); if (m_CommandThread.joinable()) { @@ -33,21 +33,6 @@ void CommandHandler::Stop() } } -bool CommandHandler::IsRunning() const -{ - return m_IsRunning.load(std::memory_order_relaxed); -} - -void CommandHandler::SetTimeout(uint32_t timeout) -{ - m_Timeout.store(timeout, std::memory_order_relaxed); -} - -void CommandHandler::SetStopAfterTimeout(bool stopAfterTimeout) -{ - m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed); -} - void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection) { do @@ -72,12 +57,12 @@ void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection) catch (...) { // Might want to differentiate the errors more - m_KeepRunning.store(false, std::memory_order_relaxed); + m_KeepRunning.store(false); } } - while (m_KeepRunning.load(std::memory_order_relaxed)); + while (m_KeepRunning.load()); - m_IsRunning.store(false, std::memory_order_relaxed); + m_IsRunning.store(false); } } // namespace profiling diff --git a/src/profiling/CommandHandler.hpp b/src/profiling/CommandHandler.hpp index 598eabde76..0cc23429cd 100644 --- a/src/profiling/CommandHandler.hpp +++ b/src/profiling/CommandHandler.hpp @@ -35,13 +35,12 @@ public: {} ~CommandHandler() { Stop(); } + void SetTimeout(uint32_t timeout) { m_Timeout.store(timeout); } + void SetStopAfterTimeout(bool stopAfterTimeout) { m_StopAfterTimeout.store(stopAfterTimeout); } + void Start(IProfilingConnection& profilingConnection); void Stop(); - - bool IsRunning() const; - - void SetTimeout(uint32_t timeout); - void SetStopAfterTimeout(bool stopAfterTimeout); + bool IsRunning() const { return m_IsRunning.load(); } private: void HandleCommands(IProfilingConnection& profilingConnection); diff --git a/src/profiling/CommandHandlerFunctor.hpp b/src/profiling/CommandHandlerFunctor.hpp index a9a59c145f..2e1e05fd32 100644 --- a/src/profiling/CommandHandlerFunctor.hpp +++ b/src/profiling/CommandHandlerFunctor.hpp @@ -18,12 +18,15 @@ namespace profiling class CommandHandlerFunctor { public: - CommandHandlerFunctor(uint32_t packetId, uint32_t version) : m_PacketId(packetId), m_Version(version) {}; + CommandHandlerFunctor(uint32_t packetId, uint32_t version) + : m_PacketId(packetId) + , m_Version(version) + {} uint32_t GetPacketId() const; uint32_t GetVersion() const; - virtual void operator()(const Packet& packet) {}; + virtual void operator()(const Packet& packet) {} private: uint32_t m_PacketId; diff --git a/src/profiling/CommandHandlerRegistry.cpp b/src/profiling/CommandHandlerRegistry.cpp index 97313475ff..bd9b318835 100644 --- a/src/profiling/CommandHandlerRegistry.cpp +++ b/src/profiling/CommandHandlerRegistry.cpp @@ -6,7 +6,7 @@ #include "CommandHandlerRegistry.hpp" #include -#include +#include namespace armnn { @@ -16,11 +16,19 @@ namespace profiling void CommandHandlerRegistry::RegisterFunctor(CommandHandlerFunctor* functor, uint32_t packetId, uint32_t version) { - BOOST_ASSERT_MSG(functor, "Provided functor should not be a nullptr."); + BOOST_ASSERT_MSG(functor, "Provided functor should not be a nullptr"); + CommandHandlerKey key(packetId, version); registry[key] = functor; } +void CommandHandlerRegistry::RegisterFunctor(CommandHandlerFunctor* functor) +{ + BOOST_ASSERT_MSG(functor, "Provided functor should not be a nullptr"); + + RegisterFunctor(functor, functor->GetPacketId(), functor->GetVersion()); +} + CommandHandlerFunctor* CommandHandlerRegistry::GetFunctor(uint32_t packetId, uint32_t version) const { CommandHandlerKey key(packetId, version); @@ -28,10 +36,22 @@ CommandHandlerFunctor* CommandHandlerRegistry::GetFunctor(uint32_t packetId, uin // Check that the requested key exists if (registry.find(key) == registry.end()) { - throw armnn::Exception("Functor with requested PacketId or Version does not exist."); + throw armnn::InvalidArgumentException( + boost::str(boost::format("Functor with requested PacketId=%1% and Version=%2% does not exist") + % packetId + % version)); + } + + CommandHandlerFunctor* commandHandlerFunctor = registry.at(key); + if (commandHandlerFunctor == nullptr) + { + throw RuntimeException( + boost::str(boost::format("Invalid functor registered for PacketId=%1% and Version=%2%") + % packetId + % version)); } - return registry.at(key); + return commandHandlerFunctor; } } // namespace profiling diff --git a/src/profiling/CommandHandlerRegistry.hpp b/src/profiling/CommandHandlerRegistry.hpp index 61d45b0fd2..9d514bfcc3 100644 --- a/src/profiling/CommandHandlerRegistry.hpp +++ b/src/profiling/CommandHandlerRegistry.hpp @@ -36,6 +36,8 @@ public: void RegisterFunctor(CommandHandlerFunctor* functor, uint32_t packetId, uint32_t version); + void RegisterFunctor(CommandHandlerFunctor* functor); + CommandHandlerFunctor* GetFunctor(uint32_t packetId, uint32_t version) const; private: diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp index 0f83a3181b..f90b601b7e 100644 --- a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp +++ b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp @@ -17,10 +17,12 @@ void ConnectionAcknowledgedCommandHandler::operator()(const Packet& packet) { if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 1u)) { - throw armnn::Exception(std::string("Expected Packet family = 0, id = 1 but received family =") - + std::to_string(packet.GetPacketFamily()) - +" id = " + std::to_string(packet.GetPacketId())); + throw armnn::InvalidArgumentException(std::string("Expected Packet family = 0, id = 1 but received family = ") + + std::to_string(packet.GetPacketFamily()) + + " id = " + std::to_string(packet.GetPacketId())); } + + // Once a Connection Acknowledged packet has been received, move to the Active state immediately m_StateMachine.TransitionToState(ProfilingState::Active); } diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.hpp b/src/profiling/ConnectionAcknowledgedCommandHandler.hpp index f61495e5c3..d0dc07ae00 100644 --- a/src/profiling/ConnectionAcknowledgedCommandHandler.hpp +++ b/src/profiling/ConnectionAcknowledgedCommandHandler.hpp @@ -15,14 +15,16 @@ namespace armnn namespace profiling { -class ConnectionAcknowledgedCommandHandler : public CommandHandlerFunctor +class ConnectionAcknowledgedCommandHandler final : public CommandHandlerFunctor { public: ConnectionAcknowledgedCommandHandler(uint32_t packetId, uint32_t version, ProfilingStateMachine& profilingStateMachine) - : CommandHandlerFunctor(packetId, version), m_StateMachine(profilingStateMachine) {} + : CommandHandlerFunctor(packetId, version) + , m_StateMachine(profilingStateMachine) + {} void operator()(const Packet& packet) override; diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index ba1e6cfa5a..91568d111d 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -154,7 +154,7 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandler) ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(1, 4194304, profilingStateMachine); CommandHandlerRegistry commandHandlerRegistry; - commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler, 1, 4194304); + commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler); profilingStateMachine.TransitionToState(ProfilingState::NotConnected); profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); @@ -388,9 +388,9 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry) CommandHandlerRegistry registry; // Register multiple different derived classes - registry.RegisterFunctor(&testFunctorA, testFunctorA.GetPacketId(), testFunctorA.GetVersion()); - registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion()); - registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion()); + registry.RegisterFunctor(&testFunctorA); + registry.RegisterFunctor(&testFunctorB); + registry.RegisterFunctor(&testFunctorC); std::unique_ptr packetDataA; std::unique_ptr packetDataB; -- cgit v1.2.1