aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp170
1 files changed, 138 insertions, 32 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index d2cf6f904a..3512d52acf 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -228,9 +228,10 @@ bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
{
bool supported = true;
- std::array<DataType,2> supportedTypes = {
+ std::array<DataType,3> supportedTypes = {
DataType::Float32,
- DataType::QuantisedAsymm8
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
};
supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
@@ -432,12 +433,33 @@ bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference division: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference division: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference division: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference division: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference division: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference division: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
@@ -606,12 +628,33 @@ bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference maximum: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference maximum: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference maximum: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference maximum: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference maximum: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference maximum: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
@@ -659,12 +702,33 @@ bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference minimum: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference minimum: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference minimum: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference minimum: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference minimum: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference minimum: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
@@ -672,12 +736,33 @@ bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference multiplication: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference multiplication: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference multiplication: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference multiplication: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference multiplication: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference multiplication: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
@@ -860,12 +945,33 @@ bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference subtraction: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference subtraction: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference subtraction: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference subtraction: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference subtraction: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference subtraction: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
} // namespace armnn