aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/armnnUtils/TensorUtils.cpp25
-rw-r--r--src/armnnUtils/test/TensorUtilsTest.cpp85
2 files changed, 98 insertions, 12 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index 03109e0cee..cb73d92ef8 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -165,6 +165,31 @@ TensorShape ExpandDims(const TensorShape& tensorShape, int axis)
return { outputDim, outputShape.data() };
}
+TensorShape ExpandDimsToRank(const TensorShape& tensorShape, unsigned int rank)
+{
+ // Can't expand if rank is smaller than current shape
+ if (tensorShape.GetNumDimensions() >= rank)
+ {
+ return tensorShape;
+ }
+
+ std::vector<unsigned int> newShape;
+
+ // First add 1s to the beginning of the tensorInfo to fill in the space
+ for (unsigned int i = 0; i < rank - tensorShape.GetNumDimensions(); ++i)
+ {
+ newShape.push_back(1);
+ }
+
+ // Then iterate through the original shape and append it to the new shape with the added 1s
+ for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
+ {
+ newShape.push_back(tensorShape[i]);
+ }
+
+ return TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data());
+}
+
std::vector<unsigned int> SqueezeDims(const TensorShape& tensorShape)
{
std::vector<unsigned int> squeezedDims;
diff --git a/src/armnnUtils/test/TensorUtilsTest.cpp b/src/armnnUtils/test/TensorUtilsTest.cpp
index a69a0098ce..ed21bbe93c 100644
--- a/src/armnnUtils/test/TensorUtilsTest.cpp
+++ b/src/armnnUtils/test/TensorUtilsTest.cpp
@@ -126,11 +126,79 @@ TEST_CASE("ExpandDimsInvalidAxisTest")
CHECK_THROWS_AS(ExpandDims(inputShape, 4), armnn::InvalidArgumentException);
}
+TEST_CASE("ExpandDimsInvalidNegativeAxisTest")
+{
+ armnn::TensorShape inputShape({ 2, 3, 4 });
+
+ // Invalid expand dimension -5
+ CHECK_THROWS_AS(ExpandDims(inputShape, -5), armnn::InvalidArgumentException);
+}
+
+TEST_CASE("ExpandDimsBy1Rank")
+{
+ armnn::TensorShape inputShape({ 2, 3, 4 });
+
+ // Expand by 1 dimension
+ armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4);
+ CHECK(outputShape.GetNumDimensions() == 4);
+ CHECK(outputShape[0] == 1);
+ CHECK(outputShape[1] == 2);
+ CHECK(outputShape[2] == 3);
+ CHECK(outputShape[3] == 4);
+}
+
+TEST_CASE("ExpandDimsBy2Ranks")
+{
+ armnn::TensorShape inputShape({ 3, 4 });
+
+ // Expand 2 dimensions
+ armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4);
+ CHECK(outputShape.GetNumDimensions() == 4);
+ CHECK(outputShape[0] == 1);
+ CHECK(outputShape[1] == 1);
+ CHECK(outputShape[2] == 3);
+ CHECK(outputShape[3] == 4);
+}
+
+TEST_CASE("ExpandDimsBy3Ranks")
+{
+ armnn::TensorShape inputShape({ 4 });
+
+ // Expand 3 dimensions
+ armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4);
+ CHECK(outputShape.GetNumDimensions() == 4);
+ CHECK(outputShape[0] == 1);
+ CHECK(outputShape[1] == 1);
+ CHECK(outputShape[2] == 1);
+ CHECK(outputShape[3] == 4);
+}
+
+TEST_CASE("ExpandDimsInvalidRankAmount")
+{
+ armnn::TensorShape inputShape({ 2, 3, 4 });
+
+ // Don't expand because target rank is smaller than current rank
+ armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 2);
+ CHECK(outputShape.GetNumDimensions() == 3);
+ CHECK(outputShape[0] == 2);
+ CHECK(outputShape[1] == 3);
+ CHECK(outputShape[2] == 4);
+}
+
+TEST_CASE("ExpandDimsToRankInvalidTensorShape")
+{
+ armnn::TensorShape inputShape({ 2, 3, 4 });
+
+ // Throw exception because rank 6 tensors are unsupported by armnn
+ CHECK_THROWS_AS(ExpandDimsToRank(inputShape, 6), armnn::InvalidArgumentException);
+}
+
+
TEST_CASE("ReduceDimsShapeAll1s")
{
armnn::TensorShape inputShape({ 1, 1, 1 });
- // Invalid expand dimension 4
+ // Reduce dimension 2
armnn::TensorShape outputShape = ReduceDims(inputShape, 2);
CHECK(outputShape.GetNumDimensions() == 2);
CHECK(outputShape[0] == 1);
@@ -141,7 +209,7 @@ TEST_CASE("ReduceDimsShapeNotEnough1s")
{
armnn::TensorShape inputShape({ 1, 2, 1 });
- // Invalid expand dimension 4
+ // Reduce dimension 1
armnn::TensorShape outputShape = ReduceDims(inputShape, 1);
CHECK(outputShape.GetNumDimensions() == 2);
CHECK(outputShape[0] == 2);
@@ -152,7 +220,7 @@ TEST_CASE("ReduceDimsInfoAll1s")
{
armnn::TensorInfo inputInfo({ 1, 1, 1 }, DataType::Float32);
- // Invalid expand dimension 4
+ // Reduce dimension 2
armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 2);
CHECK(outputInfo.GetShape().GetNumDimensions() == 2);
CHECK(outputInfo.GetShape()[0] == 1);
@@ -163,7 +231,7 @@ TEST_CASE("ReduceDimsInfoNotEnough1s")
{
armnn::TensorInfo inputInfo({ 1, 2, 1 }, DataType::Float32);
- // Invalid expand dimension 4
+ // Reduce dimension 1
armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 1);
CHECK(outputInfo.GetNumDimensions() == 2);
CHECK(outputInfo.GetShape()[0] == 2);
@@ -174,7 +242,7 @@ TEST_CASE("ReduceDimsShapeDimensionGreaterThanSize")
{
armnn::TensorShape inputShape({ 1, 1, 1 });
- // Invalid expand dimension 4
+ // Do not reduce because dimension does not exist
armnn::TensorShape outputShape = ReduceDims(inputShape, 4);
CHECK(outputShape.GetNumDimensions() == 3);
CHECK(outputShape[0] == 1);
@@ -182,13 +250,6 @@ TEST_CASE("ReduceDimsShapeDimensionGreaterThanSize")
CHECK(outputShape[2] == 1);
}
-TEST_CASE("ExpandDimsInvalidNegativeAxisTest")
-{
- armnn::TensorShape inputShape({ 2, 3, 4 });
-
- // Invalid expand dimension -5
- CHECK_THROWS_AS(ExpandDims(inputShape, -5), armnn::InvalidArgumentException);
-}
TEST_CASE("ToFloatArrayInvalidDataType")
{