From 02300aa6460441891e54342286358afe42f432c8 Mon Sep 17 00:00:00 2001 From: Colm Donelan Date: Thu, 4 Apr 2024 11:20:29 +0100 Subject: IVGCVSW-8314 Broadcast handling for Comparison layer is inconsistent. * Added Comparison and LogicalBinary to AddBroadcastReshapeLayer optimization. Signed-off-by: Colm Donelan Change-Id: I4f4bafb961daf63a733be9a1f17067fd246607ad --- src/armnn/optimizations/AddBroadcastReshapeLayer.hpp | 6 +++--- src/backends/backendsCommon/WorkloadData.cpp | 7 +++++-- src/backends/reference/RefLayerSupport.cpp | 8 +++++++- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp index dbde72b917..b07ab54a3f 100644 --- a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp +++ b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2020-2021,2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020-2021,2023-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -7,7 +7,6 @@ #include "Optimization.hpp" #include -#include #include namespace armnn @@ -18,7 +17,8 @@ namespace optimizations static const std::set broadcastOps{ LayerType::Addition, LayerType::Division, LayerType::Maximum, LayerType::Minimum, LayerType::Multiplication, LayerType::Prelu, - LayerType::Subtraction, LayerType::ElementwiseBinary }; + LayerType::Subtraction, LayerType::ElementwiseBinary, + LayerType::Comparison, LayerType::LogicalBinary}; class AddBroadcastReshapeLayerImpl { diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index de985ec28d..7055092be2 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -234,6 +234,7 @@ void ValidateBroadcastTensorShapesMatch(const TensorInfo& first, { // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get // broadcasted. + // NOTE: This check is dependent on the AddBroadcastReshapeLayerImpl optimization having been applied to the layer. if (first.GetNumDimensions() != second.GetNumDimensions()) { throw InvalidArgumentException(descName + ": Tensors " @@ -269,7 +270,8 @@ void ValidateDataTypes(const TensorInfo& info, auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType()); if (iterator == supportedTypes.end()) { - throw InvalidArgumentException(descName + ": " + " Tensor type is not supported."); + throw InvalidArgumentException(descName + ": " + " Tensor type " + GetDataTypeName(info.GetDataType()) + + " is not supported."); } } @@ -710,7 +712,8 @@ void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::QSymmS8, DataType::QSymmS16, DataType::Signed32, - DataType::Signed64 + DataType::Signed64, + DataType::Boolean }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 654aeb55dc..3e04a19df4 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -14,8 +14,8 @@ #include #include -#include #include +#include namespace armnn { @@ -940,6 +940,9 @@ bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0, supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported, "Reference comparison: output is not of type Boolean"); + supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported, + "Reference comparison: shapes are not suitable for implicit broadcast."); + return supported; } @@ -1751,6 +1754,9 @@ bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0, supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported, "Reference LogicalBinary: input and output types do not match"); + supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported, + "Reference LogicalBinary: shapes are not suitable for implicit broadcast."); + return supported; } -- cgit v1.2.1