aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/BackendHelper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/BackendHelper.cpp')
-rw-r--r--src/armnn/BackendHelper.cpp47
1 files changed, 37 insertions, 10 deletions
diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp
index 13bde0aafa..9ab30f8fb2 100644
--- a/src/armnn/BackendHelper.cpp
+++ b/src/armnn/BackendHelper.cpp
@@ -5,6 +5,7 @@
#include <armnn/BackendHelper.hpp>
#include <armnn/BackendRegistry.hpp>
+#include <armnn/Logging.hpp>
#include <armnn/backends/IBackendInternal.hpp>
@@ -399,22 +400,48 @@ bool LayerSupportHandle::IsFullyConnectedSupported(const TensorInfo& input,
const FullyConnectedDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported)
{
- if(!descriptor.m_ConstantWeights && !m_BackendId.IsUndefined())
+ if(!m_BackendId.IsUndefined())
{
- auto capability = GetCapability("NonConstWeights", m_BackendId);
- if (capability.has_value() && capability.value().GetValue().AsBool() == true)
+ auto capability = GetCapability("ConstantTensorsAsInputs", m_BackendId);
+ if(!capability.has_value() || capability.value().GetValue().AsBool() == false)
{
- return true;
+ if(!weights.IsConstant())
+ {
+ return false;
+ }
+ if(descriptor.m_BiasEnabled)
+ {
+ if(!biases.IsConstant())
+ {
+ return false;
+ }
+ }
+
+ // At the first stage we will only print a warning. this is to give
+ // backend developers a chance to adopt and read weights from input slots.
+ ARMNN_LOG(warning) << "The backend makes use of a deprecated interface to read constant tensors. "
+ "If you are a backend developer please find more information in our "
+ "doxygen documentation on github https://github.com/ARM-software/armnn "
+ "under the keyword 'ConstTensorsAsInputs'.";
+ }
+
+ if(!descriptor.m_ConstantWeights)
+ {
+ auto capability = GetCapability("NonConstWeights", m_BackendId);
+ if (capability.has_value() && capability.value().GetValue().AsBool() == true)
+ {
+ return true;
+ }
+ return false;
}
- return false;
}
return m_LayerSupport->IsFullyConnectedSupported(input,
- output,
- weights,
- biases,
- descriptor,
- reasonIfUnsupported.value());
+ output,
+ weights,
+ biases,
+ descriptor,
+ reasonIfUnsupported.value());
}
bool LayerSupportHandle::IsGatherSupported(const TensorInfo& input0,