diff options
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 85 |
1 files changed, 39 insertions, 46 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 9342b29f47..c65886ba4d 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -308,6 +308,35 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const ComparisonDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const +{ + boost::ignore_unused(descriptor); + + std::array<DataType, 4> supportedInputTypes = + { + DataType::Float32, + DataType::Float16, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + bool supported = true; + supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported, + "Reference comparison: input 0 is not a supported type"); + + supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported, + "Reference comparison: input 0 and Input 1 types are mismatched"); + + supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported, + "Reference comparison: output is not of type Boolean"); + + return supported; +} + bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, const ConcatDescriptor& descriptor, @@ -644,29 +673,11 @@ bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { - bool supported = true; - - std::array<DataType,4> supportedTypes = - { - DataType::Float32, - DataType::Float16, - DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 - }; - - supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported, - "Reference equal: input 0 is not a supported type."); - - supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported, - "Reference equal: input 1 is not a supported type."); - - supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported, - "Reference equal: input 0 and Input 1 types are mismatched"); - - supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported, - "Reference equal: shapes are not suitable for implicit broadcast."); - - return supported; + return IsComparisonSupported(input0, + input1, + output, + ComparisonDescriptor(ComparisonOperation::Equal), + reasonIfUnsupported); } bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input, @@ -802,29 +813,11 @@ bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { - bool supported = true; - - std::array<DataType,4> supportedTypes = - { - DataType::Float32, - DataType::Float16, - DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 - }; - - supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported, - "Reference greater: input 0 is not a supported type."); - - supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported, - "Reference greater: input 1 is not a supported type."); - - supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported, - "Reference greater: input 0 and Input 1 types are mismatched"); - - supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported, - "Reference greater: shapes are not suitable for implicit broadcast."); - - return supported; + return IsComparisonSupported(input0, + input1, + output, + ComparisonDescriptor(ComparisonOperation::Greater), + reasonIfUnsupported); } bool RefLayerSupport::IsInputSupported(const TensorInfo& input, |