From ddb1d06dbcb5dc4a89a237ac1176279669817f46 Mon Sep 17 00:00:00 2001 From: Francis Murtagh Date: Tue, 10 Mar 2020 13:51:45 +0000 Subject: MLCE-159 Add QAsymmS8 to ArmnnQuantizer * Allow per layer quantization from Fp32 to Int8 (QAsymmS8) like TfLite Signed-off-by: Francis Murtagh Change-Id: I5bbf770aa29d81af3568c15b47d2b2c18e55bb28 --- src/backends/backendsCommon/WorkloadData.cpp | 3 ++- .../backendsCommon/test/WorkloadTestUtils.hpp | 2 ++ src/backends/reference/RefLayerSupport.cpp | 28 +++++++++++++++++++--- 3 files changed, 29 insertions(+), 4 deletions(-) (limited to 'src/backends') diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index dbd1158380..bb0c21ffba 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -994,6 +994,7 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c { DataType::Float32, DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1183,8 +1184,8 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co std::vector supportedTypes = { DataType::Float32, - DataType::QAsymmU8, DataType::QAsymmS8, + DataType::QAsymmU8, DataType::QSymmS16, DataType::QSymmS8, DataType::Float16 diff --git a/src/backends/backendsCommon/test/WorkloadTestUtils.hpp b/src/backends/backendsCommon/test/WorkloadTestUtils.hpp index 0b0f265db4..51683335e1 100644 --- a/src/backends/backendsCommon/test/WorkloadTestUtils.hpp +++ b/src/backends/backendsCommon/test/WorkloadTestUtils.hpp @@ -98,6 +98,8 @@ inline armnn::Optional GetBiasTypeFromWeightsType(armnn::Option case armnn::DataType::Float16: case armnn::DataType::Float32: return weightsType; + case armnn::DataType::QAsymmS8: + return armnn::DataType::Signed32; case armnn::DataType::QAsymmU8: return armnn::DataType::Signed32; case armnn::DataType::QSymmS16: diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index bd2e7289d8..cb94955e7a 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -815,11 +815,12 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, bool supported = true; // Define supported types. - std::array supportedTypes = + std::array supportedTypes = { DataType::Float32, DataType::Float16, DataType::QAsymmU8, + DataType::QAsymmS8, DataType::QSymmS16 }; @@ -835,8 +836,29 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported, "Reference Fully Connected: weights type not supported."); - supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported, - "Reference Fully Connected: input and weight types mismatched."); + ARMNN_NO_DEPRECATE_WARN_BEGIN + std::array supportedWeightTypes = + { + DataType::QAsymmU8, + DataType::QSymmS8, + DataType::QuantizedSymm8PerAxis // deprecated + }; + ARMNN_NO_DEPRECATE_WARN_END + + if (IsQuantized8BitType(input.GetDataType())) + { + + supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported, + "Reference Fully Connected: weights type not supported for quantized input."); + } + else + { + supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported, + "Reference Fully Connected: weights is not a supported type."); + + supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported, + "Reference Fully Connected: input and weights types mismatched."); + } if (descriptor.m_BiasEnabled) { -- cgit v1.2.1