aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/BackendHelper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/BackendHelper.cpp')
-rw-r--r--src/armnn/BackendHelper.cpp93
1 files changed, 89 insertions, 4 deletions
diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp
index 31dfaa53a3..be21412e97 100644
--- a/src/armnn/BackendHelper.cpp
+++ b/src/armnn/BackendHelper.cpp
@@ -26,6 +26,89 @@ LayerSupportHandle GetILayerSupportByBackendId(const armnn::BackendId& backend)
return LayerSupportHandle(backendObject->GetLayerSupport(), backend);
}
+Optional<const BackendOptions::BackendOption> GetCapability(const std::string& backendCapabilityName,
+ const BackendCapabilities& capabilities)
+{
+ for (size_t i=0; i < capabilities.GetOptionCount(); i++)
+ {
+ const auto& capability = capabilities.GetOption(i);
+ if (backendCapabilityName == capability.GetName())
+ {
+ return capability;
+ }
+ }
+ return EmptyOptional();
+}
+
+Optional<const BackendOptions::BackendOption> GetCapability(const std::string& backendCapabilityName,
+ const armnn::BackendId& backend)
+{
+ auto const& backendRegistry = armnn::BackendRegistryInstance();
+ if (backendRegistry.IsBackendRegistered(backend))
+ {
+ auto factoryFunc = backendRegistry.GetFactory(backend);
+ auto backendObject = factoryFunc();
+ auto capabilities = backendObject->GetCapabilities();
+ return GetCapability(backendCapabilityName, capabilities);
+ }
+ return EmptyOptional();
+}
+
+bool HasCapability(const std::string& name, const BackendCapabilities& capabilities)
+{
+ return GetCapability(name, capabilities).has_value();
+}
+
+bool HasCapability(const std::string& name, const armnn::BackendId& backend)
+{
+ return GetCapability(name, backend).has_value();
+}
+
+bool HasCapability(const BackendOptions::BackendOption& capability, const BackendCapabilities& capabilities)
+{
+ for (size_t i=0; i < capabilities.GetOptionCount(); i++)
+ {
+ const auto& backendCapability = capabilities.GetOption(i);
+ if (capability.GetName() == backendCapability.GetName())
+ {
+ if (capability.GetValue().IsBool() && backendCapability.GetValue().IsBool())
+ {
+ return capability.GetValue().AsBool() == backendCapability.GetValue().AsBool();
+ }
+ else if(capability.GetValue().IsFloat() && backendCapability.GetValue().IsFloat())
+ {
+ return capability.GetValue().AsFloat() == backendCapability.GetValue().AsFloat();
+ }
+ else if(capability.GetValue().IsInt() && backendCapability.GetValue().IsInt())
+ {
+ return capability.GetValue().AsInt() == backendCapability.GetValue().AsInt();
+ }
+ else if(capability.GetValue().IsString() && backendCapability.GetValue().IsString())
+ {
+ return capability.GetValue().AsString() == backendCapability.GetValue().AsString();
+ }
+ else if(capability.GetValue().IsUnsignedInt() && backendCapability.GetValue().IsUnsignedInt())
+ {
+ return capability.GetValue().AsUnsignedInt() == backendCapability.GetValue().AsUnsignedInt();
+ }
+ }
+ }
+ return false;
+}
+
+bool HasCapability(const BackendOptions::BackendOption& backendOption, const armnn::BackendId& backend)
+{
+ auto const& backendRegistry = armnn::BackendRegistryInstance();
+ if (backendRegistry.IsBackendRegistered(backend))
+ {
+ auto factoryFunc = backendRegistry.GetFactory(backend);
+ auto backendObject = factoryFunc();
+ auto capabilities = backendObject->GetCapabilities();
+ return HasCapability(backendOption, capabilities);
+ }
+ return false;
+}
+
/// Convenience function to check a capability on a backend
bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapability capability)
{
@@ -35,7 +118,9 @@ bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapabi
{
auto factoryFunc = backendRegistry.GetFactory(backend);
auto backendObject = factoryFunc();
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
hasCapability = backendObject->HasCapability(capability);
+ ARMNN_NO_DEPRECATE_WARN_END
}
return hasCapability;
}
@@ -316,12 +401,12 @@ bool LayerSupportHandle::IsFullyConnectedSupported(const TensorInfo& input,
{
if(!descriptor.m_ConstantWeights && !m_BackendId.IsUndefined())
{
- bool result = false;
- result = IsCapabilitySupported(m_BackendId, BackendCapability::NonConstWeights);
- if (!result)
+ auto capability = GetCapability("NonConstWeights", m_BackendId);
+ if (capability.has_value() && capability.value().GetValue().AsBool() == true)
{
- return result;
+ return true;
}
+ return false;
}
return m_LayerSupport->IsFullyConnectedSupported(input,