From 303980c502c721f13d65e7087be6c0758df65044 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Fri, 17 Apr 2020 12:45:14 +0100 Subject: IVGCVSW-4668 Add TENSOR_QUANT8_ASYMM_SIGNED data type support to CpuRef operators Signed-off-by: Teresa Charlin Signed-off-by: Sadik Armagan Change-Id: I094125ba80699cc3cf5226bda6662a54e6caa988 --- src/backends/backendsCommon/WorkloadData.cpp | 75 ++++++++++++++++++---------- 1 file changed, 50 insertions(+), 25 deletions(-) (limited to 'src/backends/backendsCommon/WorkloadData.cpp') diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 5fe056e669..d1249a492f 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -365,8 +365,8 @@ void ValidateWeightDataType(const TensorInfo& inputInfo, ARMNN_NO_DEPRECATE_WARN_BEGIN const std::vector validTypes = { - DataType::QAsymmU8, DataType::QAsymmS8, + DataType::QAsymmU8, DataType::QSymmS8, DataType::QuantizedSymm8PerAxis // deprecated }; @@ -633,6 +633,7 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::BFloat16, DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16, DataType::Signed32 @@ -715,6 +716,7 @@ void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float16, DataType::Boolean, DataType::Signed32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -852,6 +854,7 @@ void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float16, DataType::Boolean, DataType::Signed32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -940,6 +943,7 @@ void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::Float16, DataType::Boolean, DataType::Signed32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1040,6 +1044,7 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co DataType::BFloat16, DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1101,11 +1106,11 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::vector supportedTypes = { DataType::BFloat16, + DataType::Float16, DataType::Float32, - DataType::QAsymmU8, DataType::QAsymmS8, - DataType::QSymmS16, - DataType::Float16 + DataType::QAsymmU8, + DataType::QSymmS16 }; ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); @@ -1138,6 +1143,7 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf DataType::BFloat16, DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1209,12 +1215,12 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co std::vector supportedTypes = { DataType::BFloat16, + DataType::Float16, DataType::Float32, DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16, - DataType::QSymmS8, - DataType::Float16 + DataType::QSymmS8 }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); @@ -1298,11 +1304,11 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa std::vector supportedTypes = { DataType::BFloat16, + DataType::Float16, DataType::Float32, - DataType::QAsymmU8, DataType::QAsymmS8, - DataType::QSymmS16, - DataType::Float16 + DataType::QAsymmU8, + DataType::QSymmS16 }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); @@ -1383,6 +1389,7 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c DataType::BFloat16, DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1535,6 +1542,7 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) DataType::BFloat16, DataType::Float32, DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1587,11 +1595,11 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::BFloat16, DataType::Float32, DataType::Float16, - DataType::Signed32, - DataType::QAsymmU8, DataType::QAsymmS8, + DataType::QAsymmU8, DataType::QSymmS8, - DataType::QSymmS16 + DataType::QSymmS16, + DataType::Signed32 }; ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); @@ -1615,10 +1623,10 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::BFloat16, DataType::Float32, DataType::Float16, - DataType::Signed32, - DataType::QSymmS16, DataType::QAsymmS8, - DataType::QAsymmU8 + DataType::QAsymmU8, + DataType::QSymmS16, + DataType::Signed32 }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); @@ -1683,6 +1691,7 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c DataType::BFloat16, DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1709,6 +1718,7 @@ void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con DataType::BFloat16, DataType::Float32, DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -2146,11 +2156,12 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const std::vector supportedTypes = { + DataType::BFloat16, + DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, - DataType::QSymmS16, - DataType::Float16, - DataType::BFloat16 + DataType::QSymmS16 }; ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); @@ -2178,11 +2189,12 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons std::vector supportedTypes = { + DataType::BFloat16, + DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, - DataType::QSymmS16, - DataType::Float16, - DataType::BFloat16 + DataType::QSymmS16 }; ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); @@ -2213,10 +2225,10 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::BFloat16, DataType::Float16, DataType::Float32, - DataType::Signed32, DataType::QAsymmS8, DataType::QAsymmU8, - DataType::QSymmS16 + DataType::QSymmS16, + DataType::Signed32 }; ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); @@ -2246,6 +2258,7 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::BFloat16, DataType::Float32, DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -2340,6 +2353,7 @@ void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c DataType::BFloat16, DataType::Float32, DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -2363,6 +2377,7 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con DataType::BFloat16, DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -2420,9 +2435,10 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::BFloat16, DataType::Float16, DataType::Float32, - DataType::Signed32, + DataType::QAsymmS8, DataType::QAsymmU8, - DataType::QSymmS16 + DataType::QSymmS16, + DataType::Signed32 }; ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); @@ -2510,6 +2526,7 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::BFloat16, DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -2539,6 +2556,7 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::BFloat16, DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -2586,6 +2604,7 @@ void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadI DataType::BFloat16, DataType::Float32, DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -2678,6 +2697,7 @@ void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { DataType::BFloat16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -2722,6 +2742,7 @@ void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::BFloat16, DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -2785,6 +2806,7 @@ void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa DataType::BFloat16, DataType::Float32, DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -3010,6 +3032,7 @@ void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const DataType::BFloat16, DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16, DataType::Signed32 @@ -3092,6 +3115,7 @@ void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con DataType::BFloat16, DataType::Float32, DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -3167,6 +3191,7 @@ void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) DataType::BFloat16, DataType::Float16, DataType::Float32, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16, DataType::Signed32 -- cgit v1.2.1