diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-06-04 10:59:47 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2019-06-04 15:20:45 +0000 |
commit | f550713476f404a82e59bd68223a8a4955e753f2 (patch) | |
tree | 762b3eae41d412c1da7b45c9ead1ca2547b66f12 /src/backends/reference | |
parent | 3122bd574a3d29774c535ca2136de361da626e88 (diff) | |
download | armnn-f550713476f404a82e59bd68223a8a4955e753f2.tar.gz |
IVGCVSW-3213 Extend the Reference BatchNormalization workload to
support the new QSymm16 type
* Added QSymm16 to the range of supported types for batch
normalization ref workloads
* Added unit tests for QSymm16
Change-Id: I5b2fcfbd9cb5af149ebfe24e2d95f3affa2e3690
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 5 | ||||
-rw-r--r-- | src/backends/reference/test/RefCreateWorkloadTests.cpp | 12 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 2 |
3 files changed, 17 insertions, 2 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index aeff51d853..adc63e965b 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -288,10 +288,11 @@ bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input, { ignore_unused(descriptor); - std::array<DataType, 2> supportedTypes = + std::array<DataType, 3> supportedTypes = { DataType::Float32, - DataType::QuantisedAsymm8 + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 }; bool supported = true; diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index a0c614564d..83e3f6c7bb 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -229,6 +229,18 @@ BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8WorkloadNhwc) (DataLayout::NHWC); } +BOOST_AUTO_TEST_CASE(CreateBatchNormalizationInt16Workload) +{ + RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QuantisedSymm16> + (DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreateBatchNormalizationInt16WorkloadNhwc) +{ + RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QuantisedSymm16> + (DataLayout::NHWC); +} + BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Float32Workload) { Graph graph; diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 162027032e..afeadb9485 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -357,6 +357,8 @@ ARMNN_AUTO_TEST_CASE(BatchNorm, BatchNormTest) ARMNN_AUTO_TEST_CASE(BatchNormNhwc, BatchNormNhwcTest) ARMNN_AUTO_TEST_CASE(BatchNormUint8, BatchNormUint8Test) ARMNN_AUTO_TEST_CASE(BatchNormUint8Nhwc, BatchNormUint8NhwcTest) +ARMNN_AUTO_TEST_CASE(BatchNormInt16, BatchNormInt16Test) +ARMNN_AUTO_TEST_CASE(BatchNormInt16Nhwc, BatchNormInt16NhwcTest) // Resize Bilinear - NCHW ARMNN_AUTO_TEST_CASE(SimpleResizeBilinear, SimpleResizeBilinearTest, armnn::DataLayout::NCHW) |