From 2542a267dc3dfe2b9148b3944977a6864ef3c558 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Thu, 19 Jan 2023 18:29:40 +0000 Subject: IVGCVSW-7453 Comparison does not Calculate its shape properly * Fixed issue where ComparisonLayer wasn't calculating its output shape correctly. Signed-off-by: Mike Kelly Change-Id: I37fe437b598bde694e519d6792182924bd0197cd --- src/armnn/layers/ComparisonLayer.cpp | 27 ++++++++++++++++++++------- src/armnn/test/ShapeInferenceTests.cpp | 22 +++++++++++++++++++++- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/armnn/layers/ComparisonLayer.cpp b/src/armnn/layers/ComparisonLayer.cpp index b6cd48b268..c097cddf4d 100644 --- a/src/armnn/layers/ComparisonLayer.cpp +++ b/src/armnn/layers/ComparisonLayer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2019 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -36,24 +36,37 @@ ComparisonLayer* ComparisonLayer::Clone(Graph& graph) const std::vector ComparisonLayer::InferOutputShapes(const std::vector& inputShapes) const { ARMNN_ASSERT(inputShapes.size() == 2); - const TensorShape& input0 = inputShapes[0]; - const TensorShape& input1 = inputShapes[1]; + TensorShape input0 = inputShapes[0]; + TensorShape input1 = inputShapes[1]; - ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions()); - unsigned int numDims = input0.GetNumDimensions(); + if (inputShapes[0].GetNumDimensions() < inputShapes[1].GetNumDimensions()) + { + input1 = inputShapes[0]; + input0 = inputShapes[1]; + } + unsigned int numDims = input0.GetNumDimensions(); + unsigned int shiftedDims = input0.GetNumDimensions() - input1.GetNumDimensions(); + // Get the max of the inputs. std::vector dims(numDims); - for (unsigned int i = 0; i < numDims; i++) + for (unsigned int i = shiftedDims; i < numDims; i++) { unsigned int dim0 = input0[i]; - unsigned int dim1 = input1[i]; + unsigned int dim1 = input1[i - shiftedDims]; + // Validate inputs are broadcast compatible. ARMNN_ASSERT_MSG(dim0 == dim1 || dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be of size 1."); dims[i] = std::max(dim0, dim1); } + // Fill in the rest of the shifted dimensions. + for (unsigned int i = 0; i < shiftedDims; i++) + { + dims[i] = input0[i]; + } + return std::vector({ TensorShape(numDims, dims.data()) }); } diff --git a/src/armnn/test/ShapeInferenceTests.cpp b/src/armnn/test/ShapeInferenceTests.cpp index 333d12a3a2..7b5d73a4e5 100644 --- a/src/armnn/test/ShapeInferenceTests.cpp +++ b/src/armnn/test/ShapeInferenceTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -223,6 +223,26 @@ TEST_CASE("ComparisionTest") "comparision"); } +TEST_CASE("ComparisionTestSmallerRHS") +{ + ComparisonDescriptor descriptor; + descriptor.m_Operation = ComparisonOperation::Equal; + CreateGraphAndRunTest({{ 5, 7, 6, 2 }, { 1 }}, + {{ 5, 7, 6, 2 }}, + descriptor, + "comparision"); +} + +TEST_CASE("ComparisionTestSmallerLHS") +{ + ComparisonDescriptor descriptor; + descriptor.m_Operation = ComparisonOperation::Equal; + CreateGraphAndRunTest({{ 1 }, { 5, 7, 6, 2 }}, + {{ 5, 7, 6, 2 }}, + descriptor, + "comparision"); +} + TEST_CASE("ConcatTest") { ConcatDescriptor descriptor(2, 3); -- cgit v1.2.1