diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/backends/BackendRegistry.cpp | 24 | ||||
-rw-r--r-- | src/backends/BackendRegistry.hpp | 13 | ||||
-rw-r--r-- | src/backends/cl/ClBackend.cpp | 4 | ||||
-rw-r--r-- | src/backends/cl/ClBackend.hpp | 2 | ||||
-rw-r--r-- | src/backends/neon/NeonBackend.cpp | 4 | ||||
-rw-r--r-- | src/backends/neon/NeonBackend.hpp | 2 | ||||
-rw-r--r-- | src/backends/reference/RefBackend.cpp | 4 | ||||
-rw-r--r-- | src/backends/reference/RefBackend.hpp | 2 | ||||
-rw-r--r-- | src/backends/test/BackendIdTests.cpp | 27 |
9 files changed, 60 insertions, 22 deletions
diff --git a/src/backends/BackendRegistry.cpp b/src/backends/BackendRegistry.cpp index 68336c45b9..a5e9f0e1d9 100644 --- a/src/backends/BackendRegistry.cpp +++ b/src/backends/BackendRegistry.cpp @@ -15,22 +15,22 @@ BackendRegistry& BackendRegistry::Instance() return instance; } -void BackendRegistry::Register(const std::string& name, FactoryFunction factory) +void BackendRegistry::Register(const BackendId& id, FactoryFunction factory) { - if (m_BackendFactories.count(name) > 0) + if (m_BackendFactories.count(id) > 0) { - throw InvalidArgumentException(name + " already registered as backend"); + throw InvalidArgumentException(std::string(id) + " already registered as backend"); } - m_BackendFactories[name] = factory; + m_BackendFactories[id] = factory; } -BackendRegistry::FactoryFunction BackendRegistry::GetFactory(const std::string& name) const +BackendRegistry::FactoryFunction BackendRegistry::GetFactory(const BackendId& id) const { - auto it = m_BackendFactories.find(name); + auto it = m_BackendFactories.find(id); if (it == m_BackendFactories.end()) { - throw InvalidArgumentException(name + " has no backend factory registered"); + throw InvalidArgumentException(std::string(id) + " has no backend factory registered"); } return it->second; @@ -42,4 +42,14 @@ void BackendRegistry::Swap(BackendRegistry::FactoryStorage& other) std::swap(instance.m_BackendFactories, other); } +BackendIdSet BackendRegistry::GetBackendIds() const +{ + BackendIdSet result; + for (const auto& it : m_BackendFactories) + { + result.insert(it.first); + } + return result; } + +} // namespace armnn diff --git a/src/backends/BackendRegistry.hpp b/src/backends/BackendRegistry.hpp index ff01d21715..e2c526d293 100644 --- a/src/backends/BackendRegistry.hpp +++ b/src/backends/BackendRegistry.hpp @@ -5,7 +5,6 @@ #pragma once #include <armnn/Types.hpp> -#include <string> #include <functional> #include <memory> #include <unordered_map> @@ -21,21 +20,23 @@ public: using FactoryFunction = std::function<IBackendUniquePtr()>; static BackendRegistry& Instance(); - void Register(const std::string& name, FactoryFunction factory); - FactoryFunction GetFactory(const std::string& name) const; + + void Register(const BackendId& id, FactoryFunction factory); + FactoryFunction GetFactory(const BackendId& id) const; struct Helper { - Helper(const std::string& name, FactoryFunction factory) + Helper(const BackendId& id, FactoryFunction factory) { - BackendRegistry::Instance().Register(name, factory); + BackendRegistry::Instance().Register(id, factory); } }; size_t Size() const { return m_BackendFactories.size(); } + BackendIdSet GetBackendIds() const; protected: - using FactoryStorage = std::unordered_map<std::string, FactoryFunction>; + using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>; // For testing only static void Swap(FactoryStorage& other); diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp index d185c15b72..840da8bda3 100644 --- a/src/backends/cl/ClBackend.cpp +++ b/src/backends/cl/ClBackend.cpp @@ -12,7 +12,7 @@ namespace armnn namespace { -static const std::string s_Id = "GpuAcc"; +static const BackendId s_Id{"GpuAcc"}; static BackendRegistry::Helper g_RegisterHelper{ s_Id, @@ -24,7 +24,7 @@ static BackendRegistry::Helper g_RegisterHelper{ } -const std::string& ClBackend::GetId() const +const BackendId& ClBackend::GetId() const { return s_Id; } diff --git a/src/backends/cl/ClBackend.hpp b/src/backends/cl/ClBackend.hpp index c43b6a6ce0..b927db4b25 100644 --- a/src/backends/cl/ClBackend.hpp +++ b/src/backends/cl/ClBackend.hpp @@ -16,7 +16,7 @@ public: ClBackend() = default; ~ClBackend() = default; - const std::string& GetId() const override; + const BackendId& GetId() const override; const ILayerSupport& GetLayerSupport() const override; diff --git a/src/backends/neon/NeonBackend.cpp b/src/backends/neon/NeonBackend.cpp index b4f1897704..e35bf7ea45 100644 --- a/src/backends/neon/NeonBackend.cpp +++ b/src/backends/neon/NeonBackend.cpp @@ -13,7 +13,7 @@ namespace armnn namespace { -static const std::string s_Id = "CpuAcc"; +static const BackendId s_Id{"CpuAcc"}; static BackendRegistry::Helper g_RegisterHelper{ s_Id, @@ -25,7 +25,7 @@ static BackendRegistry::Helper g_RegisterHelper{ } -const std::string& NeonBackend::GetId() const +const BackendId& NeonBackend::GetId() const { return s_Id; } diff --git a/src/backends/neon/NeonBackend.hpp b/src/backends/neon/NeonBackend.hpp index 5d4bd5dfcc..fa2cad13ee 100644 --- a/src/backends/neon/NeonBackend.hpp +++ b/src/backends/neon/NeonBackend.hpp @@ -16,7 +16,7 @@ public: NeonBackend() = default; ~NeonBackend() = default; - const std::string& GetId() const override; + const BackendId& GetId() const override; const ILayerSupport& GetLayerSupport() const override; diff --git a/src/backends/reference/RefBackend.cpp b/src/backends/reference/RefBackend.cpp index b671e8bca8..63eff52b75 100644 --- a/src/backends/reference/RefBackend.cpp +++ b/src/backends/reference/RefBackend.cpp @@ -12,7 +12,7 @@ namespace armnn namespace { -const std::string s_Id = "CpuRef"; +const BackendId s_Id{"CpuRef"}; static BackendRegistry::Helper s_RegisterHelper{ s_Id, @@ -24,7 +24,7 @@ static BackendRegistry::Helper s_RegisterHelper{ } -const std::string& RefBackend::GetId() const +const BackendId& RefBackend::GetId() const { return s_Id; } diff --git a/src/backends/reference/RefBackend.hpp b/src/backends/reference/RefBackend.hpp index e4a11f10c9..dcc974167d 100644 --- a/src/backends/reference/RefBackend.hpp +++ b/src/backends/reference/RefBackend.hpp @@ -16,7 +16,7 @@ public: RefBackend() = default; ~RefBackend() = default; - const std::string& GetId() const override; + const BackendId& GetId() const override; const ILayerSupport& GetLayerSupport() const override; diff --git a/src/backends/test/BackendIdTests.cpp b/src/backends/test/BackendIdTests.cpp new file mode 100644 index 0000000000..0ef0a20d7f --- /dev/null +++ b/src/backends/test/BackendIdTests.cpp @@ -0,0 +1,27 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include <boost/test/unit_test.hpp> + +#include <armnn/BackendId.hpp> +#include <armnn/Types.hpp> + +using namespace armnn; + +BOOST_AUTO_TEST_SUITE(BackendIdTests) + +BOOST_AUTO_TEST_CASE(CreateBackendIdFromCompute) +{ + BackendId fromCompute{Compute::GpuAcc}; + BOOST_TEST(fromCompute.Get() == GetComputeDeviceAsCString(Compute::GpuAcc)); +} + +BOOST_AUTO_TEST_CASE(CreateBackendIdVectorFromCompute) +{ + std::vector<BackendId> fromComputes = {Compute::GpuAcc, Compute::CpuRef}; + BOOST_TEST(fromComputes[0].Get() == GetComputeDeviceAsCString(Compute::GpuAcc)); + BOOST_TEST(fromComputes[1].Get() == GetComputeDeviceAsCString(Compute::CpuRef)); +} + +BOOST_AUTO_TEST_SUITE_END() |