aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Encoders.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/Encoders.hpp')
-rw-r--r--src/backends/reference/workloads/Encoders.hpp24
1 files changed, 19 insertions, 5 deletions
diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp
index d6d611494d..8a702377b2 100644
--- a/src/backends/reference/workloads/Encoders.hpp
+++ b/src/backends/reference/workloads/Encoders.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -56,10 +56,24 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void*
}
case armnn::DataType::QSymmS16:
{
- return std::make_unique<QSymm16Encoder>(
- static_cast<int16_t*>(data),
- info.GetQuantizationScale(),
- info.GetQuantizationOffset());
+ if (info.HasPerAxisQuantization())
+ {
+ unsigned int axis = info.GetQuantizationDim().value();
+ auto axisDimensionality = info.GetShape()[axis];
+ std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
+ return std::make_unique<QSymm16PerAxisEncoder>(
+ static_cast<int16_t*>(data),
+ params.second,
+ params.first,
+ axisDimensionality);
+ }
+ else
+ {
+ return std::make_unique<QSymm16Encoder>(
+ static_cast<int16_t *>(data),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset());
+ }
}
case armnn::DataType::Signed32:
{