aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/BackendHelper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/BackendHelper.cpp')
-rw-r--r--src/armnn/BackendHelper.cpp26
1 files changed, 25 insertions, 1 deletions
diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp
index 1467366323..1c926f4d30 100644
--- a/src/armnn/BackendHelper.cpp
+++ b/src/armnn/BackendHelper.cpp
@@ -23,7 +23,21 @@ LayerSupportHandle GetILayerSupportByBackendId(const armnn::BackendId& backend)
auto factoryFunc = backendRegistry.GetFactory(backend);
auto backendObject = factoryFunc();
- return LayerSupportHandle(backendObject->GetLayerSupport());
+ return LayerSupportHandle(backendObject->GetLayerSupport(), backend);
+}
+
+/// Convenience function to check a capability on a backend
+bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapability capability)
+{
+ bool hasCapability = false;
+ auto const& backendRegistry = armnn::BackendRegistryInstance();
+ if (backendRegistry.IsBackendRegistered(backend))
+ {
+ auto factoryFunc = backendRegistry.GetFactory(backend);
+ auto backendObject = factoryFunc();
+ hasCapability = backendObject->HasCapability(capability);
+ }
+ return hasCapability;
}
bool LayerSupportHandle::IsBackendRegistered() const
@@ -293,6 +307,16 @@ bool LayerSupportHandle::IsFullyConnectedSupported(const TensorInfo& input,
const FullyConnectedDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported)
{
+ if(!descriptor.m_ConstantWeights && !m_BackendId.IsUndefined())
+ {
+ bool result = false;
+ result = IsCapabilitySupported(m_BackendId, BackendCapability::NonConstWeights);
+ if (!result)
+ {
+ return result;
+ }
+ }
+
return m_LayerSupport->IsFullyConnectedSupported(input,
output,
weights,