diff options
Diffstat (limited to 'src/armnn/BackendHelper.cpp')
-rw-r--r-- | src/armnn/BackendHelper.cpp | 47 |
1 files changed, 37 insertions, 10 deletions
diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp index 13bde0aafa..9ab30f8fb2 100644 --- a/src/armnn/BackendHelper.cpp +++ b/src/armnn/BackendHelper.cpp @@ -5,6 +5,7 @@ #include <armnn/BackendHelper.hpp> #include <armnn/BackendRegistry.hpp> +#include <armnn/Logging.hpp> #include <armnn/backends/IBackendInternal.hpp> @@ -399,22 +400,48 @@ bool LayerSupportHandle::IsFullyConnectedSupported(const TensorInfo& input, const FullyConnectedDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported) { - if(!descriptor.m_ConstantWeights && !m_BackendId.IsUndefined()) + if(!m_BackendId.IsUndefined()) { - auto capability = GetCapability("NonConstWeights", m_BackendId); - if (capability.has_value() && capability.value().GetValue().AsBool() == true) + auto capability = GetCapability("ConstantTensorsAsInputs", m_BackendId); + if(!capability.has_value() || capability.value().GetValue().AsBool() == false) { - return true; + if(!weights.IsConstant()) + { + return false; + } + if(descriptor.m_BiasEnabled) + { + if(!biases.IsConstant()) + { + return false; + } + } + + // At the first stage we will only print a warning. this is to give + // backend developers a chance to adopt and read weights from input slots. + ARMNN_LOG(warning) << "The backend makes use of a deprecated interface to read constant tensors. " + "If you are a backend developer please find more information in our " + "doxygen documentation on github https://github.com/ARM-software/armnn " + "under the keyword 'ConstTensorsAsInputs'."; + } + + if(!descriptor.m_ConstantWeights) + { + auto capability = GetCapability("NonConstWeights", m_BackendId); + if (capability.has_value() && capability.value().GetValue().AsBool() == true) + { + return true; + } + return false; } - return false; } return m_LayerSupport->IsFullyConnectedSupported(input, - output, - weights, - biases, - descriptor, - reasonIfUnsupported.value()); + output, + weights, + biases, + descriptor, + reasonIfUnsupported.value()); } bool LayerSupportHandle::IsGatherSupported(const TensorInfo& input0, |