aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornikraj01 <nikhil.raj@arm.com>2019-06-14 14:20:40 +0100
committernikraj01 <nikhil.raj@arm.com>2019-06-14 14:20:40 +0100
commit24d7321ad7897e8836d4a38039a73a0ec419cf43 (patch)
tree9401dcc97d4843e66473b8fbf8d07c40c561750e
parentc6138d8a8af334fad5230d73e456f303f9665bae (diff)
downloadarmnn-24d7321ad7897e8836d4a38039a73a0ec419cf43.tar.gz
IVGCVSW-3225 Add QSymm16 support for Rsqrt workload
Change-Id: I83b8494af24ff271dc4cd609944a1c5c55c405e0 Signed-off-by: nikraj01 <nikhil.raj@arm.com>
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp3
-rw-r--r--src/backends/reference/RefLayerSupport.cpp5
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp5
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp2
4 files changed, 12 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 098e8575b1..a1d00c6945 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1489,7 +1489,8 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
DataType::Float16,
DataType::Float32,
- DataType::QuantisedAsymm8
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
};
ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index d8c942cd96..03c8633dce 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1136,10 +1136,11 @@ bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported) const
{
bool supported = true;
- std::array<DataType,2> supportedTypes =
+ std::array<DataType,3> supportedTypes =
{
DataType::Float32,
- DataType::QuantisedAsymm8
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
};
supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 8d15530c1e..a0fc7286a9 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -682,6 +682,11 @@ BOOST_AUTO_TEST_CASE(CreateRsqrtUint8)
RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::QuantisedAsymm8>();
}
+BOOST_AUTO_TEST_CASE(CreateRsqrtQsymm16)
+{
+ RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::QuantisedSymm16>();
+}
+
template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
static void RefCreateL2NormalizationTest(DataLayout dataLayout)
{
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 5bab6bb605..7ff6d1b269 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -574,6 +574,8 @@ ARMNN_AUTO_TEST_CASE(RsqrtZero, RsqrtZeroTest<armnn::DataType::Float32>)
ARMNN_AUTO_TEST_CASE(RsqrtNegative, RsqrtNegativeTest<armnn::DataType::Float32>)
ARMNN_AUTO_TEST_CASE(Rsqrt2dQuantisedAsymm8, Rsqrt2dTest<armnn::DataType::QuantisedAsymm8>)
ARMNN_AUTO_TEST_CASE(Rsqrt3dQuantisedAsymm8, Rsqrt3dTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(Rsqrt2dQuantisedSymm16, Rsqrt2dTest<armnn::DataType::QuantisedSymm16>)
+ARMNN_AUTO_TEST_CASE(Rsqrt3dQuantisedSymm16, Rsqrt3dTest<armnn::DataType::QuantisedSymm16>)
// Permute
ARMNN_AUTO_TEST_CASE(SimplePermuteFloat32, SimplePermuteFloat32Test)