aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils
diff options
context:
space:
mode:
authorKeith Davis <keith.davis@arm.com>2019-11-04 08:58:33 +0000
committerKeith Davis <keith.davis@arm.com>2019-11-04 16:46:35 +0000
commit5236e1d6bcff6ebec7ec10d7d416cc6ead5482dd (patch)
tree4152c5fcd6b9c11848a02dfa4ff8705a2cfae0a5 /src/armnnUtils
parentf71079328ae72a65c91e410b2bd35eabb67cb6d1 (diff)
downloadarmnn-5236e1d6bcff6ebec7ec10d7d416cc6ead5482dd.tar.gz
IVGCVSW-3835 Create Encoder and Decoder for QSymm8PerAxis
* Add QuantizedSymm8PerAxis to armnn DataType (types.hpp) and * Add Quantize and Dequantize template for int8 in TypeUtils to be able to compute QSymm8 of the weight * Create PerAxisIterator for per-axis quantization * Create QSymm8PerAxisDecoder * Create QSymm8PerAxisEncoder Signed-off-by: Keith Davis <keith.davis@arm.com> Change-Id: Ibcfe0288a197b7ee50b543bdbd77b7edb8a547c2
Diffstat (limited to 'src/armnnUtils')
-rw-r--r--src/armnnUtils/TensorUtils.hpp30
1 files changed, 30 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.hpp b/src/armnnUtils/TensorUtils.hpp
index 2b1f6a24f3..32af179bdc 100644
--- a/src/armnnUtils/TensorUtils.hpp
+++ b/src/armnnUtils/TensorUtils.hpp
@@ -7,6 +7,8 @@
#include <armnn/TypesUtils.hpp>
+#include <boost/assert.hpp>
+
namespace armnnUtils
{
armnn::TensorShape GetTensorShape(unsigned int numberOfBatches,
@@ -32,4 +34,32 @@ unsigned int GetNumElementsBetween(const armnn::TensorShape& shape,
unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis);
+inline unsigned int GetNumElementsAfter(const armnn::TensorShape& shape,
+ unsigned int axis)
+{
+ unsigned int numDim = shape.GetNumDimensions();
+ BOOST_ASSERT(0 >= axis);
+ BOOST_ASSERT(axis < numDim - 1);
+ unsigned int count = 1;
+ for (unsigned int i = axis; i < numDim; i++)
+ {
+ count *= shape[i];
+ }
+ return count;
+}
+
+inline std::pair<unsigned int, std::vector<float>> GetPerAxisParams(const armnn::TensorInfo& info)
+{
+ const std::vector<float>& scales = info.GetQuantizationScales();
+ armnn::Optional<unsigned int> quantizationDim = info.GetQuantizationDim();
+ if (scales.size() < 1 || !quantizationDim.has_value())
+ {
+ throw armnn::InvalidArgumentException(
+ "We currently support only per-axis symmetric quantization for QuantizedSymm8.");
+ }
+ unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value());
+
+ return {axisFactor, scales};
+}
+
} // namespace armnnUtils