From 2542a267dc3dfe2b9148b3944977a6864ef3c558 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Thu, 19 Jan 2023 18:29:40 +0000 Subject: IVGCVSW-7453 Comparison does not Calculate its shape properly * Fixed issue where ComparisonLayer wasn't calculating its output shape correctly. Signed-off-by: Mike Kelly Change-Id: I37fe437b598bde694e519d6792182924bd0197cd --- src/armnn/layers/ComparisonLayer.cpp | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) (limited to 'src/armnn/layers/ComparisonLayer.cpp') diff --git a/src/armnn/layers/ComparisonLayer.cpp b/src/armnn/layers/ComparisonLayer.cpp index b6cd48b268..c097cddf4d 100644 --- a/src/armnn/layers/ComparisonLayer.cpp +++ b/src/armnn/layers/ComparisonLayer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2019 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -36,24 +36,37 @@ ComparisonLayer* ComparisonLayer::Clone(Graph& graph) const std::vector ComparisonLayer::InferOutputShapes(const std::vector& inputShapes) const { ARMNN_ASSERT(inputShapes.size() == 2); - const TensorShape& input0 = inputShapes[0]; - const TensorShape& input1 = inputShapes[1]; + TensorShape input0 = inputShapes[0]; + TensorShape input1 = inputShapes[1]; - ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions()); - unsigned int numDims = input0.GetNumDimensions(); + if (inputShapes[0].GetNumDimensions() < inputShapes[1].GetNumDimensions()) + { + input1 = inputShapes[0]; + input0 = inputShapes[1]; + } + unsigned int numDims = input0.GetNumDimensions(); + unsigned int shiftedDims = input0.GetNumDimensions() - input1.GetNumDimensions(); + // Get the max of the inputs. std::vector dims(numDims); - for (unsigned int i = 0; i < numDims; i++) + for (unsigned int i = shiftedDims; i < numDims; i++) { unsigned int dim0 = input0[i]; - unsigned int dim1 = input1[i]; + unsigned int dim1 = input1[i - shiftedDims]; + // Validate inputs are broadcast compatible. ARMNN_ASSERT_MSG(dim0 == dim1 || dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be of size 1."); dims[i] = std::max(dim0, dim1); } + // Fill in the rest of the shifted dimensions. + for (unsigned int i = 0; i < shiftedDims; i++) + { + dims[i] = input0[i]; + } + return std::vector({ TensorShape(numDims, dims.data()) }); } -- cgit v1.2.1