aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-10-07 12:35:21 +0100
committerJim Flynn Arm <jim.flynn@arm.com>2019-10-08 08:22:51 +0000
commitc2728f95086c54aa842e4c1dae8f3b5c290a72fa (patch)
tree82002c3d0c97abfeed905d0e922579dab09b2c31
parente61ffd00a37f02338129e92d65be2f01600014c0 (diff)
downloadarmnn-c2728f95086c54aa842e4c1dae8f3b5c290a72fa.tar.gz
IVGCVSW-3937 Refactor and improve the CommandHandleRegistry class
* Added simplified RegisterFunctor method * Code refactoring * Updated the unit tests accordingly Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com> Change-Id: Iee941d898facd9c1ab5366e87c611c99a0468830
-rw-r--r--src/profiling/CommandHandler.cpp27
-rw-r--r--src/profiling/CommandHandler.hpp9
-rw-r--r--src/profiling/CommandHandlerFunctor.hpp7
-rw-r--r--src/profiling/CommandHandlerRegistry.cpp28
-rw-r--r--src/profiling/CommandHandlerRegistry.hpp2
-rw-r--r--src/profiling/ConnectionAcknowledgedCommandHandler.cpp8
-rw-r--r--src/profiling/ConnectionAcknowledgedCommandHandler.hpp6
-rw-r--r--src/profiling/test/ProfilingTests.cpp8
8 files changed, 54 insertions, 41 deletions
diff --git a/src/profiling/CommandHandler.cpp b/src/profiling/CommandHandler.cpp
index 5eddfd5ec3..49784056bf 100644
--- a/src/profiling/CommandHandler.cpp
+++ b/src/profiling/CommandHandler.cpp
@@ -18,14 +18,14 @@ void CommandHandler::Start(IProfilingConnection& profilingConnection)
return;
}
- m_IsRunning.store(true, std::memory_order_relaxed);
- m_KeepRunning.store(true, std::memory_order_relaxed);
+ m_IsRunning.store(true);
+ m_KeepRunning.store(true);
m_CommandThread = std::thread(&CommandHandler::HandleCommands, this, std::ref(profilingConnection));
}
void CommandHandler::Stop()
{
- m_KeepRunning.store(false, std::memory_order_relaxed);
+ m_KeepRunning.store(false);
if (m_CommandThread.joinable())
{
@@ -33,21 +33,6 @@ void CommandHandler::Stop()
}
}
-bool CommandHandler::IsRunning() const
-{
- return m_IsRunning.load(std::memory_order_relaxed);
-}
-
-void CommandHandler::SetTimeout(uint32_t timeout)
-{
- m_Timeout.store(timeout, std::memory_order_relaxed);
-}
-
-void CommandHandler::SetStopAfterTimeout(bool stopAfterTimeout)
-{
- m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed);
-}
-
void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection)
{
do
@@ -72,12 +57,12 @@ void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection)
catch (...)
{
// Might want to differentiate the errors more
- m_KeepRunning.store(false, std::memory_order_relaxed);
+ m_KeepRunning.store(false);
}
}
- while (m_KeepRunning.load(std::memory_order_relaxed));
+ while (m_KeepRunning.load());
- m_IsRunning.store(false, std::memory_order_relaxed);
+ m_IsRunning.store(false);
}
} // namespace profiling
diff --git a/src/profiling/CommandHandler.hpp b/src/profiling/CommandHandler.hpp
index 598eabde76..0cc23429cd 100644
--- a/src/profiling/CommandHandler.hpp
+++ b/src/profiling/CommandHandler.hpp
@@ -35,13 +35,12 @@ public:
{}
~CommandHandler() { Stop(); }
+ void SetTimeout(uint32_t timeout) { m_Timeout.store(timeout); }
+ void SetStopAfterTimeout(bool stopAfterTimeout) { m_StopAfterTimeout.store(stopAfterTimeout); }
+
void Start(IProfilingConnection& profilingConnection);
void Stop();
-
- bool IsRunning() const;
-
- void SetTimeout(uint32_t timeout);
- void SetStopAfterTimeout(bool stopAfterTimeout);
+ bool IsRunning() const { return m_IsRunning.load(); }
private:
void HandleCommands(IProfilingConnection& profilingConnection);
diff --git a/src/profiling/CommandHandlerFunctor.hpp b/src/profiling/CommandHandlerFunctor.hpp
index a9a59c145f..2e1e05fd32 100644
--- a/src/profiling/CommandHandlerFunctor.hpp
+++ b/src/profiling/CommandHandlerFunctor.hpp
@@ -18,12 +18,15 @@ namespace profiling
class CommandHandlerFunctor
{
public:
- CommandHandlerFunctor(uint32_t packetId, uint32_t version) : m_PacketId(packetId), m_Version(version) {};
+ CommandHandlerFunctor(uint32_t packetId, uint32_t version)
+ : m_PacketId(packetId)
+ , m_Version(version)
+ {}
uint32_t GetPacketId() const;
uint32_t GetVersion() const;
- virtual void operator()(const Packet& packet) {};
+ virtual void operator()(const Packet& packet) {}
private:
uint32_t m_PacketId;
diff --git a/src/profiling/CommandHandlerRegistry.cpp b/src/profiling/CommandHandlerRegistry.cpp
index 97313475ff..bd9b318835 100644
--- a/src/profiling/CommandHandlerRegistry.cpp
+++ b/src/profiling/CommandHandlerRegistry.cpp
@@ -6,7 +6,7 @@
#include "CommandHandlerRegistry.hpp"
#include <boost/assert.hpp>
-#include <boost/log/trivial.hpp>
+#include <boost/format.hpp>
namespace armnn
{
@@ -16,11 +16,19 @@ namespace profiling
void CommandHandlerRegistry::RegisterFunctor(CommandHandlerFunctor* functor, uint32_t packetId, uint32_t version)
{
- BOOST_ASSERT_MSG(functor, "Provided functor should not be a nullptr.");
+ BOOST_ASSERT_MSG(functor, "Provided functor should not be a nullptr");
+
CommandHandlerKey key(packetId, version);
registry[key] = functor;
}
+void CommandHandlerRegistry::RegisterFunctor(CommandHandlerFunctor* functor)
+{
+ BOOST_ASSERT_MSG(functor, "Provided functor should not be a nullptr");
+
+ RegisterFunctor(functor, functor->GetPacketId(), functor->GetVersion());
+}
+
CommandHandlerFunctor* CommandHandlerRegistry::GetFunctor(uint32_t packetId, uint32_t version) const
{
CommandHandlerKey key(packetId, version);
@@ -28,10 +36,22 @@ CommandHandlerFunctor* CommandHandlerRegistry::GetFunctor(uint32_t packetId, uin
// Check that the requested key exists
if (registry.find(key) == registry.end())
{
- throw armnn::Exception("Functor with requested PacketId or Version does not exist.");
+ throw armnn::InvalidArgumentException(
+ boost::str(boost::format("Functor with requested PacketId=%1% and Version=%2% does not exist")
+ % packetId
+ % version));
+ }
+
+ CommandHandlerFunctor* commandHandlerFunctor = registry.at(key);
+ if (commandHandlerFunctor == nullptr)
+ {
+ throw RuntimeException(
+ boost::str(boost::format("Invalid functor registered for PacketId=%1% and Version=%2%")
+ % packetId
+ % version));
}
- return registry.at(key);
+ return commandHandlerFunctor;
}
} // namespace profiling
diff --git a/src/profiling/CommandHandlerRegistry.hpp b/src/profiling/CommandHandlerRegistry.hpp
index 61d45b0fd2..9d514bfcc3 100644
--- a/src/profiling/CommandHandlerRegistry.hpp
+++ b/src/profiling/CommandHandlerRegistry.hpp
@@ -36,6 +36,8 @@ public:
void RegisterFunctor(CommandHandlerFunctor* functor, uint32_t packetId, uint32_t version);
+ void RegisterFunctor(CommandHandlerFunctor* functor);
+
CommandHandlerFunctor* GetFunctor(uint32_t packetId, uint32_t version) const;
private:
diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
index 0f83a3181b..f90b601b7e 100644
--- a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
+++ b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
@@ -17,10 +17,12 @@ void ConnectionAcknowledgedCommandHandler::operator()(const Packet& packet)
{
if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 1u))
{
- throw armnn::Exception(std::string("Expected Packet family = 0, id = 1 but received family =")
- + std::to_string(packet.GetPacketFamily())
- +" id = " + std::to_string(packet.GetPacketId()));
+ throw armnn::InvalidArgumentException(std::string("Expected Packet family = 0, id = 1 but received family = ")
+ + std::to_string(packet.GetPacketFamily())
+ + " id = " + std::to_string(packet.GetPacketId()));
}
+
+ // Once a Connection Acknowledged packet has been received, move to the Active state immediately
m_StateMachine.TransitionToState(ProfilingState::Active);
}
diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.hpp b/src/profiling/ConnectionAcknowledgedCommandHandler.hpp
index f61495e5c3..d0dc07ae00 100644
--- a/src/profiling/ConnectionAcknowledgedCommandHandler.hpp
+++ b/src/profiling/ConnectionAcknowledgedCommandHandler.hpp
@@ -15,14 +15,16 @@ namespace armnn
namespace profiling
{
-class ConnectionAcknowledgedCommandHandler : public CommandHandlerFunctor
+class ConnectionAcknowledgedCommandHandler final : public CommandHandlerFunctor
{
public:
ConnectionAcknowledgedCommandHandler(uint32_t packetId,
uint32_t version,
ProfilingStateMachine& profilingStateMachine)
- : CommandHandlerFunctor(packetId, version), m_StateMachine(profilingStateMachine) {}
+ : CommandHandlerFunctor(packetId, version)
+ , m_StateMachine(profilingStateMachine)
+ {}
void operator()(const Packet& packet) override;
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index ba1e6cfa5a..91568d111d 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -154,7 +154,7 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandler)
ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(1, 4194304, profilingStateMachine);
CommandHandlerRegistry commandHandlerRegistry;
- commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler, 1, 4194304);
+ commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler);
profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
@@ -388,9 +388,9 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry)
CommandHandlerRegistry registry;
// Register multiple different derived classes
- registry.RegisterFunctor(&testFunctorA, testFunctorA.GetPacketId(), testFunctorA.GetVersion());
- registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion());
- registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion());
+ registry.RegisterFunctor(&testFunctorA);
+ registry.RegisterFunctor(&testFunctorB);
+ registry.RegisterFunctor(&testFunctorC);
std::unique_ptr<char[]> packetDataA;
std::unique_ptr<char[]> packetDataB;