From 24d7321ad7897e8836d4a38039a73a0ec419cf43 Mon Sep 17 00:00:00 2001 From: nikraj01 Date: Fri, 14 Jun 2019 14:20:40 +0100 Subject: IVGCVSW-3225 Add QSymm16 support for Rsqrt workload Change-Id: I83b8494af24ff271dc4cd609944a1c5c55c405e0 Signed-off-by: nikraj01 --- src/backends/backendsCommon/WorkloadData.cpp | 3 ++- src/backends/reference/RefLayerSupport.cpp | 5 +++-- src/backends/reference/test/RefCreateWorkloadTests.cpp | 5 +++++ src/backends/reference/test/RefLayerTests.cpp | 2 ++ 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 098e8575b1..a1d00c6945 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1489,7 +1489,8 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { DataType::Float16, DataType::Float32, - DataType::QuantisedAsymm8 + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 }; ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index d8c942cd96..03c8633dce 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1136,10 +1136,11 @@ bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input, Optional reasonIfUnsupported) const { bool supported = true; - std::array supportedTypes = + std::array supportedTypes = { DataType::Float32, - DataType::QuantisedAsymm8 + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 }; supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 8d15530c1e..a0fc7286a9 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -682,6 +682,11 @@ BOOST_AUTO_TEST_CASE(CreateRsqrtUint8) RefCreateRsqrtTest(); } +BOOST_AUTO_TEST_CASE(CreateRsqrtQsymm16) +{ + RefCreateRsqrtTest(); +} + template static void RefCreateL2NormalizationTest(DataLayout dataLayout) { diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 5bab6bb605..7ff6d1b269 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -574,6 +574,8 @@ ARMNN_AUTO_TEST_CASE(RsqrtZero, RsqrtZeroTest) ARMNN_AUTO_TEST_CASE(RsqrtNegative, RsqrtNegativeTest) ARMNN_AUTO_TEST_CASE(Rsqrt2dQuantisedAsymm8, Rsqrt2dTest) ARMNN_AUTO_TEST_CASE(Rsqrt3dQuantisedAsymm8, Rsqrt3dTest) +ARMNN_AUTO_TEST_CASE(Rsqrt2dQuantisedSymm16, Rsqrt2dTest) +ARMNN_AUTO_TEST_CASE(Rsqrt3dQuantisedSymm16, Rsqrt3dTest) // Permute ARMNN_AUTO_TEST_CASE(SimplePermuteFloat32, SimplePermuteFloat32Test) -- cgit v1.2.1