aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/RegistryCommon.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/RegistryCommon.hpp')
-rw-r--r--src/backends/backendsCommon/RegistryCommon.hpp134
1 files changed, 134 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/RegistryCommon.hpp b/src/backends/backendsCommon/RegistryCommon.hpp
new file mode 100644
index 0000000000..3dbfad2a66
--- /dev/null
+++ b/src/backends/backendsCommon/RegistryCommon.hpp
@@ -0,0 +1,134 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include <armnn/BackendId.hpp>
+#include <armnn/Exceptions.hpp>
+
+#include <functional>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+
+namespace armnn
+{
+
+template <typename RegisteredType>
+struct RegisteredTypeName
+{
+ static const char * Name() { return "UNKNOWN"; }
+};
+
+template <typename RegisteredType, typename PointerType, typename ParamType>
+class RegistryCommon
+{
+public:
+ using FactoryFunction = std::function<PointerType(const ParamType&)>;
+
+ void Register(const BackendId& id, FactoryFunction factory)
+ {
+ if (m_Factories.count(id) > 0)
+ {
+ throw InvalidArgumentException(
+ std::string(id) + " already registered as " + RegisteredTypeName<RegisteredType>::Name() + " factory",
+ CHECK_LOCATION());
+ }
+
+ m_Factories[id] = factory;
+ }
+
+ FactoryFunction GetFactory(const BackendId& id) const
+ {
+ auto it = m_Factories.find(id);
+ if (it == m_Factories.end())
+ {
+ throw InvalidArgumentException(
+ std::string(id) + " has no " + RegisteredTypeName<RegisteredType>::Name() + " factory registered",
+ CHECK_LOCATION());
+ }
+
+ return it->second;
+ }
+
+ FactoryFunction GetFactory(const BackendId& id,
+ FactoryFunction defaultFactory) const
+ {
+ auto it = m_Factories.find(id);
+ if (it == m_Factories.end())
+ {
+ return defaultFactory;
+ }
+ else
+ {
+ return it->second;
+ }
+ }
+
+ size_t Size() const
+ {
+ return m_Factories.size();
+ }
+
+ BackendIdSet GetBackendIds() const
+ {
+ BackendIdSet result;
+ for (const auto& it : m_Factories)
+ {
+ result.insert(it.first);
+ }
+ return result;
+ }
+
+ std::string GetBackendIdsAsString() const
+ {
+ static const std::string delimitator = ", ";
+
+ std::stringstream output;
+ for (auto& backendId : GetBackendIds())
+ {
+ if (output.tellp() != std::streampos(0))
+ {
+ output << delimitator;
+ }
+ output << backendId;
+ }
+
+ return output.str();
+ }
+
+ RegistryCommon() {}
+ virtual ~RegistryCommon() {}
+
+protected:
+ using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>;
+
+ // For testing only
+ static void Swap(RegistryCommon& instance, FactoryStorage& other)
+ {
+ std::swap(instance.m_Factories, other);
+ }
+
+private:
+ RegistryCommon(const RegistryCommon&) = delete;
+ RegistryCommon& operator=(const RegistryCommon&) = delete;
+
+ FactoryStorage m_Factories;
+};
+
+template <typename RegistryType>
+struct StaticRegistryInitializer
+{
+ using FactoryFunction = typename RegistryType::FactoryFunction;
+
+ StaticRegistryInitializer(RegistryType& instance,
+ const BackendId& id,
+ FactoryFunction factory)
+ {
+ instance.Register(id, factory);
+ }
+};
+
+} // namespace armnn \ No newline at end of file