aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp8
1 files changed, 5 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 07ce14b763..ff97fc7f41 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -623,9 +623,10 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
- if (outputTensorInfo.GetDataType() != DataType::Signed32)
+ if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
+ outputTensorInfo.GetDataType() != DataType::Signed64)
{
- throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
+ throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
}
std::vector<DataType> supportedInputTypes =
@@ -636,7 +637,8 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16,
- DataType::Signed32
+ DataType::Signed32,
+ DataType::Signed64
};
ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);