aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp')
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp29
1 files changed, 29 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index fa6ec1000b..ff632fc701 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -633,5 +633,34 @@ bool IsMeanLayerSupportedTests(std::string& reasonIfUnsupported)
return result;
}
+// Tests that IsMeanSupported fails when input tensor dimensions
+// do not match output tensor dimensions when keepDims == true
+template<typename FactoryType, armnn::DataType InputDataType , armnn::DataType OutputDataType>
+bool IsMeanLayerNotSupportedTests(std::string& reasonIfUnsupported)
+{
+ armnn::Graph graph;
+ static const std::vector<unsigned> axes = {};
+ // Set keepDims == true
+ armnn::MeanDescriptor desc(axes, true);
+
+ armnn::Layer* const layer = graph.AddLayer<armnn::MeanLayer>(desc, "LayerName");
+
+ armnn::Layer* const input = graph.AddLayer<armnn::InputLayer>(0, "input");
+ armnn::Layer* const output = graph.AddLayer<armnn::OutputLayer>(0, "output");
+
+ // Mismatching number of tensor dimensions
+ armnn::TensorInfo inputTensorInfo({1, 1, 1, 1}, InputDataType);
+ armnn::TensorInfo outputTensorInfo({1, 1}, OutputDataType);
+
+ input->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+ input->GetOutputHandler(0).SetTensorInfo(inputTensorInfo);
+ layer->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+ layer->GetOutputHandler(0).SetTensorInfo(outputTensorInfo);
+
+ bool result = FactoryType::IsLayerSupported(*layer, InputDataType, reasonIfUnsupported);
+
+ return result;
+}
+
} //namespace