aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/TensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/TensorUtils.cpp')
-rw-r--r--src/armnnUtils/TensorUtils.cpp25
1 files changed, 25 insertions, 0 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index 03109e0cee..cb73d92ef8 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -165,6 +165,31 @@ TensorShape ExpandDims(const TensorShape& tensorShape, int axis)
return { outputDim, outputShape.data() };
}
+TensorShape ExpandDimsToRank(const TensorShape& tensorShape, unsigned int rank)
+{
+ // Can't expand if rank is smaller than current shape
+ if (tensorShape.GetNumDimensions() >= rank)
+ {
+ return tensorShape;
+ }
+
+ std::vector<unsigned int> newShape;
+
+ // First add 1s to the beginning of the tensorInfo to fill in the space
+ for (unsigned int i = 0; i < rank - tensorShape.GetNumDimensions(); ++i)
+ {
+ newShape.push_back(1);
+ }
+
+ // Then iterate through the original shape and append it to the new shape with the added 1s
+ for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
+ {
+ newShape.push_back(tensorShape[i]);
+ }
+
+ return TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data());
+}
+
std::vector<unsigned int> SqueezeDims(const TensorShape& tensorShape)
{
std::vector<unsigned int> squeezedDims;