diff options
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 82 |
1 files changed, 76 insertions, 6 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index cf1814e06a..402bd66f02 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -47,6 +47,21 @@ bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported, } // anonymous namespace +namespace +{ + +std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected, + unsigned int actual, + std::string& layerStr, + std::string& tensorName) +{ + std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" + + " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor."; + + return errorMsg; +} + +} // anonymous namespace namespace { @@ -177,6 +192,15 @@ struct ShapesAreBroadcastCompatible : public Rule } } }; + +struct TensorNumDimensionsAreCorrect : public Rule +{ + TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions) + { + m_Res = info.GetNumDimensions() == expectedNumDimensions; + } +}; + } // namespace @@ -874,12 +898,58 @@ bool RefLayerSupport::IsMeanSupported(const TensorInfo& input, const MeanDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported) const { - ignore_unused(output); - ignore_unused(descriptor); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + bool supported = true; + std::string meanLayerStr = "Mean"; + std::string outputTensorStr = "output"; + + std::array<DataType,2> supportedTypes = + { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference Mean: input type not supported."); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference Mean: input and output types are mismatched"); + + if (descriptor.m_KeepDims) + { + supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()), + reasonIfUnsupported, + CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(), + output.GetNumDimensions(), + meanLayerStr, outputTensorStr).data()); + } + else if (descriptor.m_Axis.empty()) + { + supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1), + reasonIfUnsupported, + CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(), + meanLayerStr, outputTensorStr).data()); + } + else + { + auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size()); + + if (outputDim > 0) + { + supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim), + reasonIfUnsupported, + CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(), + meanLayerStr, outputTensorStr).data()); + } + else + { + supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1), + reasonIfUnsupported, + CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(), + meanLayerStr, outputTensorStr).data()); + } + } + + return supported; } bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs, |