From 32cbb0c7cd99786191c080f5a619b3dab23b4cd0 Mon Sep 17 00:00:00 2001 From: David Beck Date: Tue, 9 Oct 2018 15:46:08 +0100 Subject: IVGCVSW-1987 : registry for backend creation functions (factories) Change-Id: I13d2d3dc763e1d05dffddb34472bd4f9e632c776 --- CMakeLists.txt | 1 + include/armnn/Types.hpp | 3 +- src/armnn/DeviceSpec.hpp | 4 +- src/backends/BackendRegistry.cpp | 45 +++++++++++++++ src/backends/BackendRegistry.hpp | 52 +++++++++++++++++ src/backends/CMakeLists.txt | 2 + src/backends/common.mk | 1 + src/backends/test/BackendRegistryTests.cpp | 91 ++++++++++++++++++++++++++++++ 8 files changed, 196 insertions(+), 3 deletions(-) create mode 100644 src/backends/BackendRegistry.cpp create mode 100644 src/backends/BackendRegistry.hpp create mode 100644 src/backends/test/BackendRegistryTests.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 669c92fd3f..501da806ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -378,6 +378,7 @@ if(BUILD_UNIT_TESTS) src/armnn/test/InstrumentTests.cpp src/armnn/test/ObservableTest.cpp src/armnn/test/OptionalTest.cpp + src/backends/test/BackendRegistryTests.cpp src/backends/test/IsLayerSupportedTestImpl.hpp src/backends/test/WorkloadDataValidation.cpp src/backends/test/TensorCopyUtils.hpp diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index 12ecda0c39..b7ee9472a3 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -117,7 +117,8 @@ public: virtual const ILayerSupport& GetLayerSupport() const = 0; }; -using IBackendPtr = std::shared_ptr; +using IBackendSharedPtr = std::shared_ptr; +using IBackendUniquePtr = std::unique_ptr; /// Device specific knowledge to be passed to the optimizer. class IDeviceSpec diff --git a/src/armnn/DeviceSpec.hpp b/src/armnn/DeviceSpec.hpp index dbc04f0af6..34acbcbdec 100644 --- a/src/armnn/DeviceSpec.hpp +++ b/src/armnn/DeviceSpec.hpp @@ -17,9 +17,9 @@ public: DeviceSpec() {} virtual ~DeviceSpec() {} - virtual std::vector GetBackends() const + virtual std::vector GetBackends() const { - return std::vector(); + return std::vector(); } std::set m_SupportedComputeDevices; diff --git a/src/backends/BackendRegistry.cpp b/src/backends/BackendRegistry.cpp new file mode 100644 index 0000000000..68336c45b9 --- /dev/null +++ b/src/backends/BackendRegistry.cpp @@ -0,0 +1,45 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "BackendRegistry.hpp" +#include + +namespace armnn +{ + +BackendRegistry& BackendRegistry::Instance() +{ + static BackendRegistry instance; + return instance; +} + +void BackendRegistry::Register(const std::string& name, FactoryFunction factory) +{ + if (m_BackendFactories.count(name) > 0) + { + throw InvalidArgumentException(name + " already registered as backend"); + } + + m_BackendFactories[name] = factory; +} + +BackendRegistry::FactoryFunction BackendRegistry::GetFactory(const std::string& name) const +{ + auto it = m_BackendFactories.find(name); + if (it == m_BackendFactories.end()) + { + throw InvalidArgumentException(name + " has no backend factory registered"); + } + + return it->second; +} + +void BackendRegistry::Swap(BackendRegistry::FactoryStorage& other) +{ + BackendRegistry& instance = Instance(); + std::swap(instance.m_BackendFactories, other); +} + +} diff --git a/src/backends/BackendRegistry.hpp b/src/backends/BackendRegistry.hpp new file mode 100644 index 0000000000..ff01d21715 --- /dev/null +++ b/src/backends/BackendRegistry.hpp @@ -0,0 +1,52 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include +#include +#include +#include +#include + +namespace armnn +{ + +class IBackend; + +class BackendRegistry +{ +public: + using FactoryFunction = std::function; + + static BackendRegistry& Instance(); + void Register(const std::string& name, FactoryFunction factory); + FactoryFunction GetFactory(const std::string& name) const; + + struct Helper + { + Helper(const std::string& name, FactoryFunction factory) + { + BackendRegistry::Instance().Register(name, factory); + } + }; + + size_t Size() const { return m_BackendFactories.size(); } + +protected: + using FactoryStorage = std::unordered_map; + + // For testing only + static void Swap(FactoryStorage& other); + BackendRegistry() {} + ~BackendRegistry() {} + +private: + BackendRegistry(const BackendRegistry&) = delete; + BackendRegistry& operator=(const BackendRegistry&) = delete; + + FactoryStorage m_BackendFactories; +}; + +} // namespace armnn diff --git a/src/backends/CMakeLists.txt b/src/backends/CMakeLists.txt index ea5ad7814c..0bc6888899 100644 --- a/src/backends/CMakeLists.txt +++ b/src/backends/CMakeLists.txt @@ -4,6 +4,8 @@ # list(APPEND armnnBackendsCommon_sources + BackendRegistry.cpp + BackendRegistry.hpp CpuTensorHandle.cpp CpuTensorHandleFwd.hpp CpuTensorHandle.hpp diff --git a/src/backends/common.mk b/src/backends/common.mk index e65281a1a4..99c6e12a3e 100644 --- a/src/backends/common.mk +++ b/src/backends/common.mk @@ -8,6 +8,7 @@ # file in the root of ArmNN COMMON_SOURCES := \ + BackendRegistry.cpp \ CpuTensorHandle.cpp \ MemCopyWorkload.cpp \ OutputHandler.cpp \ diff --git a/src/backends/test/BackendRegistryTests.cpp b/src/backends/test/BackendRegistryTests.cpp new file mode 100644 index 0000000000..e895df63a6 --- /dev/null +++ b/src/backends/test/BackendRegistryTests.cpp @@ -0,0 +1,91 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include + +#include +#include + +namespace +{ + +class SwapRegistryStorage : public armnn::BackendRegistry +{ +public: + SwapRegistryStorage() : armnn::BackendRegistry() + { + Swap(m_TempStorage); + } + + ~SwapRegistryStorage() + { + Swap(m_TempStorage); + } + +private: + BackendRegistry::FactoryStorage m_TempStorage; +}; + +} + +BOOST_AUTO_TEST_SUITE(BackendRegistryTests) + +BOOST_AUTO_TEST_CASE(SwapRegistry) +{ + using armnn::BackendRegistry; + auto nFactories = BackendRegistry::Instance().Size(); + { + SwapRegistryStorage helper; + BOOST_TEST(BackendRegistry::Instance().Size() == 0); + } + BOOST_TEST(BackendRegistry::Instance().Size() == nFactories); +} + +BOOST_AUTO_TEST_CASE(TestRegistryHelper) +{ + using armnn::BackendRegistry; + SwapRegistryStorage helper; + + bool called = false; + BackendRegistry::Helper factoryHelper("HelloWorld", [&called]() { + called = true; + return armnn::IBackendUniquePtr(nullptr, nullptr); + } ); + + // sanity check: the factory has not been called yet + BOOST_TEST(called == false); + + auto factoryFunction = BackendRegistry::Instance().GetFactory("HelloWorld"); + + // sanity check: the factory still not called + BOOST_TEST(called == false); + + factoryFunction(); + BOOST_TEST(called == true); +} + +BOOST_AUTO_TEST_CASE(TestDirectCallToRegistry) +{ + using armnn::BackendRegistry; + SwapRegistryStorage helper; + + bool called = false; + BackendRegistry::Instance().Register("HelloWorld", [&called]() { + called = true; + return armnn::IBackendUniquePtr(nullptr, nullptr); + } ); + + // sanity check: the factory has not been called yet + BOOST_TEST(called == false); + + auto factoryFunction = BackendRegistry::Instance().GetFactory("HelloWorld"); + + // sanity check: the factory still not called + BOOST_TEST(called == false); + + factoryFunction(); + BOOST_TEST(called == true); +} + +BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1