aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp3
-rw-r--r--src/backends/reference/RefLayerSupport.cpp5
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp10
3 files changed, 15 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 9482136b59..e14283d823 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -614,7 +614,8 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
{
DataType::Float16,
DataType::Float32,
- DataType::QuantisedAsymm8
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
};
ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 60536081be..39f521c120 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -962,11 +962,12 @@ bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
ignore_unused(descriptor);
// Define supported types
- std::array<DataType, 3> supportedTypes =
+ std::array<DataType, 4> supportedTypes =
{
DataType::Float16,
DataType::Float32,
- DataType::QuantisedAsymm8
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
};
bool supported = true;
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 3da9de9263..d550f00f15 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -392,6 +392,16 @@ BOOST_AUTO_TEST_CASE(CreateRefNormalizationUint8NhwcWorkload)
RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QuantisedAsymm8>(DataLayout::NHWC);
}
+BOOST_AUTO_TEST_CASE(CreateRefNormalizationInt16NchwWorkload)
+{
+ RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QuantisedSymm16>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateRefNormalizationInt16NhwcWorkload)
+{
+ RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QuantisedSymm16>(DataLayout::NHWC);
+}
+
template <typename Pooling2dWorkloadType, armnn::DataType DataType>
static void RefCreatePooling2dWorkloadTest(DataLayout dataLayout)
{