From 94d7915bef33ad59d1bdfa791490268c682c5359 Mon Sep 17 00:00:00 2001 From: Francis Murtagh Date: Fri, 16 Aug 2019 17:45:07 +0100 Subject: IVGCVSW-3550 Create Command Handler Registry Change-Id: I51e34068d79ba660ae2f16b22ad2bb8191d473fa Signed-off-by: Francis Murtagh --- CMakeLists.txt | 2 + src/profiling/CommandHandlerRegistry.cpp | 29 ++++++++++ src/profiling/CommandHandlerRegistry.hpp | 36 +++++++++++++ src/profiling/test/ProfilingTests.cpp | 93 ++++++++++++++++++++++++-------- 4 files changed, 138 insertions(+), 22 deletions(-) create mode 100644 src/profiling/CommandHandlerRegistry.cpp create mode 100644 src/profiling/CommandHandlerRegistry.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index a77b112bb8..df4b742cda 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -415,6 +415,8 @@ list(APPEND armnn_sources src/profiling/CommandHandlerFunctor.hpp src/profiling/CommandHandlerKey.cpp src/profiling/CommandHandlerKey.hpp + src/profiling/CommandHandlerRegistry.cpp + src/profiling/CommandHandlerRegistry.hpp src/profiling/Packet.cpp src/profiling/Packet.hpp third-party/half/half.hpp diff --git a/src/profiling/CommandHandlerRegistry.cpp b/src/profiling/CommandHandlerRegistry.cpp new file mode 100644 index 0000000000..d392db0534 --- /dev/null +++ b/src/profiling/CommandHandlerRegistry.cpp @@ -0,0 +1,29 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "CommandHandlerRegistry.hpp" + +#include +#include + +void CommandHandlerRegistry::RegisterFunctor(CommandHandlerFunctor* functor, uint32_t packetId, uint32_t version) +{ + BOOST_ASSERT_MSG(functor, "Provided functor should not be a nullptr."); + CommandHandlerKey key(packetId, version); + registry[key] = functor; +} + +CommandHandlerFunctor* CommandHandlerRegistry::GetFunctor(uint32_t packetId, uint32_t version) const +{ + CommandHandlerKey key(packetId, version); + + // Check that the requested key exists + if (registry.find(key) == registry.end()) + { + throw armnn::Exception("Functor with requested PacketId or Version does not exist."); + } + + return registry.at(key); +} diff --git a/src/profiling/CommandHandlerRegistry.hpp b/src/profiling/CommandHandlerRegistry.hpp new file mode 100644 index 0000000000..ba81f1790f --- /dev/null +++ b/src/profiling/CommandHandlerRegistry.hpp @@ -0,0 +1,36 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "CommandHandlerFunctor.hpp" +#include "CommandHandlerKey.hpp" + +#include +#include + +struct CommandHandlerHash +{ + std::size_t operator() (const CommandHandlerKey& commandHandlerKey) const + { + std::size_t seed = 0; + boost::hash_combine(seed, commandHandlerKey.GetPacketId()); + boost::hash_combine(seed, commandHandlerKey.GetVersion()); + return seed; + } +}; + +class CommandHandlerRegistry +{ +public: + CommandHandlerRegistry() = default; + + void RegisterFunctor(CommandHandlerFunctor* functor, uint32_t packetId, uint32_t version); + + CommandHandlerFunctor* GetFunctor(uint32_t packetId, uint32_t version) const; + +private: + std::unordered_map registry; +}; \ No newline at end of file diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 26cbfd7183..a8ec0277d2 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -5,6 +5,7 @@ #include "../CommandHandlerKey.hpp" #include "../CommandHandlerFunctor.hpp" +#include "../CommandHandlerRegistry.hpp" #include "../Packet.hpp" #include @@ -80,35 +81,35 @@ BOOST_AUTO_TEST_CASE(CheckPacketClass) BOOST_CHECK(packetTest1.GetPacketClass() == 5); } -BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor) +// Create Derived Classes +class TestFunctorA : public CommandHandlerFunctor { - // Create Derived Classes - class TestFunctorA : public CommandHandlerFunctor - { - public: - using CommandHandlerFunctor::CommandHandlerFunctor; +public: + using CommandHandlerFunctor::CommandHandlerFunctor; - int GetCount() { return m_Count; } + int GetCount() { return m_Count; } - void operator()(const Packet& packet) override - { - m_Count++; - } + void operator()(const Packet& packet) override + { + m_Count++; + } - private: - int m_Count = 0; - }; +private: + int m_Count = 0; +}; - class TestFunctorB : public TestFunctorA - { - using TestFunctorA::TestFunctorA; - }; +class TestFunctorB : public TestFunctorA +{ + using TestFunctorA::TestFunctorA; +}; - class TestFunctorC : public TestFunctorA - { - using TestFunctorA::TestFunctorA; - }; +class TestFunctorC : public TestFunctorA +{ + using TestFunctorA::TestFunctorA; +}; +BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor) +{ // Hard code the version as it will be the same during a single profiling session uint32_t version = 1; @@ -156,4 +157,52 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor) BOOST_CHECK(testFunctorC.GetCount() == 1); } +BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry) +{ + // Hard code the version as it will be the same during a single profiling session + uint32_t version = 1; + + TestFunctorA testFunctorA(461, version); + TestFunctorB testFunctorB(963, version); + TestFunctorC testFunctorC(983, version); + + // Create the Command Handler Registry + CommandHandlerRegistry registry; + + // Register multiple different derived classes + registry.RegisterFunctor(&testFunctorA, testFunctorA.GetPacketId(), testFunctorA.GetVersion()); + registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion()); + registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion()); + + Packet packetA(500000000, 0, nullptr); + Packet packetB(600000000, 0, nullptr); + Packet packetC(400000000, 0, nullptr); + + // Check the correct operator of derived class is called + registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetA); + BOOST_CHECK(testFunctorA.GetCount() == 1); + BOOST_CHECK(testFunctorB.GetCount() == 0); + BOOST_CHECK(testFunctorC.GetCount() == 0); + + registry.GetFunctor(packetB.GetPacketId(), version)->operator()(packetB); + BOOST_CHECK(testFunctorA.GetCount() == 1); + BOOST_CHECK(testFunctorB.GetCount() == 1); + BOOST_CHECK(testFunctorC.GetCount() == 0); + + registry.GetFunctor(packetC.GetPacketId(), version)->operator()(packetC); + BOOST_CHECK(testFunctorA.GetCount() == 1); + BOOST_CHECK(testFunctorB.GetCount() == 1); + BOOST_CHECK(testFunctorC.GetCount() == 1); + + // Re-register an existing key with a new function + registry.RegisterFunctor(&testFunctorC, testFunctorA.GetPacketId(), version); + registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetC); + BOOST_CHECK(testFunctorA.GetCount() == 1); + BOOST_CHECK(testFunctorB.GetCount() == 1); + BOOST_CHECK(testFunctorC.GetCount() == 2); + + // Check that non-existent key returns nullptr for its functor + BOOST_CHECK_THROW(registry.GetFunctor(0, 0), armnn::Exception); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1