diff options
Diffstat (limited to 'src/armnn/layers/ComparisonLayer.cpp')
-rw-r--r-- | src/armnn/layers/ComparisonLayer.cpp | 27 |
1 files changed, 20 insertions, 7 deletions
diff --git a/src/armnn/layers/ComparisonLayer.cpp b/src/armnn/layers/ComparisonLayer.cpp index b6cd48b268..c097cddf4d 100644 --- a/src/armnn/layers/ComparisonLayer.cpp +++ b/src/armnn/layers/ComparisonLayer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2019 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -36,24 +36,37 @@ ComparisonLayer* ComparisonLayer::Clone(Graph& graph) const std::vector<TensorShape> ComparisonLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const { ARMNN_ASSERT(inputShapes.size() == 2); - const TensorShape& input0 = inputShapes[0]; - const TensorShape& input1 = inputShapes[1]; + TensorShape input0 = inputShapes[0]; + TensorShape input1 = inputShapes[1]; - ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions()); - unsigned int numDims = input0.GetNumDimensions(); + if (inputShapes[0].GetNumDimensions() < inputShapes[1].GetNumDimensions()) + { + input1 = inputShapes[0]; + input0 = inputShapes[1]; + } + unsigned int numDims = input0.GetNumDimensions(); + unsigned int shiftedDims = input0.GetNumDimensions() - input1.GetNumDimensions(); + // Get the max of the inputs. std::vector<unsigned int> dims(numDims); - for (unsigned int i = 0; i < numDims; i++) + for (unsigned int i = shiftedDims; i < numDims; i++) { unsigned int dim0 = input0[i]; - unsigned int dim1 = input1[i]; + unsigned int dim1 = input1[i - shiftedDims]; + // Validate inputs are broadcast compatible. ARMNN_ASSERT_MSG(dim0 == dim1 || dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be of size 1."); dims[i] = std::max(dim0, dim1); } + // Fill in the rest of the shifted dimensions. + for (unsigned int i = 0; i < shiftedDims; i++) + { + dims[i] = input0[i]; + } + return std::vector<TensorShape>({ TensorShape(numDims, dims.data()) }); } |