diff options
Diffstat (limited to 'src/armnnUtils/TensorUtils.cpp')
-rw-r--r-- | src/armnnUtils/TensorUtils.cpp | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp index 03109e0cee..cb73d92ef8 100644 --- a/src/armnnUtils/TensorUtils.cpp +++ b/src/armnnUtils/TensorUtils.cpp @@ -165,6 +165,31 @@ TensorShape ExpandDims(const TensorShape& tensorShape, int axis) return { outputDim, outputShape.data() }; } +TensorShape ExpandDimsToRank(const TensorShape& tensorShape, unsigned int rank) +{ + // Can't expand if rank is smaller than current shape + if (tensorShape.GetNumDimensions() >= rank) + { + return tensorShape; + } + + std::vector<unsigned int> newShape; + + // First add 1s to the beginning of the tensorInfo to fill in the space + for (unsigned int i = 0; i < rank - tensorShape.GetNumDimensions(); ++i) + { + newShape.push_back(1); + } + + // Then iterate through the original shape and append it to the new shape with the added 1s + for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i) + { + newShape.push_back(tensorShape[i]); + } + + return TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()); +} + std::vector<unsigned int> SqueezeDims(const TensorShape& tensorShape) { std::vector<unsigned int> squeezedDims; |