aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
authorColm Donelan <colm.donelan@arm.com>2024-04-04 11:20:29 +0100
committerColm Donelan <colm.donelan@arm.com>2024-04-18 09:12:22 +0000
commit02300aa6460441891e54342286358afe42f432c8 (patch)
treeb36396e1b3052198af5c8c60fb31a8e5e2931dc5 /src/backends/backendsCommon/WorkloadData.cpp
parent4f1771ab4d321afba9f5a52411855b5dc33bf247 (diff)
downloadarmnn-02300aa6460441891e54342286358afe42f432c8.tar.gz
IVGCVSW-8314 Broadcast handling for Comparison layer is inconsistent.
* Added Comparison and LogicalBinary to AddBroadcastReshapeLayer optimization. Signed-off-by: Colm Donelan <colm.donelan@arm.com> Change-Id: I4f4bafb961daf63a733be9a1f17067fd246607ad
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp7
1 files changed, 5 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index de985ec28d..7055092be2 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -234,6 +234,7 @@ void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
{
// Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
// broadcasted.
+ // NOTE: This check is dependent on the AddBroadcastReshapeLayerImpl optimization having been applied to the layer.
if (first.GetNumDimensions() != second.GetNumDimensions())
{
throw InvalidArgumentException(descName + ": Tensors "
@@ -269,7 +270,8 @@ void ValidateDataTypes(const TensorInfo& info,
auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
if (iterator == supportedTypes.end())
{
- throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
+ throw InvalidArgumentException(descName + ": " + " Tensor type " + GetDataTypeName(info.GetDataType()) +
+ " is not supported.");
}
}
@@ -710,7 +712,8 @@ void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
DataType::QSymmS8,
DataType::QSymmS16,
DataType::Signed32,
- DataType::Signed64
+ DataType::Signed64,
+ DataType::Boolean
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);