diff options
Diffstat (limited to 'src/armnnUtils')
-rw-r--r-- | src/armnnUtils/TensorUtils.cpp | 25 | ||||
-rw-r--r-- | src/armnnUtils/test/TensorUtilsTest.cpp | 85 |
2 files changed, 98 insertions, 12 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp index 03109e0cee..cb73d92ef8 100644 --- a/src/armnnUtils/TensorUtils.cpp +++ b/src/armnnUtils/TensorUtils.cpp @@ -165,6 +165,31 @@ TensorShape ExpandDims(const TensorShape& tensorShape, int axis) return { outputDim, outputShape.data() }; } +TensorShape ExpandDimsToRank(const TensorShape& tensorShape, unsigned int rank) +{ + // Can't expand if rank is smaller than current shape + if (tensorShape.GetNumDimensions() >= rank) + { + return tensorShape; + } + + std::vector<unsigned int> newShape; + + // First add 1s to the beginning of the tensorInfo to fill in the space + for (unsigned int i = 0; i < rank - tensorShape.GetNumDimensions(); ++i) + { + newShape.push_back(1); + } + + // Then iterate through the original shape and append it to the new shape with the added 1s + for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i) + { + newShape.push_back(tensorShape[i]); + } + + return TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()); +} + std::vector<unsigned int> SqueezeDims(const TensorShape& tensorShape) { std::vector<unsigned int> squeezedDims; diff --git a/src/armnnUtils/test/TensorUtilsTest.cpp b/src/armnnUtils/test/TensorUtilsTest.cpp index a69a0098ce..ed21bbe93c 100644 --- a/src/armnnUtils/test/TensorUtilsTest.cpp +++ b/src/armnnUtils/test/TensorUtilsTest.cpp @@ -126,11 +126,79 @@ TEST_CASE("ExpandDimsInvalidAxisTest") CHECK_THROWS_AS(ExpandDims(inputShape, 4), armnn::InvalidArgumentException); } +TEST_CASE("ExpandDimsInvalidNegativeAxisTest") +{ + armnn::TensorShape inputShape({ 2, 3, 4 }); + + // Invalid expand dimension -5 + CHECK_THROWS_AS(ExpandDims(inputShape, -5), armnn::InvalidArgumentException); +} + +TEST_CASE("ExpandDimsBy1Rank") +{ + armnn::TensorShape inputShape({ 2, 3, 4 }); + + // Expand by 1 dimension + armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4); + CHECK(outputShape.GetNumDimensions() == 4); + CHECK(outputShape[0] == 1); + CHECK(outputShape[1] == 2); + CHECK(outputShape[2] == 3); + CHECK(outputShape[3] == 4); +} + +TEST_CASE("ExpandDimsBy2Ranks") +{ + armnn::TensorShape inputShape({ 3, 4 }); + + // Expand 2 dimensions + armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4); + CHECK(outputShape.GetNumDimensions() == 4); + CHECK(outputShape[0] == 1); + CHECK(outputShape[1] == 1); + CHECK(outputShape[2] == 3); + CHECK(outputShape[3] == 4); +} + +TEST_CASE("ExpandDimsBy3Ranks") +{ + armnn::TensorShape inputShape({ 4 }); + + // Expand 3 dimensions + armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4); + CHECK(outputShape.GetNumDimensions() == 4); + CHECK(outputShape[0] == 1); + CHECK(outputShape[1] == 1); + CHECK(outputShape[2] == 1); + CHECK(outputShape[3] == 4); +} + +TEST_CASE("ExpandDimsInvalidRankAmount") +{ + armnn::TensorShape inputShape({ 2, 3, 4 }); + + // Don't expand because target rank is smaller than current rank + armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 2); + CHECK(outputShape.GetNumDimensions() == 3); + CHECK(outputShape[0] == 2); + CHECK(outputShape[1] == 3); + CHECK(outputShape[2] == 4); +} + +TEST_CASE("ExpandDimsToRankInvalidTensorShape") +{ + armnn::TensorShape inputShape({ 2, 3, 4 }); + + // Throw exception because rank 6 tensors are unsupported by armnn + CHECK_THROWS_AS(ExpandDimsToRank(inputShape, 6), armnn::InvalidArgumentException); +} + + TEST_CASE("ReduceDimsShapeAll1s") { armnn::TensorShape inputShape({ 1, 1, 1 }); - // Invalid expand dimension 4 + // Reduce dimension 2 armnn::TensorShape outputShape = ReduceDims(inputShape, 2); CHECK(outputShape.GetNumDimensions() == 2); CHECK(outputShape[0] == 1); @@ -141,7 +209,7 @@ TEST_CASE("ReduceDimsShapeNotEnough1s") { armnn::TensorShape inputShape({ 1, 2, 1 }); - // Invalid expand dimension 4 + // Reduce dimension 1 armnn::TensorShape outputShape = ReduceDims(inputShape, 1); CHECK(outputShape.GetNumDimensions() == 2); CHECK(outputShape[0] == 2); @@ -152,7 +220,7 @@ TEST_CASE("ReduceDimsInfoAll1s") { armnn::TensorInfo inputInfo({ 1, 1, 1 }, DataType::Float32); - // Invalid expand dimension 4 + // Reduce dimension 2 armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 2); CHECK(outputInfo.GetShape().GetNumDimensions() == 2); CHECK(outputInfo.GetShape()[0] == 1); @@ -163,7 +231,7 @@ TEST_CASE("ReduceDimsInfoNotEnough1s") { armnn::TensorInfo inputInfo({ 1, 2, 1 }, DataType::Float32); - // Invalid expand dimension 4 + // Reduce dimension 1 armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 1); CHECK(outputInfo.GetNumDimensions() == 2); CHECK(outputInfo.GetShape()[0] == 2); @@ -174,7 +242,7 @@ TEST_CASE("ReduceDimsShapeDimensionGreaterThanSize") { armnn::TensorShape inputShape({ 1, 1, 1 }); - // Invalid expand dimension 4 + // Do not reduce because dimension does not exist armnn::TensorShape outputShape = ReduceDims(inputShape, 4); CHECK(outputShape.GetNumDimensions() == 3); CHECK(outputShape[0] == 1); @@ -182,13 +250,6 @@ TEST_CASE("ReduceDimsShapeDimensionGreaterThanSize") CHECK(outputShape[2] == 1); } -TEST_CASE("ExpandDimsInvalidNegativeAxisTest") -{ - armnn::TensorShape inputShape({ 2, 3, 4 }); - - // Invalid expand dimension -5 - CHECK_THROWS_AS(ExpandDims(inputShape, -5), armnn::InvalidArgumentException); -} TEST_CASE("ToFloatArrayInvalidDataType") { |