aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/TensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/TensorUtils.cpp')
-rw-r--r--src/armnnUtils/TensorUtils.cpp28
1 files changed, 28 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index 8baea78ab5..b4e8d5acda 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -107,4 +107,32 @@ armnn::TensorShape ExpandDims(const armnn::TensorShape& tensorShape, int axis)
return armnn::TensorShape(outputDim, outputShape.data());
}
+unsigned int GetNumElementsBetween(const armnn::TensorShape& shape,
+ const unsigned int firstAxisInclusive,
+ const unsigned int lastAxisExclusive)
+{
+ BOOST_ASSERT(0 <= firstAxisInclusive);
+ BOOST_ASSERT(firstAxisInclusive <= lastAxisExclusive);
+ BOOST_ASSERT(lastAxisExclusive <= shape.GetNumDimensions());
+ unsigned int count = 1;
+ for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++)
+ {
+ count *= shape[i];
+ }
+ return count;
+}
+
+unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis)
+{
+ BOOST_ASSERT_MSG(axis < boost::numeric_cast<int>(inputDimension),
+ "Required axis index greater than number of dimensions.");
+ BOOST_ASSERT_MSG(axis >= -boost::numeric_cast<int>(inputDimension),
+ "Required axis index lower than negative of the number of dimensions");
+
+ unsigned int uAxis = axis < 0 ?
+ inputDimension - boost::numeric_cast<unsigned int>(abs(axis))
+ : boost::numeric_cast<unsigned int>(axis);
+ return uAxis;
+}
+
}