diff options
Diffstat (limited to 'src/armnnUtils/TensorUtils.hpp')
-rw-r--r-- | src/armnnUtils/TensorUtils.hpp | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.hpp b/src/armnnUtils/TensorUtils.hpp index 2b1f6a24f3..32af179bdc 100644 --- a/src/armnnUtils/TensorUtils.hpp +++ b/src/armnnUtils/TensorUtils.hpp @@ -7,6 +7,8 @@ #include <armnn/TypesUtils.hpp> +#include <boost/assert.hpp> + namespace armnnUtils { armnn::TensorShape GetTensorShape(unsigned int numberOfBatches, @@ -32,4 +34,32 @@ unsigned int GetNumElementsBetween(const armnn::TensorShape& shape, unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis); +inline unsigned int GetNumElementsAfter(const armnn::TensorShape& shape, + unsigned int axis) +{ + unsigned int numDim = shape.GetNumDimensions(); + BOOST_ASSERT(0 >= axis); + BOOST_ASSERT(axis < numDim - 1); + unsigned int count = 1; + for (unsigned int i = axis; i < numDim; i++) + { + count *= shape[i]; + } + return count; +} + +inline std::pair<unsigned int, std::vector<float>> GetPerAxisParams(const armnn::TensorInfo& info) +{ + const std::vector<float>& scales = info.GetQuantizationScales(); + armnn::Optional<unsigned int> quantizationDim = info.GetQuantizationDim(); + if (scales.size() < 1 || !quantizationDim.has_value()) + { + throw armnn::InvalidArgumentException( + "We currently support only per-axis symmetric quantization for QuantizedSymm8."); + } + unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value()); + + return {axisFactor, scales}; +} + } // namespace armnnUtils |