diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 2000ce4a57..6667eabdc5 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1806,10 +1806,13 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; - if (inputTensorInfo.GetDataType() != DataType::Float32) + std::vector<DataType> supportedTypes = { - throw InvalidArgumentException(descriptorName + ": Quantize only accepts Float32 inputs."); - } + DataType::Float32, + DataType::Float16 + }; + + ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); if (outputTensorInfo.GetDataType() != DataType::QuantisedAsymm8 && outputTensorInfo.GetDataType() != DataType::QuantisedSymm16) @@ -2117,10 +2120,13 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const throw InvalidArgumentException(descriptorName + ": Input to dequantize layer must be quantized type."); } - if (outputTensorInfo.GetDataType() != DataType::Float32) + std::vector<DataType> supportedTypes = { - throw InvalidArgumentException(descriptorName + ": Output of dequantize layer must be Float32 type."); - } + DataType::Float32, + DataType::Float16 + }; + + ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); } void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const |