From ddb1d06dbcb5dc4a89a237ac1176279669817f46 Mon Sep 17 00:00:00 2001 From: Francis Murtagh Date: Tue, 10 Mar 2020 13:51:45 +0000 Subject: MLCE-159 Add QAsymmS8 to ArmnnQuantizer * Allow per layer quantization from Fp32 to Int8 (QAsymmS8) like TfLite Signed-off-by: Francis Murtagh Change-Id: I5bbf770aa29d81af3568c15b47d2b2c18e55bb28 --- src/backends/reference/RefLayerSupport.cpp | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) (limited to 'src/backends/reference') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index bd2e7289d8..cb94955e7a 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -815,11 +815,12 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, bool supported = true; // Define supported types. - std::array supportedTypes = + std::array supportedTypes = { DataType::Float32, DataType::Float16, DataType::QAsymmU8, + DataType::QAsymmS8, DataType::QSymmS16 }; @@ -835,8 +836,29 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported, "Reference Fully Connected: weights type not supported."); - supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported, - "Reference Fully Connected: input and weight types mismatched."); + ARMNN_NO_DEPRECATE_WARN_BEGIN + std::array supportedWeightTypes = + { + DataType::QAsymmU8, + DataType::QSymmS8, + DataType::QuantizedSymm8PerAxis // deprecated + }; + ARMNN_NO_DEPRECATE_WARN_END + + if (IsQuantized8BitType(input.GetDataType())) + { + + supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported, + "Reference Fully Connected: weights type not supported for quantized input."); + } + else + { + supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported, + "Reference Fully Connected: weights is not a supported type."); + + supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported, + "Reference Fully Connected: input and weights types mismatched."); + } if (descriptor.m_BiasEnabled) { -- cgit v1.2.1