aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/BaseIterator.hpp
diff options
context:
space:
mode:
authorJohn Mcloughlin <john.mcloughlin@arm.com>2023-10-16 10:28:40 +0100
committerjohn.mcloughlin <john.mcloughlin@arm.com>2023-10-16 16:16:59 +0000
commitb41793a9f9afc43fb04a991ca819818fca8faab8 (patch)
tree4206314ed348eeeeebc3f7747712b14f1d26e90d /src/backends/reference/workloads/BaseIterator.hpp
parent363b572b61f7a32e92cde51478d7556ce43db56f (diff)
downloadarmnn-b41793a9f9afc43fb04a991ca819818fca8faab8.tar.gz
IVGCVSW-7752 DTS: Fix QuantizePerChannel tests
* Added validation for scale on all Quantized types * Added Encoder for Per Axis UINT16 Symmetrical Quantized type * Added error for Per Axis Asymmetrical Quantized type not supported Signed-off-by: John Mcloughlin <john.mcloughlin@arm.com> Change-Id: I433519ccacd71219a92bde2b81955d6abf9219c5
Diffstat (limited to 'src/backends/reference/workloads/BaseIterator.hpp')
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp29
1 files changed, 28 insertions, 1 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index 2d27951b73..1665c1ff46 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -896,4 +896,31 @@ private:
std::vector<float> m_Scales;
};
+class QSymm16PerAxisEncoder : public PerAxisIterator<int16_t, Encoder<float>>
+{
+public:
+ QSymm16PerAxisEncoder(int16_t* data, const std::vector<float>& scale,
+ unsigned int axisFactor, unsigned int axisDimensionality)
+ : PerAxisIterator(data, axisFactor, axisDimensionality), m_Scale(scale) {}
+
+ void Set(float right)
+ {
+ *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale[m_AxisIndex], 0);
+ }
+
+ float Get() const
+ {
+ return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
+ }
+
+ // Get scale of the current value
+ float GetScale() const
+ {
+ return m_Scale[m_AxisIndex];
+ }
+
+private:
+ std::vector<float> m_Scale;
+};
+
} // namespace armnn