From 0c2eeac6347533a1d3d456aebea492f5123388f3 Mon Sep 17 00:00:00 2001 From: Keith Davis Date: Tue, 11 Feb 2020 16:51:50 +0000 Subject: IVGCVSW-4436 Add ExecuteNetwork test for mobilenet_v2_int8 * Add QAsymmS8 to QueueDescriptor supportedTypes * Add QSymmS8/QAsymmS8 to RefLayerSupport supportedTypes * Some additional comments and refactoring Change-Id: I8567314452e6e8f6f69cb6e458ee147d3fc92fab Signed-off-by: Keith Davis --- src/backends/backendsCommon/WorkloadData.cpp | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) (limited to 'src/backends/backendsCommon/WorkloadData.cpp') diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index ebaf961fe8..fea72256a1 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -30,6 +30,8 @@ DataType GetBiasDataType(DataType inputDataType) return DataType::Float16; case DataType::Float32: return DataType::Float32; + case DataType::QAsymmS8: + return DataType::Signed32; case DataType::QAsymmU8: return DataType::Signed32; case DataType::QSymmS8: @@ -357,12 +359,13 @@ void ValidateWeightDataType(const TensorInfo& inputInfo, const std::string& descName) { const DataType inputType = inputInfo.GetDataType(); - if (inputType == DataType::QAsymmU8) + if (IsQuantized8BitType(inputType)) { ARMNN_NO_DEPRECATE_WARN_BEGIN const std::vector validTypes = { DataType::QAsymmU8, + DataType::QAsymmS8, DataType::QSymmS8, DataType::QuantizedSymm8PerAxis // deprecated }; @@ -420,8 +423,7 @@ void ValidatePerAxisQuantization(const TensorInfo& inputInfo, const DataType inputDataType = inputInfo.GetDataType(); const DataType outputDataType = outputInfo.GetDataType(); - const bool canHavePerAxisQuantization = (inputDataType == DataType::QSymmS8 || - inputDataType == DataType::QAsymmU8) && inputDataType == outputDataType; + const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType; if (!canHavePerAxisQuantization) { @@ -599,6 +601,7 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -684,6 +687,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1038,10 +1042,11 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector supportedTypes = { DataType::Float32, + DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16, - DataType::QSymmS8, - DataType::Float16 + DataType::QSymmS8 }; ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); @@ -1181,6 +1186,7 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co { DataType::Float32, DataType::QAsymmU8, + DataType::QAsymmS8, DataType::QSymmS16, DataType::QSymmS8, DataType::Float16 @@ -1255,6 +1261,7 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa { DataType::Float32, DataType::QAsymmU8, + DataType::QAsymmS8, DataType::QSymmS16, DataType::Float16 }; @@ -1309,6 +1316,7 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { DataType::Float32, DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1560,9 +1568,10 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float32, DataType::Float16, DataType::Signed32, + DataType::QSymmS16, + DataType::QAsymmS8, DataType::QAsymmU8, - DataType::QSymmS8, - DataType::QSymmS16 + DataType::QSymmS8 }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); @@ -2208,10 +2217,7 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); - if (outputTensorInfo.GetDataType() != DataType::QAsymmS8 && - outputTensorInfo.GetDataType() != DataType::QAsymmU8 && - outputTensorInfo.GetDataType() != DataType::QSymmS8 && - outputTensorInfo.GetDataType() != DataType::QSymmS16) + if (!IsQuantizedType(outputTensorInfo.GetDataType())) { throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type."); } -- cgit v1.2.1