aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp82
1 files changed, 76 insertions, 6 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index cf1814e06a..402bd66f02 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -47,6 +47,21 @@ bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
} // anonymous namespace
+namespace
+{
+
+std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
+ unsigned int actual,
+ std::string& layerStr,
+ std::string& tensorName)
+{
+ std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
+ " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
+
+ return errorMsg;
+}
+
+} // anonymous namespace
namespace
{
@@ -177,6 +192,15 @@ struct ShapesAreBroadcastCompatible : public Rule
}
}
};
+
+struct TensorNumDimensionsAreCorrect : public Rule
+{
+ TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
+ {
+ m_Res = info.GetNumDimensions() == expectedNumDimensions;
+ }
+};
+
} // namespace
@@ -874,12 +898,58 @@ bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
const MeanDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(output);
- ignore_unused(descriptor);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+ std::string meanLayerStr = "Mean";
+ std::string outputTensorStr = "output";
+
+ std::array<DataType,2> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference Mean: input type not supported.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference Mean: input and output types are mismatched");
+
+ if (descriptor.m_KeepDims)
+ {
+ supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
+ reasonIfUnsupported,
+ CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
+ output.GetNumDimensions(),
+ meanLayerStr, outputTensorStr).data());
+ }
+ else if (descriptor.m_Axis.empty())
+ {
+ supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
+ reasonIfUnsupported,
+ CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
+ meanLayerStr, outputTensorStr).data());
+ }
+ else
+ {
+ auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
+
+ if (outputDim > 0)
+ {
+ supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
+ reasonIfUnsupported,
+ CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
+ meanLayerStr, outputTensorStr).data());
+ }
+ else
+ {
+ supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
+ reasonIfUnsupported,
+ CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
+ meanLayerStr, outputTensorStr).data());
+ }
+ }
+
+ return supported;
}
bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,