diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 07ce14b763..ff97fc7f41 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -623,9 +623,10 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; - if (outputTensorInfo.GetDataType() != DataType::Signed32) + if (outputTensorInfo.GetDataType() != DataType::Signed32 && + outputTensorInfo.GetDataType() != DataType::Signed64) { - throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32."); + throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64."); } std::vector<DataType> supportedInputTypes = @@ -636,7 +637,8 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16, - DataType::Signed32 + DataType::Signed32, + DataType::Signed64 }; ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName); |