aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2019-04-09 14:20:12 +0100
committerSadik Armagan <sadik.armagan@arm.com>2019-04-09 14:24:05 +0100
commit2999a02f0c6a6f290ce45f28c998a1c000d48f67 (patch)
treef9d13cec08ab8c6c47e68df512cddc613552a7d2 /src/backends/reference/RefLayerSupport.cpp
parent998517647d699d602e36f06b40d3f1d1ddaae7be (diff)
downloadarmnn-2999a02f0c6a6f290ce45f28c998a1c000d48f67.tar.gz
IVGCVSW-2862 Extend the Elementwise Workload to support QSymm16 Data Type
IVGCVSW-2863 Unit test per Elementwise operator with QSymm16 Data Type * Added QSymm16 support for Elementwise Operators * Added QSymm16 unit tests for Elementwise Operators Change-Id: I4e4e2938f9ed2cbbb1f05fb0f7dc476768550277 Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp170
1 files changed, 138 insertions, 32 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index d2cf6f904a..3512d52acf 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -228,9 +228,10 @@ bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
{
bool supported = true;
- std::array<DataType,2> supportedTypes = {
+ std::array<DataType,3> supportedTypes = {
DataType::Float32,
- DataType::QuantisedAsymm8
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
};
supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
@@ -432,12 +433,33 @@ bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference division: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference division: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference division: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference division: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference division: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference division: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
@@ -606,12 +628,33 @@ bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference maximum: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference maximum: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference maximum: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference maximum: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference maximum: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference maximum: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
@@ -659,12 +702,33 @@ bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference minimum: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference minimum: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference minimum: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference minimum: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference minimum: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference minimum: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
@@ -672,12 +736,33 @@ bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference multiplication: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference multiplication: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference multiplication: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference multiplication: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference multiplication: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference multiplication: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
@@ -860,12 +945,33 @@ bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference subtraction: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference subtraction: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference subtraction: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference subtraction: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference subtraction: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference subtraction: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
} // namespace armnn