diff options
author | Ryan OShea <ryan.oshea3@arm.com> | 2023-01-25 18:10:20 +0000 |
---|---|---|
committer | ryan.oshea3 <ryan.oshea3@arm.com> | 2023-02-21 14:36:56 +0000 |
commit | a544f0f5d01ea980ca86e1e13e2530fea4fddcd2 (patch) | |
tree | dead6db771d8d78f1e797d3a556586bd9f5129af /src | |
parent | b2293702c16d107ac1ad80cfac9bd84d804f55d4 (diff) | |
download | armnn-a544f0f5d01ea980ca86e1e13e2530fea4fddcd2.tar.gz |
MLCE-753 Expand Tensorshape for relevent layers before verifying support
Previously we were adding a reshape layer to "broadcast" tensors
for elementwise operations. This broadcast was happening too late
and was really just an expand dims. This was breaking the constant
attributes of tensors and layer support of certain backends.
* Remove addition of reshape layer when expanding dimensions
* Replace broadcast function with expand dims to equal rank function
* Fix some error status checks in various layers
* Add new TensorUtil function that expands dims to a defined rank
* Add unit tests to new TensorUtil function
Signed-off-by: Ryan OShea <ryan.oshea3@arm.com>
Change-Id: I31aca47c98075fef4f86864a15470f5faa55ab8d
Diffstat (limited to 'src')
-rw-r--r-- | src/armnnUtils/TensorUtils.cpp | 25 | ||||
-rw-r--r-- | src/armnnUtils/test/TensorUtilsTest.cpp | 85 |
2 files changed, 98 insertions, 12 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp index 03109e0cee..cb73d92ef8 100644 --- a/src/armnnUtils/TensorUtils.cpp +++ b/src/armnnUtils/TensorUtils.cpp @@ -165,6 +165,31 @@ TensorShape ExpandDims(const TensorShape& tensorShape, int axis) return { outputDim, outputShape.data() }; } +TensorShape ExpandDimsToRank(const TensorShape& tensorShape, unsigned int rank) +{ + // Can't expand if rank is smaller than current shape + if (tensorShape.GetNumDimensions() >= rank) + { + return tensorShape; + } + + std::vector<unsigned int> newShape; + + // First add 1s to the beginning of the tensorInfo to fill in the space + for (unsigned int i = 0; i < rank - tensorShape.GetNumDimensions(); ++i) + { + newShape.push_back(1); + } + + // Then iterate through the original shape and append it to the new shape with the added 1s + for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i) + { + newShape.push_back(tensorShape[i]); + } + + return TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()); +} + std::vector<unsigned int> SqueezeDims(const TensorShape& tensorShape) { std::vector<unsigned int> squeezedDims; diff --git a/src/armnnUtils/test/TensorUtilsTest.cpp b/src/armnnUtils/test/TensorUtilsTest.cpp index a69a0098ce..ed21bbe93c 100644 --- a/src/armnnUtils/test/TensorUtilsTest.cpp +++ b/src/armnnUtils/test/TensorUtilsTest.cpp @@ -126,11 +126,79 @@ TEST_CASE("ExpandDimsInvalidAxisTest") CHECK_THROWS_AS(ExpandDims(inputShape, 4), armnn::InvalidArgumentException); } +TEST_CASE("ExpandDimsInvalidNegativeAxisTest") +{ + armnn::TensorShape inputShape({ 2, 3, 4 }); + + // Invalid expand dimension -5 + CHECK_THROWS_AS(ExpandDims(inputShape, -5), armnn::InvalidArgumentException); +} + +TEST_CASE("ExpandDimsBy1Rank") +{ + armnn::TensorShape inputShape({ 2, 3, 4 }); + + // Expand by 1 dimension + armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4); + CHECK(outputShape.GetNumDimensions() == 4); + CHECK(outputShape[0] == 1); + CHECK(outputShape[1] == 2); + CHECK(outputShape[2] == 3); + CHECK(outputShape[3] == 4); +} + +TEST_CASE("ExpandDimsBy2Ranks") +{ + armnn::TensorShape inputShape({ 3, 4 }); + + // Expand 2 dimensions + armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4); + CHECK(outputShape.GetNumDimensions() == 4); + CHECK(outputShape[0] == 1); + CHECK(outputShape[1] == 1); + CHECK(outputShape[2] == 3); + CHECK(outputShape[3] == 4); +} + +TEST_CASE("ExpandDimsBy3Ranks") +{ + armnn::TensorShape inputShape({ 4 }); + + // Expand 3 dimensions + armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4); + CHECK(outputShape.GetNumDimensions() == 4); + CHECK(outputShape[0] == 1); + CHECK(outputShape[1] == 1); + CHECK(outputShape[2] == 1); + CHECK(outputShape[3] == 4); +} + +TEST_CASE("ExpandDimsInvalidRankAmount") +{ + armnn::TensorShape inputShape({ 2, 3, 4 }); + + // Don't expand because target rank is smaller than current rank + armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 2); + CHECK(outputShape.GetNumDimensions() == 3); + CHECK(outputShape[0] == 2); + CHECK(outputShape[1] == 3); + CHECK(outputShape[2] == 4); +} + +TEST_CASE("ExpandDimsToRankInvalidTensorShape") +{ + armnn::TensorShape inputShape({ 2, 3, 4 }); + + // Throw exception because rank 6 tensors are unsupported by armnn + CHECK_THROWS_AS(ExpandDimsToRank(inputShape, 6), armnn::InvalidArgumentException); +} + + TEST_CASE("ReduceDimsShapeAll1s") { armnn::TensorShape inputShape({ 1, 1, 1 }); - // Invalid expand dimension 4 + // Reduce dimension 2 armnn::TensorShape outputShape = ReduceDims(inputShape, 2); CHECK(outputShape.GetNumDimensions() == 2); CHECK(outputShape[0] == 1); @@ -141,7 +209,7 @@ TEST_CASE("ReduceDimsShapeNotEnough1s") { armnn::TensorShape inputShape({ 1, 2, 1 }); - // Invalid expand dimension 4 + // Reduce dimension 1 armnn::TensorShape outputShape = ReduceDims(inputShape, 1); CHECK(outputShape.GetNumDimensions() == 2); CHECK(outputShape[0] == 2); @@ -152,7 +220,7 @@ TEST_CASE("ReduceDimsInfoAll1s") { armnn::TensorInfo inputInfo({ 1, 1, 1 }, DataType::Float32); - // Invalid expand dimension 4 + // Reduce dimension 2 armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 2); CHECK(outputInfo.GetShape().GetNumDimensions() == 2); CHECK(outputInfo.GetShape()[0] == 1); @@ -163,7 +231,7 @@ TEST_CASE("ReduceDimsInfoNotEnough1s") { armnn::TensorInfo inputInfo({ 1, 2, 1 }, DataType::Float32); - // Invalid expand dimension 4 + // Reduce dimension 1 armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 1); CHECK(outputInfo.GetNumDimensions() == 2); CHECK(outputInfo.GetShape()[0] == 2); @@ -174,7 +242,7 @@ TEST_CASE("ReduceDimsShapeDimensionGreaterThanSize") { armnn::TensorShape inputShape({ 1, 1, 1 }); - // Invalid expand dimension 4 + // Do not reduce because dimension does not exist armnn::TensorShape outputShape = ReduceDims(inputShape, 4); CHECK(outputShape.GetNumDimensions() == 3); CHECK(outputShape[0] == 1); @@ -182,13 +250,6 @@ TEST_CASE("ReduceDimsShapeDimensionGreaterThanSize") CHECK(outputShape[2] == 1); } -TEST_CASE("ExpandDimsInvalidNegativeAxisTest") -{ - armnn::TensorShape inputShape({ 2, 3, 4 }); - - // Invalid expand dimension -5 - CHECK_THROWS_AS(ExpandDims(inputShape, -5), armnn::InvalidArgumentException); -} TEST_CASE("ToFloatArrayInvalidDataType") { |