aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils')
-rw-r--r--src/armnnUtils/TensorUtils.cpp28
-rw-r--r--src/armnnUtils/TensorUtils.hpp6
2 files changed, 34 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index 8baea78ab5..b4e8d5acda 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -107,4 +107,32 @@ armnn::TensorShape ExpandDims(const armnn::TensorShape& tensorShape, int axis)
return armnn::TensorShape(outputDim, outputShape.data());
}
+unsigned int GetNumElementsBetween(const armnn::TensorShape& shape,
+ const unsigned int firstAxisInclusive,
+ const unsigned int lastAxisExclusive)
+{
+ BOOST_ASSERT(0 <= firstAxisInclusive);
+ BOOST_ASSERT(firstAxisInclusive <= lastAxisExclusive);
+ BOOST_ASSERT(lastAxisExclusive <= shape.GetNumDimensions());
+ unsigned int count = 1;
+ for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++)
+ {
+ count *= shape[i];
+ }
+ return count;
+}
+
+unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis)
+{
+ BOOST_ASSERT_MSG(axis < boost::numeric_cast<int>(inputDimension),
+ "Required axis index greater than number of dimensions.");
+ BOOST_ASSERT_MSG(axis >= -boost::numeric_cast<int>(inputDimension),
+ "Required axis index lower than negative of the number of dimensions");
+
+ unsigned int uAxis = axis < 0 ?
+ inputDimension - boost::numeric_cast<unsigned int>(abs(axis))
+ : boost::numeric_cast<unsigned int>(axis);
+ return uAxis;
+}
+
}
diff --git a/src/armnnUtils/TensorUtils.hpp b/src/armnnUtils/TensorUtils.hpp
index 03b1c8a2df..2b1f6a24f3 100644
--- a/src/armnnUtils/TensorUtils.hpp
+++ b/src/armnnUtils/TensorUtils.hpp
@@ -26,4 +26,10 @@ std::pair<float, float> FindMinMax(armnn::ITensorHandle* tensorHandle);
armnn::TensorShape ExpandDims(const armnn::TensorShape& tensorShape, int axis);
+unsigned int GetNumElementsBetween(const armnn::TensorShape& shape,
+ unsigned int firstAxisInclusive,
+ unsigned int lastAxisExclusive);
+
+unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis);
+
} // namespace armnnUtils