aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp85
1 files changed, 41 insertions, 44 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 26a61d45d5..491081dbac 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -70,28 +70,10 @@ std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- bool supported = true;
- std::array<DataType,4> supportedTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QAsymmU8,
- DataType::QSymmS16
- };
-
- supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
- "Reference abs: input type not supported");
-
- supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
- "Reference abs: output type not supported");
-
- supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
- "Reference abs: input and output types not matching");
-
- supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
- "Reference abs: input and output shapes have different number of total elements");
-
- return supported;
+ return IsElementwiseUnarySupported(input,
+ output,
+ ElementwiseUnaryDescriptor(UnaryOperation::Abs),
+ reasonIfUnsupported);
}
bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
@@ -714,6 +696,39 @@ bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
return supported;
}
+bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const ElementwiseUnaryDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ boost::ignore_unused(descriptor);
+
+ std::array<DataType, 4> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
+ };
+
+ bool supported = true;
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference elementwise unary: input type not supported");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference elementwise unary: output type not supported");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference elementwise unary: input and output types not matching");
+
+ supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
+ "Reference elementwise unary: input and output shapes"
+ "have different number of total elements");
+
+ return supported;
+}
+
bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
@@ -1499,28 +1514,10 @@ bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- bool supported = true;
- std::array<DataType,4> supportedTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QAsymmU8,
- DataType::QSymmS16
- };
-
- supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
- "Reference rsqrt: input type not supported");
-
- supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
- "Reference rsqrt: output type not supported");
-
- supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
- "Reference rsqrt: input and output types not matching");
-
- supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
- "Reference Rsqrt: input and output shapes have different number of total elements");
-
- return supported;
+ return IsElementwiseUnarySupported(input,
+ output,
+ ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
+ reasonIfUnsupported);
}
bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,