diff options
author | Mike Kelly <mike.kelly@arm.com> | 2023-01-19 18:29:40 +0000 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2023-01-23 22:05:22 +0000 |
commit | 2542a267dc3dfe2b9148b3944977a6864ef3c558 (patch) | |
tree | 1548da9d99abec215e80c5e6b2e25181a092b8b3 /src/armnn | |
parent | c01da459e63461dc0ecfe855fd3254c938689386 (diff) | |
download | armnn-2542a267dc3dfe2b9148b3944977a6864ef3c558.tar.gz |
IVGCVSW-7453 Comparison does not Calculate its shape properly
* Fixed issue where ComparisonLayer wasn't calculating its output shape correctly.
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I37fe437b598bde694e519d6792182924bd0197cd
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/layers/ComparisonLayer.cpp | 27 | ||||
-rw-r--r-- | src/armnn/test/ShapeInferenceTests.cpp | 22 |
2 files changed, 41 insertions, 8 deletions
diff --git a/src/armnn/layers/ComparisonLayer.cpp b/src/armnn/layers/ComparisonLayer.cpp index b6cd48b268..c097cddf4d 100644 --- a/src/armnn/layers/ComparisonLayer.cpp +++ b/src/armnn/layers/ComparisonLayer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2019 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -36,24 +36,37 @@ ComparisonLayer* ComparisonLayer::Clone(Graph& graph) const std::vector<TensorShape> ComparisonLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const { ARMNN_ASSERT(inputShapes.size() == 2); - const TensorShape& input0 = inputShapes[0]; - const TensorShape& input1 = inputShapes[1]; + TensorShape input0 = inputShapes[0]; + TensorShape input1 = inputShapes[1]; - ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions()); - unsigned int numDims = input0.GetNumDimensions(); + if (inputShapes[0].GetNumDimensions() < inputShapes[1].GetNumDimensions()) + { + input1 = inputShapes[0]; + input0 = inputShapes[1]; + } + unsigned int numDims = input0.GetNumDimensions(); + unsigned int shiftedDims = input0.GetNumDimensions() - input1.GetNumDimensions(); + // Get the max of the inputs. std::vector<unsigned int> dims(numDims); - for (unsigned int i = 0; i < numDims; i++) + for (unsigned int i = shiftedDims; i < numDims; i++) { unsigned int dim0 = input0[i]; - unsigned int dim1 = input1[i]; + unsigned int dim1 = input1[i - shiftedDims]; + // Validate inputs are broadcast compatible. ARMNN_ASSERT_MSG(dim0 == dim1 || dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be of size 1."); dims[i] = std::max(dim0, dim1); } + // Fill in the rest of the shifted dimensions. + for (unsigned int i = 0; i < shiftedDims; i++) + { + dims[i] = input0[i]; + } + return std::vector<TensorShape>({ TensorShape(numDims, dims.data()) }); } diff --git a/src/armnn/test/ShapeInferenceTests.cpp b/src/armnn/test/ShapeInferenceTests.cpp index 333d12a3a2..7b5d73a4e5 100644 --- a/src/armnn/test/ShapeInferenceTests.cpp +++ b/src/armnn/test/ShapeInferenceTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -223,6 +223,26 @@ TEST_CASE("ComparisionTest") "comparision"); } +TEST_CASE("ComparisionTestSmallerRHS") +{ + ComparisonDescriptor descriptor; + descriptor.m_Operation = ComparisonOperation::Equal; + CreateGraphAndRunTest<ComparisonLayer>({{ 5, 7, 6, 2 }, { 1 }}, + {{ 5, 7, 6, 2 }}, + descriptor, + "comparision"); +} + +TEST_CASE("ComparisionTestSmallerLHS") +{ + ComparisonDescriptor descriptor; + descriptor.m_Operation = ComparisonOperation::Equal; + CreateGraphAndRunTest<ComparisonLayer>({{ 1 }, { 5, 7, 6, 2 }}, + {{ 5, 7, 6, 2 }}, + descriptor, + "comparision"); +} + TEST_CASE("ConcatTest") { ConcatDescriptor descriptor(2, 3); |