diff options
Diffstat (limited to 'src/armnnUtils/TensorUtils.cpp')
-rw-r--r-- | src/armnnUtils/TensorUtils.cpp | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp index c2fbbe0bcc..8baea78ab5 100644 --- a/src/armnnUtils/TensorUtils.cpp +++ b/src/armnnUtils/TensorUtils.cpp @@ -6,6 +6,10 @@ #include "TensorUtils.hpp" #include <backendsCommon/ITensorHandle.hpp> +#include <boost/assert.hpp> +#include <boost/format.hpp> +#include <boost/numeric/conversion/cast.hpp> + namespace armnnUtils { @@ -75,4 +79,32 @@ std::pair<float, float> FindMinMax(armnn::ITensorHandle* tensorHandle) return std::make_pair(min, max); } +armnn::TensorShape ExpandDims(const armnn::TensorShape& tensorShape, int axis) +{ + unsigned int outputDim = tensorShape.GetNumDimensions() + 1; + + if (axis < -boost::numeric_cast<int>(outputDim) || axis > boost::numeric_cast<int>(tensorShape.GetNumDimensions())) + { + throw armnn::InvalidArgumentException( + boost::str(boost::format("Invalid expansion axis %1% for %2%D input tensor. %3%") % + axis % + tensorShape.GetNumDimensions() % + CHECK_LOCATION().AsString())); + } + + if (axis < 0) + { + axis = boost::numeric_cast<int>(outputDim) + axis; + } + + std::vector<unsigned int> outputShape; + for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i) + { + outputShape.push_back(tensorShape[i]); + } + outputShape.insert(outputShape.begin() + axis, 1); + + return armnn::TensorShape(outputDim, outputShape.data()); +} + } |