aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/TensorUtils.cpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-09-16 17:00:22 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-09-16 17:00:54 +0100
commit4dc64a69ba383ece509d442598617445a3b4847f (patch)
treeb50cb259594aa0cf634a4c37657a2c7a50be0c6c /src/armnnUtils/TensorUtils.cpp
parenta0c7871cf140d1e9cf59a213626ee534c0122c7f (diff)
downloadarmnn-4dc64a69ba383ece509d442598617445a3b4847f.tar.gz
IVGCVSW-3694 Add ArgMinMax implementation for Ref
* Add ArgMinMax implementation * Add utility function to get number of elements between axis * Add utility function to get unsigned axis * Unit tests for ArgMinMax function Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I7bc3d610dda9526190187eb87394a8ed7a4b5cdd
Diffstat (limited to 'src/armnnUtils/TensorUtils.cpp')
-rw-r--r--src/armnnUtils/TensorUtils.cpp28
1 files changed, 28 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index 8baea78ab5..b4e8d5acda 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -107,4 +107,32 @@ armnn::TensorShape ExpandDims(const armnn::TensorShape& tensorShape, int axis)
return armnn::TensorShape(outputDim, outputShape.data());
}
+unsigned int GetNumElementsBetween(const armnn::TensorShape& shape,
+ const unsigned int firstAxisInclusive,
+ const unsigned int lastAxisExclusive)
+{
+ BOOST_ASSERT(0 <= firstAxisInclusive);
+ BOOST_ASSERT(firstAxisInclusive <= lastAxisExclusive);
+ BOOST_ASSERT(lastAxisExclusive <= shape.GetNumDimensions());
+ unsigned int count = 1;
+ for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++)
+ {
+ count *= shape[i];
+ }
+ return count;
+}
+
+unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis)
+{
+ BOOST_ASSERT_MSG(axis < boost::numeric_cast<int>(inputDimension),
+ "Required axis index greater than number of dimensions.");
+ BOOST_ASSERT_MSG(axis >= -boost::numeric_cast<int>(inputDimension),
+ "Required axis index lower than negative of the number of dimensions");
+
+ unsigned int uAxis = axis < 0 ?
+ inputDimension - boost::numeric_cast<unsigned int>(abs(axis))
+ : boost::numeric_cast<unsigned int>(axis);
+ return uAxis;
+}
+
}