diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 5 | ||||
-rw-r--r-- | src/backends/reference/test/RefCreateWorkloadTests.cpp | 12 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 2 |
3 files changed, 17 insertions, 2 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index aeff51d853..adc63e965b 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -288,10 +288,11 @@ bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input, { ignore_unused(descriptor); - std::array<DataType, 2> supportedTypes = + std::array<DataType, 3> supportedTypes = { DataType::Float32, - DataType::QuantisedAsymm8 + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 }; bool supported = true; diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index a0c614564d..83e3f6c7bb 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -229,6 +229,18 @@ BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8WorkloadNhwc) (DataLayout::NHWC); } +BOOST_AUTO_TEST_CASE(CreateBatchNormalizationInt16Workload) +{ + RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QuantisedSymm16> + (DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreateBatchNormalizationInt16WorkloadNhwc) +{ + RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QuantisedSymm16> + (DataLayout::NHWC); +} + BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Float32Workload) { Graph graph; diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 162027032e..afeadb9485 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -357,6 +357,8 @@ ARMNN_AUTO_TEST_CASE(BatchNorm, BatchNormTest) ARMNN_AUTO_TEST_CASE(BatchNormNhwc, BatchNormNhwcTest) ARMNN_AUTO_TEST_CASE(BatchNormUint8, BatchNormUint8Test) ARMNN_AUTO_TEST_CASE(BatchNormUint8Nhwc, BatchNormUint8NhwcTest) +ARMNN_AUTO_TEST_CASE(BatchNormInt16, BatchNormInt16Test) +ARMNN_AUTO_TEST_CASE(BatchNormInt16Nhwc, BatchNormInt16NhwcTest) // Resize Bilinear - NCHW ARMNN_AUTO_TEST_CASE(SimpleResizeBilinear, SimpleResizeBilinearTest, armnn::DataLayout::NCHW) |