diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index de985ec28d..7055092be2 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -234,6 +234,7 @@ void ValidateBroadcastTensorShapesMatch(const TensorInfo& first, { // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get // broadcasted. + // NOTE: This check is dependent on the AddBroadcastReshapeLayerImpl optimization having been applied to the layer. if (first.GetNumDimensions() != second.GetNumDimensions()) { throw InvalidArgumentException(descName + ": Tensors " @@ -269,7 +270,8 @@ void ValidateDataTypes(const TensorInfo& info, auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType()); if (iterator == supportedTypes.end()) { - throw InvalidArgumentException(descName + ": " + " Tensor type is not supported."); + throw InvalidArgumentException(descName + ": " + " Tensor type " + GetDataTypeName(info.GetDataType()) + + " is not supported."); } } @@ -710,7 +712,8 @@ void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::QSymmS8, DataType::QSymmS16, DataType::Signed32, - DataType::Signed64 + DataType::Signed64, + DataType::Boolean }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); |