aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp42
1 files changed, 35 insertions, 7 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 6ad6816474..b6da628be3 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -364,14 +364,42 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
const Optional<TensorInfo>& biases,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(output);
+ bool supported = true;
+
+ // Define supported types.
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference addition: input is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference addition: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
+ "Reference addition: weights is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference activation: input and output types mismatched.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
+ "Reference activation: input and weights types mismatched.");
+
+ if (biases.has_value())
+ {
+ std::array<DataType,3> biasesSupportedTypes = {
+ DataType::Float32,
+ DataType::Signed32
+ };
+ supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
+ "Reference addition: biases is not a supported type.");
+ }
ignore_unused(descriptor);
- ignore_unused(weights);
- ignore_unused(biases);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+
+ return supported;
}
bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,