diff options
author | Orlaith Monahan <orlaith.monahan@arm.com> | 2024-05-09 10:48:00 +0100 |
---|---|---|
committer | Orlaith Monahan <orlaith.monahan@arm.com> | 2024-05-09 10:48:00 +0100 |
commit | 76556c21ba546f136026f02a30f5d809fe86e51f (patch) | |
tree | 934a4e64c430686158a800acc80b88ea7ac98b3a | |
parent | 21bda1405d2cb49fc873583b41a48836b33d285e (diff) | |
download | armnn-76556c21ba546f136026f02a30f5d809fe86e51f.tar.gz |
IVGCVSW-8300 Fix for CTS Float16 tests
* Fix for Neon IsLayerSupported to properly check for multiple Quantization Scales
Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
Change-Id: I9f4558cbd62ce1657adb5025ac16c2b5d69d12b1
-rw-r--r-- | src/armnn/Tensor.cpp | 15 | ||||
-rw-r--r-- | src/backends/neon/NeonLayerSupport.cpp | 21 |
2 files changed, 22 insertions, 14 deletions
diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp index f75fc60ef7..650b93835f 100644 --- a/src/armnn/Tensor.cpp +++ b/src/armnn/Tensor.cpp @@ -465,15 +465,12 @@ float TensorInfo::GetQuantizationScale() const // NOTE: old default for backward compatibility return 1.0f; } - // If this tensor includes multiples scales then you should be calling GetQuantizationScales. - // This should be an exception not an assert but unfortunately it breaks many tests. - // ToDo: IVGCVSW-8323 - ARMNN_ASSERT(!HasMultipleQuantizationScales()); -// if (HasMultipleQuantizationScales()) -// { -// throw RuntimeException("Invalid call to GetQuantizationScale on a tensor with multiple scale values. Use " -// "GetQuantizationScales instead."); -// } + + if (HasMultipleQuantizationScales()) + { + throw RuntimeException("Invalid call to GetQuantizationScale on a tensor with multiple scale values. Use " + "GetQuantizationScales instead."); + } return m_Quantization.m_Scales[0]; } diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp index 0298c7c552..b6db52342e 100644 --- a/src/backends/neon/NeonLayerSupport.cpp +++ b/src/backends/neon/NeonLayerSupport.cpp @@ -102,11 +102,22 @@ const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> typ { return info; } - return TensorInfo(info.GetShape(), - type.value(), - info.GetQuantizationScale(), - info.GetQuantizationOffset(), - info.IsConstant()); + if (info.HasMultipleQuantizationScales()) + { + return TensorInfo(info.GetShape(), + type.value(), + info.GetQuantizationScales(), + info.GetQuantizationDim().value(), + info.IsConstant()); + } + else + { + return TensorInfo(info.GetShape(), + type.value(), + info.GetQuantizationScale(), + info.GetQuantizationOffset(), + info.IsConstant()); + } } template< typename ... Args> |