aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/test')
-rw-r--r--src/backends/reference/test/RefLayerSupportTests.cpp25
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp28
2 files changed, 40 insertions, 13 deletions
diff --git a/src/backends/reference/test/RefLayerSupportTests.cpp b/src/backends/reference/test/RefLayerSupportTests.cpp
index 2c7e17da43..0d99b3e66f 100644
--- a/src/backends/reference/test/RefLayerSupportTests.cpp
+++ b/src/backends/reference/test/RefLayerSupportTests.cpp
@@ -14,6 +14,7 @@
#include <backendsCommon/test/IsLayerSupportedTestImpl.hpp>
#include <boost/test/unit_test.hpp>
+#include <boost/algorithm/string/trim.hpp>
#include <string>
@@ -130,4 +131,28 @@ BOOST_AUTO_TEST_CASE(IsConvertFp32ToFp16SupportedFp32OutputReference)
BOOST_CHECK_EQUAL(reasonIfUnsupported, "Layer is not supported with float32 data type output");
}
+BOOST_AUTO_TEST_CASE(IsLayerSupportedMeanDimensionsReference)
+{
+ std::string reasonIfUnsupported;
+
+ bool result = IsMeanLayerSupportedTests<armnn::RefWorkloadFactory,
+ armnn::DataType::Float32, armnn::DataType::Float32>(reasonIfUnsupported);
+
+ BOOST_CHECK(result);
+}
+
+BOOST_AUTO_TEST_CASE(IsLayerNotSupportedMeanDimensionsReference)
+{
+ std::string reasonIfUnsupported;
+
+ bool result = IsMeanLayerNotSupportedTests<armnn::RefWorkloadFactory,
+ armnn::DataType::Float32, armnn::DataType::Float32>(reasonIfUnsupported);
+
+ BOOST_CHECK(!result);
+
+ boost::algorithm::trim(reasonIfUnsupported);
+ BOOST_CHECK_EQUAL(reasonIfUnsupported,
+ "Reference Mean: Expected 4 dimensions but got 2 dimensions instead, for the 'output' tensor.");
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 7ff6d1b269..c2cda8ec6b 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -607,19 +607,21 @@ ARMNN_AUTO_TEST_CASE(SimpleConvertFp16ToFp32, SimpleConvertFp16ToFp32Test)
ARMNN_AUTO_TEST_CASE(SimpleConvertFp32ToFp16, SimpleConvertFp32ToFp16Test)
// Mean
-ARMNN_AUTO_TEST_CASE(MeanUint8Simple, MeanUint8SimpleTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8SimpleAxis, MeanUint8SimpleAxisTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8KeepDims, MeanUint8KeepDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8MultipleDims, MeanUint8MultipleDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanVtsUint8, MeanVtsUint8Test)
-
-ARMNN_AUTO_TEST_CASE(MeanFloatSimple, MeanFloatSimpleTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatSimpleAxis, MeanFloatSimpleAxisTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatKeepDims, MeanFloatKeepDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatMultipleDims, MeanFloatMultipleDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat1, MeanVtsFloat1Test)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat2, MeanVtsFloat2Test)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat3, MeanVtsFloat3Test)
+ARMNN_AUTO_TEST_CASE(MeanSimpleFloat32, MeanSimpleTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanSimpleAxisFloat32, MeanSimpleAxisTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanKeepDimsFloat32, MeanKeepDimsTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanMultipleDimsFloat32, MeanMultipleDimsTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts1Float32, MeanVts1Test<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts2Float32, MeanVts2Test<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts3Float32, MeanVts3Test<armnn::DataType::Float32>)
+
+ARMNN_AUTO_TEST_CASE(MeanSimpleQuantisedAsymm8, MeanSimpleTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanSimpleAxisQuantisedAsymm8, MeanSimpleAxisTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanKeepDimsQuantisedAsymm8, MeanKeepDimsTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanMultipleDimsQuantisedAsymm8, MeanMultipleDimsTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts1QuantisedAsymm8, MeanVts1Test<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts2QuantisedAsymm8, MeanVts2Test<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts3QuantisedAsymm8, MeanVts3Test<armnn::DataType::QuantisedAsymm8>)
ARMNN_AUTO_TEST_CASE(AdditionAfterMaxPool, AdditionAfterMaxPoolTest)