aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-07-03 14:55:57 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-07-03 18:11:20 +0100
commitf9ac3fd5676565b1065698158f8d54a27f24981c (patch)
tree60c3549217ad4257ecd2fad9e6c0e0116995bb12 /src/backends/reference/RefLayerSupport.cpp
parent6133cc31e4f4638494b663243543d8564a450ff1 (diff)
downloadarmnn-f9ac3fd5676565b1065698158f8d54a27f24981c.tar.gz
IVGCVSW-3399 Add QSymm16 IsLayerSupportedTest to reference backend
* Refactor and add QSymm16 to IsDepthwiseConvolutionSupported * Refactor and add QSymm16 to IsEqualSupported * Refactor and add QSymm16 to IsGreaterSupported * Refactor and add QSymm16 to IsSplitterSupported * Refactor and add QSymm16 to IsStridedSliceSupported * Refactor and add QSymm16 to IsMemCopySupported * Refactor IsFakeQuantizationSupported Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I4f115a5535748fc22df8bc90b24b537fd5dd95b8
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp220
1 files changed, 168 insertions, 52 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 3d260c5abd..26070a5328 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -547,14 +547,45 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
const Optional<TensorInfo>& biases,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(output);
+ bool supported = true;
+
+ // Define supported types.
+ std::array<DataType,3> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference DepthwiseConvolution2d: input is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference DepthwiseConvolution2d: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
+ "Reference DepthwiseConvolution2d: weights is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference DepthwiseConvolution2d: input and output types mismatched.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
+ "Reference DepthwiseConvolution2d: input and weights types mismatched.");
+
+ if (biases.has_value())
+ {
+ std::array<DataType,2> biasesSupportedTypes =
+ {
+ DataType::Float32,
+ DataType::Signed32
+ };
+ supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
+ "Reference DepthwiseConvolution2d: biases is not a supported type.");
+ }
ignore_unused(descriptor);
- ignore_unused(weights);
- ignore_unused(biases);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+
+ return supported;
+
}
bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
@@ -656,14 +687,28 @@ bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input0);
- ignore_unused(input1);
- ignore_unused(output);
- ignore_unused(reasonIfUnsupported);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference equal: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference equal: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference equal: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference equal: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
@@ -671,10 +716,17 @@ bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported) const
{
ignore_unused(descriptor);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &FalseFuncU8<>);
+ bool supported = true;
+
+ std::array<DataType,1> supportedTypes =
+ {
+ DataType::Float32
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference fake quantization: input type not supported.");
+
+ return supported;
}
bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
@@ -763,9 +815,9 @@ bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
bool supported = true;
std::array<DataType,3> supportedTypes =
{
- DataType::Float32,
- DataType::QuantisedAsymm8,
- DataType::QuantisedSymm16
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
};
supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
@@ -788,14 +840,28 @@ bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input0);
- ignore_unused(input1);
- ignore_unused(output);
- ignore_unused(reasonIfUnsupported);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference greater: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference greater: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference greater: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference greater: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
@@ -1027,14 +1093,27 @@ bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
const TensorInfo &output,
Optional<std::string &> reasonIfUnsupported) const
{
- ignore_unused(output);
- return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>,
- &TrueFunc<>,
- &FalseFuncI32<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,5> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16,
+ DataType::Boolean
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference MemCopy: input type not supported");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference MemCopy: output type not supported");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference MemCopy: input and output types are mismatched");
+
+ return supported;
}
bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
@@ -1401,10 +1480,18 @@ bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported) const
{
ignore_unused(descriptor);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+ std::array<DataType,3> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference splitter: input type not supported");
+
+ return supported;
}
bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
@@ -1413,11 +1500,26 @@ bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported) const
{
ignore_unused(descriptor);
- ignore_unused(outputs);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+ std::array<DataType,3> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference splitter: output type not supported");
+ for (const TensorInfo output : outputs)
+ {
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference splitter: input type not supported");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference splitter: input and output types mismatched.");
+ }
+
+ return supported;
}
bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
@@ -1425,12 +1527,26 @@ bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
const StridedSliceDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(output);
ignore_unused(descriptor);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference StridedSlice: input type not supported");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference StridedSlice: output type not supported");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference StridedSlice: input and output types are mismatched");
+
+ return supported;
}
bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,