diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 5ca492888f..cd40097e74 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1597,23 +1597,41 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateNumInputs(workloadInfo, "GatherQueueDescriptor", 2); - ValidateNumOutputs(workloadInfo, "GatherQueueDescriptor", 1); + const std::string GatherQueueDescriptorStr = "GatherQueueDescriptor"; + + ValidateNumInputs(workloadInfo, GatherQueueDescriptorStr, 2); + ValidateNumOutputs(workloadInfo, GatherQueueDescriptorStr, 1); const TensorInfo& indices = workloadInfo.m_InputTensorInfos[1]; if (indices.GetDataType() != DataType::Signed32) { - throw InvalidArgumentException("GatherQueueDescriptor: Indices tensor type must be int32."); + throw InvalidArgumentException(GatherQueueDescriptorStr + ": Indices tensor type must be int32."); } + std::vector<DataType> supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + GatherQueueDescriptorStr); + + ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_OutputTensorInfos[0], + GatherQueueDescriptorStr, "Input", "Output"); + const TensorInfo& params = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; unsigned int paramsDim = params.GetNumDimensions(); unsigned int indicesDim = indices.GetNumDimensions(); unsigned int outputDim = paramsDim - 1 + indicesDim; - ValidateTensorNumDimensions(output, "GatherQueueDescriptor", outputDim, "output"); + ValidateTensorNumDimensions(output, GatherQueueDescriptorStr, outputDim, "output"); } void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const |