diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 3 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 5 | ||||
-rw-r--r-- | src/backends/reference/test/RefCreateWorkloadTests.cpp | 10 |
3 files changed, 15 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 9482136b59..e14283d823 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -614,7 +614,8 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co { DataType::Float16, DataType::Float32, - DataType::QuantisedAsymm8 + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 }; ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 60536081be..39f521c120 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -962,11 +962,12 @@ bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input, ignore_unused(descriptor); // Define supported types - std::array<DataType, 3> supportedTypes = + std::array<DataType, 4> supportedTypes = { DataType::Float16, DataType::Float32, - DataType::QuantisedAsymm8 + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 }; bool supported = true; diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 3da9de9263..d550f00f15 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -392,6 +392,16 @@ BOOST_AUTO_TEST_CASE(CreateRefNormalizationUint8NhwcWorkload) RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QuantisedAsymm8>(DataLayout::NHWC); } +BOOST_AUTO_TEST_CASE(CreateRefNormalizationInt16NchwWorkload) +{ + RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QuantisedSymm16>(DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreateRefNormalizationInt16NhwcWorkload) +{ + RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QuantisedSymm16>(DataLayout::NHWC); +} + template <typename Pooling2dWorkloadType, armnn::DataType DataType> static void RefCreatePooling2dWorkloadTest(DataLayout dataLayout) { |