diff options
author | Finn Williams <Finn.Williams@arm.com> | 2021-05-26 18:38:12 +0100 |
---|---|---|
committer | Finn Williams <Finn.Williams@arm.com> | 2021-06-09 17:18:10 +0100 |
commit | b9af86ea42568ade799ee5529137e4756977b6c6 (patch) | |
tree | 261003078fd2191b22ee7465e07668cbed666553 /src/armnn | |
parent | 5b1bcc93820b442bc4035c1e030a8d4a0983df91 (diff) | |
download | armnn-b9af86ea42568ade799ee5529137e4756977b6c6.tar.gz |
IVGCVSW-5855 Refactor the reporting of capabilities from backends
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I05fc331a8e91bdcb6b8a2f32cfb555060fc5d797
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/BackendHelper.cpp | 93 | ||||
-rw-r--r-- | src/armnn/test/OptimizerTests.cpp | 8 | ||||
-rw-r--r-- | src/armnn/test/UtilsTests.cpp | 12 |
3 files changed, 95 insertions, 18 deletions
diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp index 31dfaa53a3..be21412e97 100644 --- a/src/armnn/BackendHelper.cpp +++ b/src/armnn/BackendHelper.cpp @@ -26,6 +26,89 @@ LayerSupportHandle GetILayerSupportByBackendId(const armnn::BackendId& backend) return LayerSupportHandle(backendObject->GetLayerSupport(), backend); } +Optional<const BackendOptions::BackendOption> GetCapability(const std::string& backendCapabilityName, + const BackendCapabilities& capabilities) +{ + for (size_t i=0; i < capabilities.GetOptionCount(); i++) + { + const auto& capability = capabilities.GetOption(i); + if (backendCapabilityName == capability.GetName()) + { + return capability; + } + } + return EmptyOptional(); +} + +Optional<const BackendOptions::BackendOption> GetCapability(const std::string& backendCapabilityName, + const armnn::BackendId& backend) +{ + auto const& backendRegistry = armnn::BackendRegistryInstance(); + if (backendRegistry.IsBackendRegistered(backend)) + { + auto factoryFunc = backendRegistry.GetFactory(backend); + auto backendObject = factoryFunc(); + auto capabilities = backendObject->GetCapabilities(); + return GetCapability(backendCapabilityName, capabilities); + } + return EmptyOptional(); +} + +bool HasCapability(const std::string& name, const BackendCapabilities& capabilities) +{ + return GetCapability(name, capabilities).has_value(); +} + +bool HasCapability(const std::string& name, const armnn::BackendId& backend) +{ + return GetCapability(name, backend).has_value(); +} + +bool HasCapability(const BackendOptions::BackendOption& capability, const BackendCapabilities& capabilities) +{ + for (size_t i=0; i < capabilities.GetOptionCount(); i++) + { + const auto& backendCapability = capabilities.GetOption(i); + if (capability.GetName() == backendCapability.GetName()) + { + if (capability.GetValue().IsBool() && backendCapability.GetValue().IsBool()) + { + return capability.GetValue().AsBool() == backendCapability.GetValue().AsBool(); + } + else if(capability.GetValue().IsFloat() && backendCapability.GetValue().IsFloat()) + { + return capability.GetValue().AsFloat() == backendCapability.GetValue().AsFloat(); + } + else if(capability.GetValue().IsInt() && backendCapability.GetValue().IsInt()) + { + return capability.GetValue().AsInt() == backendCapability.GetValue().AsInt(); + } + else if(capability.GetValue().IsString() && backendCapability.GetValue().IsString()) + { + return capability.GetValue().AsString() == backendCapability.GetValue().AsString(); + } + else if(capability.GetValue().IsUnsignedInt() && backendCapability.GetValue().IsUnsignedInt()) + { + return capability.GetValue().AsUnsignedInt() == backendCapability.GetValue().AsUnsignedInt(); + } + } + } + return false; +} + +bool HasCapability(const BackendOptions::BackendOption& backendOption, const armnn::BackendId& backend) +{ + auto const& backendRegistry = armnn::BackendRegistryInstance(); + if (backendRegistry.IsBackendRegistered(backend)) + { + auto factoryFunc = backendRegistry.GetFactory(backend); + auto backendObject = factoryFunc(); + auto capabilities = backendObject->GetCapabilities(); + return HasCapability(backendOption, capabilities); + } + return false; +} + /// Convenience function to check a capability on a backend bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapability capability) { @@ -35,7 +118,9 @@ bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapabi { auto factoryFunc = backendRegistry.GetFactory(backend); auto backendObject = factoryFunc(); + ARMNN_NO_DEPRECATE_WARN_BEGIN hasCapability = backendObject->HasCapability(capability); + ARMNN_NO_DEPRECATE_WARN_END } return hasCapability; } @@ -316,12 +401,12 @@ bool LayerSupportHandle::IsFullyConnectedSupported(const TensorInfo& input, { if(!descriptor.m_ConstantWeights && !m_BackendId.IsUndefined()) { - bool result = false; - result = IsCapabilitySupported(m_BackendId, BackendCapability::NonConstWeights); - if (!result) + auto capability = GetCapability("NonConstWeights", m_BackendId); + if (capability.has_value() && capability.value().GetValue().AsBool() == true) { - return result; + return true; } + return false; } return m_LayerSupport->IsFullyConnectedSupported(input, diff --git a/src/armnn/test/OptimizerTests.cpp b/src/armnn/test/OptimizerTests.cpp index fcfff1a807..7fe69a9380 100644 --- a/src/armnn/test/OptimizerTests.cpp +++ b/src/armnn/test/OptimizerTests.cpp @@ -615,11 +615,15 @@ public: BOOST_AUTO_TEST_CASE(BackendCapabilityTest) { BackendId backendId = "MockBackend"; + + armnn::BackendOptions::BackendOption nonConstWeights{"NonConstWeights", true}; + // MockBackend does not support the NonConstWeights capability - BOOST_CHECK(!armnn::IsCapabilitySupported(backendId, armnn::BackendCapability::NonConstWeights)); + BOOST_CHECK(!armnn::HasCapability(nonConstWeights, backendId)); + BOOST_CHECK(!armnn::HasCapability("NonConstWeights", backendId)); // MockBackend does not support the AsyncExecution capability - BOOST_CHECK(!armnn::IsCapabilitySupported(backendId, armnn::BackendCapability::AsyncExecution)); + BOOST_CHECK(!armnn::GetCapability("AsyncExecution", backendId).has_value()); } BOOST_AUTO_TEST_CASE(BackendHintTest) diff --git a/src/armnn/test/UtilsTests.cpp b/src/armnn/test/UtilsTests.cpp index 77883ba4fb..f2ca95d7bd 100644 --- a/src/armnn/test/UtilsTests.cpp +++ b/src/armnn/test/UtilsTests.cpp @@ -321,18 +321,6 @@ BOOST_AUTO_TEST_CASE(LayerSupportHandle) BOOST_CHECK(layerSupportObject.IsBackendRegistered()); } - -BOOST_AUTO_TEST_CASE(IsCapabilitySupported_CpuRef) -{ - BOOST_CHECK(armnn::IsCapabilitySupported(armnn::Compute::CpuRef, armnn::BackendCapability::NonConstWeights)); -} -#endif - -#if defined(ARMCOMPUTENEON_ENABLED) -BOOST_AUTO_TEST_CASE(IsCapabilitySupported_CpuAcc) -{ - BOOST_CHECK(!armnn::IsCapabilitySupported(armnn::Compute::CpuAcc, armnn::BackendCapability::NonConstWeights)); -} #endif BOOST_AUTO_TEST_SUITE_END() |