aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2019-06-11 11:25:30 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-06-17 11:17:15 +0000
commitb80775f7d19b8535383f96a00cde85feec338741 (patch)
treeb9c4466dd52a6613257410f48096e1fee7916c16
parent4d1ff588288b1a7a98dd2fd7f2ba5717b8ecf102 (diff)
downloadarmnn-b80775f7d19b8535383f96a00cde85feec338741.tar.gz
IVGCVSW-3222 Extend Mean ref workload to support QSymm16
* Added support for QSymm16 in Mean ref workload * Added unit tests for QSymm16 Mean Signed-off-by: James Conroy <james.conroy@arm.com> Change-Id: I600b15069ff4a4531666c6bc7fb73187dcebf0ee
-rw-r--r--src/backends/reference/RefLayerSupport.cpp5
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp8
2 files changed, 11 insertions, 2 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 402bd66f02..a25338f906 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -902,10 +902,11 @@ bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
std::string meanLayerStr = "Mean";
std::string outputTensorStr = "output";
- std::array<DataType,2> supportedTypes =
+ std::array<DataType,3> supportedTypes =
{
DataType::Float32,
- DataType::QuantisedAsymm8
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
};
supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index c2cda8ec6b..155da246bd 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -623,6 +623,14 @@ ARMNN_AUTO_TEST_CASE(MeanVts1QuantisedAsymm8, MeanVts1Test<armnn::DataType::Quan
ARMNN_AUTO_TEST_CASE(MeanVts2QuantisedAsymm8, MeanVts2Test<armnn::DataType::QuantisedAsymm8>)
ARMNN_AUTO_TEST_CASE(MeanVts3QuantisedAsymm8, MeanVts3Test<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanSimpleQuantisedSymm16, MeanSimpleTest<armnn::DataType::QuantisedSymm16>)
+ARMNN_AUTO_TEST_CASE(MeanSimpleAxisQuantisedSymm16, MeanSimpleAxisTest<armnn::DataType::QuantisedSymm16>)
+ARMNN_AUTO_TEST_CASE(MeanKeepDimsQuantisedSymm16, MeanKeepDimsTest<armnn::DataType::QuantisedSymm16>)
+ARMNN_AUTO_TEST_CASE(MeanMultipleDimsQuantisedSymm16, MeanMultipleDimsTest<armnn::DataType::QuantisedSymm16>)
+ARMNN_AUTO_TEST_CASE(MeanVts1QuantisedSymm16, MeanVts1Test<armnn::DataType::QuantisedSymm16>)
+ARMNN_AUTO_TEST_CASE(MeanVts2QuantisedSymm16, MeanVts2Test<armnn::DataType::QuantisedSymm16>)
+ARMNN_AUTO_TEST_CASE(MeanVts3QuantisedSymm16, MeanVts3Test<armnn::DataType::QuantisedSymm16>)
+
ARMNN_AUTO_TEST_CASE(AdditionAfterMaxPool, AdditionAfterMaxPoolTest)
// Space To Batch Nd