aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Beck <david.beck@arm.com>2018-10-09 15:46:08 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-22 16:57:53 +0100
commit32cbb0c7cd99786191c080f5a619b3dab23b4cd0 (patch)
tree0e7db143cdf3faa6430a1ffa5e948165ff60c045
parent43095f31edf103d71a8e2420b549d21fd349b49e (diff)
downloadarmnn-32cbb0c7cd99786191c080f5a619b3dab23b4cd0.tar.gz
IVGCVSW-1987 : registry for backend creation functions (factories)
Change-Id: I13d2d3dc763e1d05dffddb34472bd4f9e632c776
-rw-r--r--CMakeLists.txt1
-rw-r--r--include/armnn/Types.hpp3
-rw-r--r--src/armnn/DeviceSpec.hpp4
-rw-r--r--src/backends/BackendRegistry.cpp45
-rw-r--r--src/backends/BackendRegistry.hpp52
-rw-r--r--src/backends/CMakeLists.txt2
-rw-r--r--src/backends/common.mk1
-rw-r--r--src/backends/test/BackendRegistryTests.cpp91
8 files changed, 196 insertions, 3 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 669c92fd3f..501da806ad 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -378,6 +378,7 @@ if(BUILD_UNIT_TESTS)
src/armnn/test/InstrumentTests.cpp
src/armnn/test/ObservableTest.cpp
src/armnn/test/OptionalTest.cpp
+ src/backends/test/BackendRegistryTests.cpp
src/backends/test/IsLayerSupportedTestImpl.hpp
src/backends/test/WorkloadDataValidation.cpp
src/backends/test/TensorCopyUtils.hpp
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index 12ecda0c39..b7ee9472a3 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -117,7 +117,8 @@ public:
virtual const ILayerSupport& GetLayerSupport() const = 0;
};
-using IBackendPtr = std::shared_ptr<IBackend>;
+using IBackendSharedPtr = std::shared_ptr<IBackend>;
+using IBackendUniquePtr = std::unique_ptr<IBackend, void(*)(IBackend* backend)>;
/// Device specific knowledge to be passed to the optimizer.
class IDeviceSpec
diff --git a/src/armnn/DeviceSpec.hpp b/src/armnn/DeviceSpec.hpp
index dbc04f0af6..34acbcbdec 100644
--- a/src/armnn/DeviceSpec.hpp
+++ b/src/armnn/DeviceSpec.hpp
@@ -17,9 +17,9 @@ public:
DeviceSpec() {}
virtual ~DeviceSpec() {}
- virtual std::vector<IBackendPtr> GetBackends() const
+ virtual std::vector<IBackendSharedPtr> GetBackends() const
{
- return std::vector<IBackendPtr>();
+ return std::vector<IBackendSharedPtr>();
}
std::set<Compute> m_SupportedComputeDevices;
diff --git a/src/backends/BackendRegistry.cpp b/src/backends/BackendRegistry.cpp
new file mode 100644
index 0000000000..68336c45b9
--- /dev/null
+++ b/src/backends/BackendRegistry.cpp
@@ -0,0 +1,45 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "BackendRegistry.hpp"
+#include <armnn/Exceptions.hpp>
+
+namespace armnn
+{
+
+BackendRegistry& BackendRegistry::Instance()
+{
+ static BackendRegistry instance;
+ return instance;
+}
+
+void BackendRegistry::Register(const std::string& name, FactoryFunction factory)
+{
+ if (m_BackendFactories.count(name) > 0)
+ {
+ throw InvalidArgumentException(name + " already registered as backend");
+ }
+
+ m_BackendFactories[name] = factory;
+}
+
+BackendRegistry::FactoryFunction BackendRegistry::GetFactory(const std::string& name) const
+{
+ auto it = m_BackendFactories.find(name);
+ if (it == m_BackendFactories.end())
+ {
+ throw InvalidArgumentException(name + " has no backend factory registered");
+ }
+
+ return it->second;
+}
+
+void BackendRegistry::Swap(BackendRegistry::FactoryStorage& other)
+{
+ BackendRegistry& instance = Instance();
+ std::swap(instance.m_BackendFactories, other);
+}
+
+}
diff --git a/src/backends/BackendRegistry.hpp b/src/backends/BackendRegistry.hpp
new file mode 100644
index 0000000000..ff01d21715
--- /dev/null
+++ b/src/backends/BackendRegistry.hpp
@@ -0,0 +1,52 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include <armnn/Types.hpp>
+#include <string>
+#include <functional>
+#include <memory>
+#include <unordered_map>
+
+namespace armnn
+{
+
+class IBackend;
+
+class BackendRegistry
+{
+public:
+ using FactoryFunction = std::function<IBackendUniquePtr()>;
+
+ static BackendRegistry& Instance();
+ void Register(const std::string& name, FactoryFunction factory);
+ FactoryFunction GetFactory(const std::string& name) const;
+
+ struct Helper
+ {
+ Helper(const std::string& name, FactoryFunction factory)
+ {
+ BackendRegistry::Instance().Register(name, factory);
+ }
+ };
+
+ size_t Size() const { return m_BackendFactories.size(); }
+
+protected:
+ using FactoryStorage = std::unordered_map<std::string, FactoryFunction>;
+
+ // For testing only
+ static void Swap(FactoryStorage& other);
+ BackendRegistry() {}
+ ~BackendRegistry() {}
+
+private:
+ BackendRegistry(const BackendRegistry&) = delete;
+ BackendRegistry& operator=(const BackendRegistry&) = delete;
+
+ FactoryStorage m_BackendFactories;
+};
+
+} // namespace armnn
diff --git a/src/backends/CMakeLists.txt b/src/backends/CMakeLists.txt
index ea5ad7814c..0bc6888899 100644
--- a/src/backends/CMakeLists.txt
+++ b/src/backends/CMakeLists.txt
@@ -4,6 +4,8 @@
#
list(APPEND armnnBackendsCommon_sources
+ BackendRegistry.cpp
+ BackendRegistry.hpp
CpuTensorHandle.cpp
CpuTensorHandleFwd.hpp
CpuTensorHandle.hpp
diff --git a/src/backends/common.mk b/src/backends/common.mk
index e65281a1a4..99c6e12a3e 100644
--- a/src/backends/common.mk
+++ b/src/backends/common.mk
@@ -8,6 +8,7 @@
# file in the root of ArmNN
COMMON_SOURCES := \
+ BackendRegistry.cpp \
CpuTensorHandle.cpp \
MemCopyWorkload.cpp \
OutputHandler.cpp \
diff --git a/src/backends/test/BackendRegistryTests.cpp b/src/backends/test/BackendRegistryTests.cpp
new file mode 100644
index 0000000000..e895df63a6
--- /dev/null
+++ b/src/backends/test/BackendRegistryTests.cpp
@@ -0,0 +1,91 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include <boost/test/unit_test.hpp>
+
+#include <backends/BackendRegistry.hpp>
+#include <armnn/Types.hpp>
+
+namespace
+{
+
+class SwapRegistryStorage : public armnn::BackendRegistry
+{
+public:
+ SwapRegistryStorage() : armnn::BackendRegistry()
+ {
+ Swap(m_TempStorage);
+ }
+
+ ~SwapRegistryStorage()
+ {
+ Swap(m_TempStorage);
+ }
+
+private:
+ BackendRegistry::FactoryStorage m_TempStorage;
+};
+
+}
+
+BOOST_AUTO_TEST_SUITE(BackendRegistryTests)
+
+BOOST_AUTO_TEST_CASE(SwapRegistry)
+{
+ using armnn::BackendRegistry;
+ auto nFactories = BackendRegistry::Instance().Size();
+ {
+ SwapRegistryStorage helper;
+ BOOST_TEST(BackendRegistry::Instance().Size() == 0);
+ }
+ BOOST_TEST(BackendRegistry::Instance().Size() == nFactories);
+}
+
+BOOST_AUTO_TEST_CASE(TestRegistryHelper)
+{
+ using armnn::BackendRegistry;
+ SwapRegistryStorage helper;
+
+ bool called = false;
+ BackendRegistry::Helper factoryHelper("HelloWorld", [&called]() {
+ called = true;
+ return armnn::IBackendUniquePtr(nullptr, nullptr);
+ } );
+
+ // sanity check: the factory has not been called yet
+ BOOST_TEST(called == false);
+
+ auto factoryFunction = BackendRegistry::Instance().GetFactory("HelloWorld");
+
+ // sanity check: the factory still not called
+ BOOST_TEST(called == false);
+
+ factoryFunction();
+ BOOST_TEST(called == true);
+}
+
+BOOST_AUTO_TEST_CASE(TestDirectCallToRegistry)
+{
+ using armnn::BackendRegistry;
+ SwapRegistryStorage helper;
+
+ bool called = false;
+ BackendRegistry::Instance().Register("HelloWorld", [&called]() {
+ called = true;
+ return armnn::IBackendUniquePtr(nullptr, nullptr);
+ } );
+
+ // sanity check: the factory has not been called yet
+ BOOST_TEST(called == false);
+
+ auto factoryFunction = BackendRegistry::Instance().GetFactory("HelloWorld");
+
+ // sanity check: the factory still not called
+ BOOST_TEST(called == false);
+
+ factoryFunction();
+ BOOST_TEST(called == true);
+}
+
+BOOST_AUTO_TEST_SUITE_END()