diff options
Diffstat (limited to 'src/backends/reference/workloads/Encoders.hpp')
-rw-r--r-- | src/backends/reference/workloads/Encoders.hpp | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp index d6d611494d..8a702377b2 100644 --- a/src/backends/reference/workloads/Encoders.hpp +++ b/src/backends/reference/workloads/Encoders.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -56,10 +56,24 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void* } case armnn::DataType::QSymmS16: { - return std::make_unique<QSymm16Encoder>( - static_cast<int16_t*>(data), - info.GetQuantizationScale(), - info.GetQuantizationOffset()); + if (info.HasPerAxisQuantization()) + { + unsigned int axis = info.GetQuantizationDim().value(); + auto axisDimensionality = info.GetShape()[axis]; + std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info); + return std::make_unique<QSymm16PerAxisEncoder>( + static_cast<int16_t*>(data), + params.second, + params.first, + axisDimensionality); + } + else + { + return std::make_unique<QSymm16Encoder>( + static_cast<int16_t *>(data), + info.GetQuantizationScale(), + info.GetQuantizationOffset()); + } } case armnn::DataType::Signed32: { |