aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/TensorUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/TensorUtils.hpp')
-rw-r--r--src/armnnUtils/TensorUtils.hpp30
1 files changed, 30 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.hpp b/src/armnnUtils/TensorUtils.hpp
index 2b1f6a24f3..32af179bdc 100644
--- a/src/armnnUtils/TensorUtils.hpp
+++ b/src/armnnUtils/TensorUtils.hpp
@@ -7,6 +7,8 @@
#include <armnn/TypesUtils.hpp>
+#include <boost/assert.hpp>
+
namespace armnnUtils
{
armnn::TensorShape GetTensorShape(unsigned int numberOfBatches,
@@ -32,4 +34,32 @@ unsigned int GetNumElementsBetween(const armnn::TensorShape& shape,
unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis);
+inline unsigned int GetNumElementsAfter(const armnn::TensorShape& shape,
+ unsigned int axis)
+{
+ unsigned int numDim = shape.GetNumDimensions();
+ BOOST_ASSERT(0 >= axis);
+ BOOST_ASSERT(axis < numDim - 1);
+ unsigned int count = 1;
+ for (unsigned int i = axis; i < numDim; i++)
+ {
+ count *= shape[i];
+ }
+ return count;
+}
+
+inline std::pair<unsigned int, std::vector<float>> GetPerAxisParams(const armnn::TensorInfo& info)
+{
+ const std::vector<float>& scales = info.GetQuantizationScales();
+ armnn::Optional<unsigned int> quantizationDim = info.GetQuantizationDim();
+ if (scales.size() < 1 || !quantizationDim.has_value())
+ {
+ throw armnn::InvalidArgumentException(
+ "We currently support only per-axis symmetric quantization for QuantizedSymm8.");
+ }
+ unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value());
+
+ return {axisFactor, scales};
+}
+
} // namespace armnnUtils