diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/workloads/BaseIterator.hpp | 29 | ||||
-rw-r--r-- | src/backends/reference/workloads/Encoders.hpp | 24 |
2 files changed, 47 insertions, 6 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp index 2d27951b73..1665c1ff46 100644 --- a/src/backends/reference/workloads/BaseIterator.hpp +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -896,4 +896,31 @@ private: std::vector<float> m_Scales; }; +class QSymm16PerAxisEncoder : public PerAxisIterator<int16_t, Encoder<float>> +{ +public: + QSymm16PerAxisEncoder(int16_t* data, const std::vector<float>& scale, + unsigned int axisFactor, unsigned int axisDimensionality) + : PerAxisIterator(data, axisFactor, axisDimensionality), m_Scale(scale) {} + + void Set(float right) + { + *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale[m_AxisIndex], 0); + } + + float Get() const + { + return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0); + } + + // Get scale of the current value + float GetScale() const + { + return m_Scale[m_AxisIndex]; + } + +private: + std::vector<float> m_Scale; +}; + } // namespace armnn diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp index d6d611494d..8a702377b2 100644 --- a/src/backends/reference/workloads/Encoders.hpp +++ b/src/backends/reference/workloads/Encoders.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -56,10 +56,24 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void* } case armnn::DataType::QSymmS16: { - return std::make_unique<QSymm16Encoder>( - static_cast<int16_t*>(data), - info.GetQuantizationScale(), - info.GetQuantizationOffset()); + if (info.HasPerAxisQuantization()) + { + unsigned int axis = info.GetQuantizationDim().value(); + auto axisDimensionality = info.GetShape()[axis]; + std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info); + return std::make_unique<QSymm16PerAxisEncoder>( + static_cast<int16_t*>(data), + params.second, + params.first, + axisDimensionality); + } + else + { + return std::make_unique<QSymm16Encoder>( + static_cast<int16_t *>(data), + info.GetQuantizationScale(), + info.GetQuantizationOffset()); + } } case armnn::DataType::Signed32: { |