aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorColm Donelan <colm.donelan@arm.com>2024-04-04 11:20:29 +0100
committerColm Donelan <colm.donelan@arm.com>2024-04-18 09:12:22 +0000
commit02300aa6460441891e54342286358afe42f432c8 (patch)
treeb36396e1b3052198af5c8c60fb31a8e5e2931dc5
parent4f1771ab4d321afba9f5a52411855b5dc33bf247 (diff)
downloadarmnn-02300aa6460441891e54342286358afe42f432c8.tar.gz
IVGCVSW-8314 Broadcast handling for Comparison layer is inconsistent.
* Added Comparison and LogicalBinary to AddBroadcastReshapeLayer optimization. Signed-off-by: Colm Donelan <colm.donelan@arm.com> Change-Id: I4f4bafb961daf63a733be9a1f17067fd246607ad
-rw-r--r--src/armnn/optimizations/AddBroadcastReshapeLayer.hpp6
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp7
-rw-r--r--src/backends/reference/RefLayerSupport.cpp8
3 files changed, 15 insertions, 6 deletions
diff --git a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
index dbde72b917..b07ab54a3f 100644
--- a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
+++ b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2020-2021,2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2020-2021,2023-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -7,7 +7,6 @@
#include "Optimization.hpp"
#include <armnn/backends/TensorHandle.hpp>
-#include <armnn/utility/IgnoreUnused.hpp>
#include <armnn/utility/PolymorphicDowncast.hpp>
namespace armnn
@@ -18,7 +17,8 @@ namespace optimizations
static const std::set<armnn::LayerType> broadcastOps{ LayerType::Addition, LayerType::Division,
LayerType::Maximum, LayerType::Minimum,
LayerType::Multiplication, LayerType::Prelu,
- LayerType::Subtraction, LayerType::ElementwiseBinary };
+ LayerType::Subtraction, LayerType::ElementwiseBinary,
+ LayerType::Comparison, LayerType::LogicalBinary};
class AddBroadcastReshapeLayerImpl
{
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index de985ec28d..7055092be2 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -234,6 +234,7 @@ void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
{
// Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
// broadcasted.
+ // NOTE: This check is dependent on the AddBroadcastReshapeLayerImpl optimization having been applied to the layer.
if (first.GetNumDimensions() != second.GetNumDimensions())
{
throw InvalidArgumentException(descName + ": Tensors "
@@ -269,7 +270,8 @@ void ValidateDataTypes(const TensorInfo& info,
auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
if (iterator == supportedTypes.end())
{
- throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
+ throw InvalidArgumentException(descName + ": " + " Tensor type " + GetDataTypeName(info.GetDataType()) +
+ " is not supported.");
}
}
@@ -710,7 +712,8 @@ void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
DataType::QSymmS8,
DataType::QSymmS16,
DataType::Signed32,
- DataType::Signed64
+ DataType::Signed64,
+ DataType::Boolean
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 654aeb55dc..3e04a19df4 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -14,8 +14,8 @@
#include <LayerSupportCommon.hpp>
#include <backendsCommon/LayerSupportRules.hpp>
-#include <vector>
#include <array>
+#include <vector>
namespace armnn
{
@@ -940,6 +940,9 @@ bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
"Reference comparison: output is not of type Boolean");
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference comparison: shapes are not suitable for implicit broadcast.");
+
return supported;
}
@@ -1751,6 +1754,9 @@ bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
"Reference LogicalBinary: input and output types do not match");
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference LogicalBinary: shapes are not suitable for implicit broadcast.");
+
return supported;
}