aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2021-05-26 18:38:12 +0100
committerFinn Williams <Finn.Williams@arm.com>2021-06-09 17:18:10 +0100
commitb9af86ea42568ade799ee5529137e4756977b6c6 (patch)
tree261003078fd2191b22ee7465e07668cbed666553 /src/armnn
parent5b1bcc93820b442bc4035c1e030a8d4a0983df91 (diff)
downloadarmnn-b9af86ea42568ade799ee5529137e4756977b6c6.tar.gz
IVGCVSW-5855 Refactor the reporting of capabilities from backends
Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I05fc331a8e91bdcb6b8a2f32cfb555060fc5d797
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/BackendHelper.cpp93
-rw-r--r--src/armnn/test/OptimizerTests.cpp8
-rw-r--r--src/armnn/test/UtilsTests.cpp12
3 files changed, 95 insertions, 18 deletions
diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp
index 31dfaa53a3..be21412e97 100644
--- a/src/armnn/BackendHelper.cpp
+++ b/src/armnn/BackendHelper.cpp
@@ -26,6 +26,89 @@ LayerSupportHandle GetILayerSupportByBackendId(const armnn::BackendId& backend)
return LayerSupportHandle(backendObject->GetLayerSupport(), backend);
}
+Optional<const BackendOptions::BackendOption> GetCapability(const std::string& backendCapabilityName,
+ const BackendCapabilities& capabilities)
+{
+ for (size_t i=0; i < capabilities.GetOptionCount(); i++)
+ {
+ const auto& capability = capabilities.GetOption(i);
+ if (backendCapabilityName == capability.GetName())
+ {
+ return capability;
+ }
+ }
+ return EmptyOptional();
+}
+
+Optional<const BackendOptions::BackendOption> GetCapability(const std::string& backendCapabilityName,
+ const armnn::BackendId& backend)
+{
+ auto const& backendRegistry = armnn::BackendRegistryInstance();
+ if (backendRegistry.IsBackendRegistered(backend))
+ {
+ auto factoryFunc = backendRegistry.GetFactory(backend);
+ auto backendObject = factoryFunc();
+ auto capabilities = backendObject->GetCapabilities();
+ return GetCapability(backendCapabilityName, capabilities);
+ }
+ return EmptyOptional();
+}
+
+bool HasCapability(const std::string& name, const BackendCapabilities& capabilities)
+{
+ return GetCapability(name, capabilities).has_value();
+}
+
+bool HasCapability(const std::string& name, const armnn::BackendId& backend)
+{
+ return GetCapability(name, backend).has_value();
+}
+
+bool HasCapability(const BackendOptions::BackendOption& capability, const BackendCapabilities& capabilities)
+{
+ for (size_t i=0; i < capabilities.GetOptionCount(); i++)
+ {
+ const auto& backendCapability = capabilities.GetOption(i);
+ if (capability.GetName() == backendCapability.GetName())
+ {
+ if (capability.GetValue().IsBool() && backendCapability.GetValue().IsBool())
+ {
+ return capability.GetValue().AsBool() == backendCapability.GetValue().AsBool();
+ }
+ else if(capability.GetValue().IsFloat() && backendCapability.GetValue().IsFloat())
+ {
+ return capability.GetValue().AsFloat() == backendCapability.GetValue().AsFloat();
+ }
+ else if(capability.GetValue().IsInt() && backendCapability.GetValue().IsInt())
+ {
+ return capability.GetValue().AsInt() == backendCapability.GetValue().AsInt();
+ }
+ else if(capability.GetValue().IsString() && backendCapability.GetValue().IsString())
+ {
+ return capability.GetValue().AsString() == backendCapability.GetValue().AsString();
+ }
+ else if(capability.GetValue().IsUnsignedInt() && backendCapability.GetValue().IsUnsignedInt())
+ {
+ return capability.GetValue().AsUnsignedInt() == backendCapability.GetValue().AsUnsignedInt();
+ }
+ }
+ }
+ return false;
+}
+
+bool HasCapability(const BackendOptions::BackendOption& backendOption, const armnn::BackendId& backend)
+{
+ auto const& backendRegistry = armnn::BackendRegistryInstance();
+ if (backendRegistry.IsBackendRegistered(backend))
+ {
+ auto factoryFunc = backendRegistry.GetFactory(backend);
+ auto backendObject = factoryFunc();
+ auto capabilities = backendObject->GetCapabilities();
+ return HasCapability(backendOption, capabilities);
+ }
+ return false;
+}
+
/// Convenience function to check a capability on a backend
bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapability capability)
{
@@ -35,7 +118,9 @@ bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapabi
{
auto factoryFunc = backendRegistry.GetFactory(backend);
auto backendObject = factoryFunc();
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
hasCapability = backendObject->HasCapability(capability);
+ ARMNN_NO_DEPRECATE_WARN_END
}
return hasCapability;
}
@@ -316,12 +401,12 @@ bool LayerSupportHandle::IsFullyConnectedSupported(const TensorInfo& input,
{
if(!descriptor.m_ConstantWeights && !m_BackendId.IsUndefined())
{
- bool result = false;
- result = IsCapabilitySupported(m_BackendId, BackendCapability::NonConstWeights);
- if (!result)
+ auto capability = GetCapability("NonConstWeights", m_BackendId);
+ if (capability.has_value() && capability.value().GetValue().AsBool() == true)
{
- return result;
+ return true;
}
+ return false;
}
return m_LayerSupport->IsFullyConnectedSupported(input,
diff --git a/src/armnn/test/OptimizerTests.cpp b/src/armnn/test/OptimizerTests.cpp
index fcfff1a807..7fe69a9380 100644
--- a/src/armnn/test/OptimizerTests.cpp
+++ b/src/armnn/test/OptimizerTests.cpp
@@ -615,11 +615,15 @@ public:
BOOST_AUTO_TEST_CASE(BackendCapabilityTest)
{
BackendId backendId = "MockBackend";
+
+ armnn::BackendOptions::BackendOption nonConstWeights{"NonConstWeights", true};
+
// MockBackend does not support the NonConstWeights capability
- BOOST_CHECK(!armnn::IsCapabilitySupported(backendId, armnn::BackendCapability::NonConstWeights));
+ BOOST_CHECK(!armnn::HasCapability(nonConstWeights, backendId));
+ BOOST_CHECK(!armnn::HasCapability("NonConstWeights", backendId));
// MockBackend does not support the AsyncExecution capability
- BOOST_CHECK(!armnn::IsCapabilitySupported(backendId, armnn::BackendCapability::AsyncExecution));
+ BOOST_CHECK(!armnn::GetCapability("AsyncExecution", backendId).has_value());
}
BOOST_AUTO_TEST_CASE(BackendHintTest)
diff --git a/src/armnn/test/UtilsTests.cpp b/src/armnn/test/UtilsTests.cpp
index 77883ba4fb..f2ca95d7bd 100644
--- a/src/armnn/test/UtilsTests.cpp
+++ b/src/armnn/test/UtilsTests.cpp
@@ -321,18 +321,6 @@ BOOST_AUTO_TEST_CASE(LayerSupportHandle)
BOOST_CHECK(layerSupportObject.IsBackendRegistered());
}
-
-BOOST_AUTO_TEST_CASE(IsCapabilitySupported_CpuRef)
-{
- BOOST_CHECK(armnn::IsCapabilitySupported(armnn::Compute::CpuRef, armnn::BackendCapability::NonConstWeights));
-}
-#endif
-
-#if defined(ARMCOMPUTENEON_ENABLED)
-BOOST_AUTO_TEST_CASE(IsCapabilitySupported_CpuAcc)
-{
- BOOST_CHECK(!armnn::IsCapabilitySupported(armnn::Compute::CpuAcc, armnn::BackendCapability::NonConstWeights));
-}
#endif
BOOST_AUTO_TEST_SUITE_END()