diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 072b9a9934..05f4e317a9 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1055,6 +1055,21 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { ValidateTwoInputs(workloadInfo, "GatherQueueDescriptor"); ValidateSingleOutput(workloadInfo, "GatherQueueDescriptor"); + + const TensorInfo& indices = workloadInfo.m_InputTensorInfos[1]; + + if (indices.GetDataType() != DataType::Signed32) + { + throw InvalidArgumentException("GatherQueueDescriptor: Indices tensor type must be int32."); + } + + const TensorInfo& params = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; + unsigned int paramsDim = params.GetNumDimensions(); + unsigned int indicesDim = indices.GetNumDimensions(); + unsigned int outputDim = paramsDim - 1 + indicesDim; + + ValidateTensorNumDimensions(output, "GatherQueueDescriptor", outputDim, "output"); } void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const |