aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/TensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/TensorUtils.cpp')
-rw-r--r--src/armnnUtils/TensorUtils.cpp32
1 files changed, 32 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index c2fbbe0bcc..8baea78ab5 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -6,6 +6,10 @@
#include "TensorUtils.hpp"
#include <backendsCommon/ITensorHandle.hpp>
+#include <boost/assert.hpp>
+#include <boost/format.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+
namespace armnnUtils
{
@@ -75,4 +79,32 @@ std::pair<float, float> FindMinMax(armnn::ITensorHandle* tensorHandle)
return std::make_pair(min, max);
}
+armnn::TensorShape ExpandDims(const armnn::TensorShape& tensorShape, int axis)
+{
+ unsigned int outputDim = tensorShape.GetNumDimensions() + 1;
+
+ if (axis < -boost::numeric_cast<int>(outputDim) || axis > boost::numeric_cast<int>(tensorShape.GetNumDimensions()))
+ {
+ throw armnn::InvalidArgumentException(
+ boost::str(boost::format("Invalid expansion axis %1% for %2%D input tensor. %3%") %
+ axis %
+ tensorShape.GetNumDimensions() %
+ CHECK_LOCATION().AsString()));
+ }
+
+ if (axis < 0)
+ {
+ axis = boost::numeric_cast<int>(outputDim) + axis;
+ }
+
+ std::vector<unsigned int> outputShape;
+ for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
+ {
+ outputShape.push_back(tensorShape[i]);
+ }
+ outputShape.insert(outputShape.begin() + axis, 1);
+
+ return armnn::TensorShape(outputDim, outputShape.data());
+}
+
}