From 5236e1d6bcff6ebec7ec10d7d416cc6ead5482dd Mon Sep 17 00:00:00 2001 From: Keith Davis Date: Mon, 4 Nov 2019 08:58:33 +0000 Subject: IVGCVSW-3835 Create Encoder and Decoder for QSymm8PerAxis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add QuantizedSymm8PerAxis to armnn DataType (types.hpp) and * Add Quantize and Dequantize template for int8 in TypeUtils to be able to compute QSymm8 of the weight * Create PerAxisIterator for per-axis quantization * Create QSymm8PerAxisDecoder * Create QSymm8PerAxisEncoder Signed-off-by: Keith Davis Change-Id: Ibcfe0288a197b7ee50b543bdbd77b7edb8a547c2 --- src/armnnUtils/TensorUtils.hpp | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) (limited to 'src/armnnUtils/TensorUtils.hpp') diff --git a/src/armnnUtils/TensorUtils.hpp b/src/armnnUtils/TensorUtils.hpp index 2b1f6a24f3..32af179bdc 100644 --- a/src/armnnUtils/TensorUtils.hpp +++ b/src/armnnUtils/TensorUtils.hpp @@ -7,6 +7,8 @@ #include +#include + namespace armnnUtils { armnn::TensorShape GetTensorShape(unsigned int numberOfBatches, @@ -32,4 +34,32 @@ unsigned int GetNumElementsBetween(const armnn::TensorShape& shape, unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis); +inline unsigned int GetNumElementsAfter(const armnn::TensorShape& shape, + unsigned int axis) +{ + unsigned int numDim = shape.GetNumDimensions(); + BOOST_ASSERT(0 >= axis); + BOOST_ASSERT(axis < numDim - 1); + unsigned int count = 1; + for (unsigned int i = axis; i < numDim; i++) + { + count *= shape[i]; + } + return count; +} + +inline std::pair> GetPerAxisParams(const armnn::TensorInfo& info) +{ + const std::vector& scales = info.GetQuantizationScales(); + armnn::Optional quantizationDim = info.GetQuantizationDim(); + if (scales.size() < 1 || !quantizationDim.has_value()) + { + throw armnn::InvalidArgumentException( + "We currently support only per-axis symmetric quantization for QuantizedSymm8."); + } + unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value()); + + return {axisFactor, scales}; +} + } // namespace armnnUtils -- cgit v1.2.1