From 6aeb771e854ed45f0392d6c17000a6a039b6256a Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Wed, 5 Jun 2019 17:23:29 +0100 Subject: IVGCVSW-3227 Extend the reference normalization workload to support QSymm16 * Added support for QSymm16 * Added unit tests Change-Id: I7ba57793830bed7958ac9a94e9ac39d6dbe708b5 Signed-off-by: Matteo Martincigh --- src/backends/backendsCommon/WorkloadData.cpp | 3 ++- src/backends/reference/RefLayerSupport.cpp | 5 +++-- src/backends/reference/test/RefCreateWorkloadTests.cpp | 10 ++++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 9482136b59..e14283d823 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -614,7 +614,8 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co { DataType::Float16, DataType::Float32, - DataType::QuantisedAsymm8 + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 }; ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 60536081be..39f521c120 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -962,11 +962,12 @@ bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input, ignore_unused(descriptor); // Define supported types - std::array supportedTypes = + std::array supportedTypes = { DataType::Float16, DataType::Float32, - DataType::QuantisedAsymm8 + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 }; bool supported = true; diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 3da9de9263..d550f00f15 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -392,6 +392,16 @@ BOOST_AUTO_TEST_CASE(CreateRefNormalizationUint8NhwcWorkload) RefCreateNormalizationWorkloadTest(DataLayout::NHWC); } +BOOST_AUTO_TEST_CASE(CreateRefNormalizationInt16NchwWorkload) +{ + RefCreateNormalizationWorkloadTest(DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreateRefNormalizationInt16NhwcWorkload) +{ + RefCreateNormalizationWorkloadTest(DataLayout::NHWC); +} + template static void RefCreatePooling2dWorkloadTest(DataLayout dataLayout) { -- cgit v1.2.1