diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 3 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 5 | ||||
-rw-r--r-- | src/backends/reference/test/RefCreateWorkloadTests.cpp | 5 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 2 |
4 files changed, 12 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 098e8575b1..a1d00c6945 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1489,7 +1489,8 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { DataType::Float16, DataType::Float32, - DataType::QuantisedAsymm8 + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 }; ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index d8c942cd96..03c8633dce 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1136,10 +1136,11 @@ bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input, Optional<std::string&> reasonIfUnsupported) const { bool supported = true; - std::array<DataType,2> supportedTypes = + std::array<DataType,3> supportedTypes = { DataType::Float32, - DataType::QuantisedAsymm8 + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 }; supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 8d15530c1e..a0fc7286a9 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -682,6 +682,11 @@ BOOST_AUTO_TEST_CASE(CreateRsqrtUint8) RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::QuantisedAsymm8>(); } +BOOST_AUTO_TEST_CASE(CreateRsqrtQsymm16) +{ + RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::QuantisedSymm16>(); +} + template <typename L2NormalizationWorkloadType, armnn::DataType DataType> static void RefCreateL2NormalizationTest(DataLayout dataLayout) { diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 5bab6bb605..7ff6d1b269 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -574,6 +574,8 @@ ARMNN_AUTO_TEST_CASE(RsqrtZero, RsqrtZeroTest<armnn::DataType::Float32>) ARMNN_AUTO_TEST_CASE(RsqrtNegative, RsqrtNegativeTest<armnn::DataType::Float32>) ARMNN_AUTO_TEST_CASE(Rsqrt2dQuantisedAsymm8, Rsqrt2dTest<armnn::DataType::QuantisedAsymm8>) ARMNN_AUTO_TEST_CASE(Rsqrt3dQuantisedAsymm8, Rsqrt3dTest<armnn::DataType::QuantisedAsymm8>) +ARMNN_AUTO_TEST_CASE(Rsqrt2dQuantisedSymm16, Rsqrt2dTest<armnn::DataType::QuantisedSymm16>) +ARMNN_AUTO_TEST_CASE(Rsqrt3dQuantisedSymm16, Rsqrt3dTest<armnn::DataType::QuantisedSymm16>) // Permute ARMNN_AUTO_TEST_CASE(SimplePermuteFloat32, SimplePermuteFloat32Test) |