diff options
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/BackendRegistry.cpp | 69 | ||||
-rw-r--r-- | src/backends/backendsCommon/BackendRegistry.hpp | 50 | ||||
-rw-r--r-- | src/backends/backendsCommon/CMakeLists.txt | 3 | ||||
-rw-r--r-- | src/backends/backendsCommon/IBackendInternal.hpp | 3 | ||||
-rw-r--r-- | src/backends/backendsCommon/LayerSupportRegistry.cpp | 17 | ||||
-rw-r--r-- | src/backends/backendsCommon/LayerSupportRegistry.hpp | 23 | ||||
-rw-r--r-- | src/backends/backendsCommon/RegistryCommon.hpp | 120 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 22 | ||||
-rw-r--r-- | src/backends/backendsCommon/common.mk | 1 | ||||
-rw-r--r-- | src/backends/backendsCommon/test/BackendRegistryTests.cpp | 2 |
10 files changed, 134 insertions, 176 deletions
diff --git a/src/backends/backendsCommon/BackendRegistry.cpp b/src/backends/backendsCommon/BackendRegistry.cpp index e9361210f2..80ab01ce1b 100644 --- a/src/backends/backendsCommon/BackendRegistry.cpp +++ b/src/backends/backendsCommon/BackendRegistry.cpp @@ -4,6 +4,7 @@ // #include "BackendRegistry.hpp" +#include <armnn/Exceptions.hpp> namespace armnn { @@ -14,4 +15,72 @@ BackendRegistry& BackendRegistryInstance() return instance; } +void BackendRegistry::Register(const BackendId& id, BackendRegistry::FactoryFunction factory) +{ + if (m_Factories.count(id) > 0) + { + throw InvalidArgumentException( + std::string(id) + " already registered as IBackend factory", + CHECK_LOCATION()); + } + + m_Factories[id] = factory; +} + +bool BackendRegistry::IsBackendRegistered(const BackendId& id) const +{ + return (m_Factories.find(id) != m_Factories.end()); +} + +BackendRegistry::FactoryFunction BackendRegistry::GetFactory(const BackendId& id) const +{ + auto it = m_Factories.find(id); + if (it == m_Factories.end()) + { + throw InvalidArgumentException( + std::string(id) + " has no IBackend factory registered", + CHECK_LOCATION()); + } + + return it->second; +} + +size_t BackendRegistry::Size() const +{ + return m_Factories.size(); +} + +BackendIdSet BackendRegistry::GetBackendIds() const +{ + BackendIdSet result; + for (const auto& it : m_Factories) + { + result.insert(it.first); + } + return result; +} + +std::string BackendRegistry::GetBackendIdsAsString() const +{ + static const std::string delimitator = ", "; + + std::stringstream output; + for (auto& backendId : GetBackendIds()) + { + if (output.tellp() != std::streampos(0)) + { + output << delimitator; + } + output << backendId; + } + + return output.str(); +} + +void BackendRegistry::Swap(BackendRegistry& instance, BackendRegistry::FactoryStorage& other) +{ + std::swap(instance.m_Factories, other); +} + + } // namespace armnn diff --git a/src/backends/backendsCommon/BackendRegistry.hpp b/src/backends/backendsCommon/BackendRegistry.hpp index 4b20cacbe0..2a52e24238 100644 --- a/src/backends/backendsCommon/BackendRegistry.hpp +++ b/src/backends/backendsCommon/BackendRegistry.hpp @@ -4,21 +4,57 @@ // #pragma once -#include "RegistryCommon.hpp" #include <armnn/Types.hpp> +#include <armnn/BackendId.hpp> + +#include <memory> +#include <unordered_map> namespace armnn { + class IBackendInternal; using IBackendInternalUniquePtr = std::unique_ptr<IBackendInternal>; -using BackendRegistry = RegistryCommon<IBackendInternal, IBackendInternalUniquePtr>; - -BackendRegistry& BackendRegistryInstance(); -template <> -struct RegisteredTypeName<IBackend> +class BackendRegistry { - static const char * Name() { return "IBackend"; } +public: + using PointerType = IBackendInternalUniquePtr; + using FactoryFunction = std::function<PointerType()>; + + void Register(const BackendId& id, FactoryFunction factory); + bool IsBackendRegistered(const BackendId& id) const; + FactoryFunction GetFactory(const BackendId& id) const; + size_t Size() const; + BackendIdSet GetBackendIds() const; + std::string GetBackendIdsAsString() const; + + BackendRegistry() {} + virtual ~BackendRegistry() {} + + struct StaticRegistryInitializer + { + StaticRegistryInitializer(BackendRegistry& instance, + const BackendId& id, + FactoryFunction factory) + { + instance.Register(id, factory); + } + }; + +protected: + using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>; + + // For testing only + static void Swap(BackendRegistry& instance, FactoryStorage& other); + +private: + BackendRegistry(const BackendRegistry&) = delete; + BackendRegistry& operator=(const BackendRegistry&) = delete; + + FactoryStorage m_Factories; }; +BackendRegistry& BackendRegistryInstance(); + } // namespace armnn
\ No newline at end of file diff --git a/src/backends/backendsCommon/CMakeLists.txt b/src/backends/backendsCommon/CMakeLists.txt index f4ab45f8b4..e6ac01c0ac 100644 --- a/src/backends/backendsCommon/CMakeLists.txt +++ b/src/backends/backendsCommon/CMakeLists.txt @@ -13,14 +13,11 @@ list(APPEND armnnBackendsCommon_sources IBackendContext.hpp ILayerSupport.cpp ITensorHandle.hpp - LayerSupportRegistry.cpp - LayerSupportRegistry.hpp MakeWorkloadHelper.hpp MemCopyWorkload.cpp MemCopyWorkload.hpp OutputHandler.cpp OutputHandler.hpp - RegistryCommon.hpp StringMapping.cpp StringMapping.hpp WorkloadDataCollector.hpp diff --git a/src/backends/backendsCommon/IBackendInternal.hpp b/src/backends/backendsCommon/IBackendInternal.hpp index 9c54b821e7..9d649fcfe2 100644 --- a/src/backends/backendsCommon/IBackendInternal.hpp +++ b/src/backends/backendsCommon/IBackendInternal.hpp @@ -13,6 +13,7 @@ namespace armnn class IWorkloadFactory; class IBackendContext; class Optimization; +class ILayerSupport; class IBackendInternal : public IBackend { @@ -30,10 +31,12 @@ public: using IBackendContextPtr = std::unique_ptr<IBackendContext>; using OptimizationPtr = std::unique_ptr<Optimization>; using Optimizations = std::vector<OptimizationPtr>; + using ILayerSupportSharedPtr = std::shared_ptr<ILayerSupport>; virtual IWorkloadFactoryPtr CreateWorkloadFactory() const = 0; virtual IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const = 0; virtual Optimizations GetOptimizations() const = 0; + virtual ILayerSupportSharedPtr GetLayerSupport() const = 0; }; using IBackendInternalUniquePtr = std::unique_ptr<IBackendInternal>; diff --git a/src/backends/backendsCommon/LayerSupportRegistry.cpp b/src/backends/backendsCommon/LayerSupportRegistry.cpp deleted file mode 100644 index 63b4da7337..0000000000 --- a/src/backends/backendsCommon/LayerSupportRegistry.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "LayerSupportRegistry.hpp" - -namespace armnn -{ - -LayerSupportRegistry& LayerSupportRegistryInstance() -{ - static LayerSupportRegistry instance; - return instance; -} - -} // namespace armnn diff --git a/src/backends/backendsCommon/LayerSupportRegistry.hpp b/src/backends/backendsCommon/LayerSupportRegistry.hpp deleted file mode 100644 index a5efad05ef..0000000000 --- a/src/backends/backendsCommon/LayerSupportRegistry.hpp +++ /dev/null @@ -1,23 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// -#pragma once - -#include "RegistryCommon.hpp" -#include <armnn/ILayerSupport.hpp> - -namespace armnn -{ - -using LayerSupportRegistry = RegistryCommon<ILayerSupport, ILayerSupportSharedPtr>; - -LayerSupportRegistry& LayerSupportRegistryInstance(); - -template <> -struct RegisteredTypeName<ILayerSupport> -{ - static const char * Name() { return "ILayerSupport"; } -}; - -} // namespace armnn diff --git a/src/backends/backendsCommon/RegistryCommon.hpp b/src/backends/backendsCommon/RegistryCommon.hpp deleted file mode 100644 index 03bd338090..0000000000 --- a/src/backends/backendsCommon/RegistryCommon.hpp +++ /dev/null @@ -1,120 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// -#pragma once - -#include <armnn/BackendId.hpp> -#include <armnn/Exceptions.hpp> - -#include <functional> -#include <memory> -#include <sstream> -#include <string> -#include <unordered_map> - -namespace armnn -{ - -template <typename RegisteredType> -struct RegisteredTypeName -{ - static const char * Name() { return "UNKNOWN"; } -}; - -template <typename RegisteredType, typename PointerType> -class RegistryCommon -{ -public: - using FactoryFunction = std::function<PointerType()>; - - void Register(const BackendId& id, FactoryFunction factory) - { - if (m_Factories.count(id) > 0) - { - throw InvalidArgumentException( - std::string(id) + " already registered as " + RegisteredTypeName<RegisteredType>::Name() + " factory", - CHECK_LOCATION()); - } - - m_Factories[id] = factory; - } - - FactoryFunction GetFactory(const BackendId& id) const - { - auto it = m_Factories.find(id); - if (it == m_Factories.end()) - { - throw InvalidArgumentException( - std::string(id) + " has no " + RegisteredTypeName<RegisteredType>::Name() + " factory registered", - CHECK_LOCATION()); - } - - return it->second; - } - - size_t Size() const - { - return m_Factories.size(); - } - - BackendIdSet GetBackendIds() const - { - BackendIdSet result; - for (const auto& it : m_Factories) - { - result.insert(it.first); - } - return result; - } - - std::string GetBackendIdsAsString() const - { - static const std::string delimitator = ", "; - - std::stringstream output; - for (auto& backendId : GetBackendIds()) - { - if (output.tellp() != std::streampos(0)) - { - output << delimitator; - } - output << backendId; - } - - return output.str(); - } - - RegistryCommon() {} - virtual ~RegistryCommon() {} - -protected: - using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>; - - // For testing only - static void Swap(RegistryCommon& instance, FactoryStorage& other) - { - std::swap(instance.m_Factories, other); - } - -private: - RegistryCommon(const RegistryCommon&) = delete; - RegistryCommon& operator=(const RegistryCommon&) = delete; - - FactoryStorage m_Factories; -}; - -template <typename RegistryType> -struct StaticRegistryInitializer -{ - using FactoryFunction = typename RegistryType::FactoryFunction; - - StaticRegistryInitializer(RegistryType& instance, - const BackendId& id, - FactoryFunction factory) - { - instance.Register(id, factory); - } -}; - -} // namespace armnn diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index ec30f34880..bb63b336e9 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -10,14 +10,17 @@ #include <armnn/Types.hpp> #include <armnn/LayerSupport.hpp> +#include <armnn/ILayerSupport.hpp> -#include <backendsCommon/LayerSupportRegistry.hpp> +#include <backendsCommon/BackendRegistry.hpp> #include <backendsCommon/WorkloadFactory.hpp> +#include <backendsCommon/IBackendInternal.hpp> #include <boost/cast.hpp> #include <boost/iterator/transform_iterator.hpp> #include <cstring> +#include <sstream> namespace armnn { @@ -66,9 +69,20 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, bool result; const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer)); - auto const& layerSupportRegistry = LayerSupportRegistryInstance(); - auto layerSupportFactory = layerSupportRegistry.GetFactory(backendId); - auto layerSupportObject = layerSupportFactory(); + auto const& backendRegistry = BackendRegistryInstance(); + if (!backendRegistry.IsBackendRegistered(backendId)) + { + std::stringstream ss; + ss << connectableLayer.GetName() << " is not supported on " << backendId + << " because this backend is not registered."; + + outReasonIfUnsupported = ss.str(); + return false; + } + + auto backendFactory = backendRegistry.GetFactory(backendId); + auto backendObject = backendFactory(); + auto layerSupportObject = backendObject->GetLayerSupport(); switch(layer.GetType()) { diff --git a/src/backends/backendsCommon/common.mk b/src/backends/backendsCommon/common.mk index b1583b987e..8d29316599 100644 --- a/src/backends/backendsCommon/common.mk +++ b/src/backends/backendsCommon/common.mk @@ -12,7 +12,6 @@ COMMON_SOURCES := \ CpuTensorHandle.cpp \ ILayerSupport.cpp \ MemCopyWorkload.cpp \ - LayerSupportRegistry.cpp \ OutputHandler.cpp \ StringMapping.cpp \ WorkloadData.cpp \ diff --git a/src/backends/backendsCommon/test/BackendRegistryTests.cpp b/src/backends/backendsCommon/test/BackendRegistryTests.cpp index 26175e015f..283caafaf9 100644 --- a/src/backends/backendsCommon/test/BackendRegistryTests.cpp +++ b/src/backends/backendsCommon/test/BackendRegistryTests.cpp @@ -52,7 +52,7 @@ BOOST_AUTO_TEST_CASE(TestRegistryHelper) bool called = false; - StaticRegistryInitializer<BackendRegistry> factoryHelper( + BackendRegistry::StaticRegistryInitializer factoryHelper( BackendRegistryInstance(), "HelloWorld", [&called]() |