From ddb1d06dbcb5dc4a89a237ac1176279669817f46 Mon Sep 17 00:00:00 2001 From: Francis Murtagh Date: Tue, 10 Mar 2020 13:51:45 +0000 Subject: MLCE-159 Add QAsymmS8 to ArmnnQuantizer * Allow per layer quantization from Fp32 to Int8 (QAsymmS8) like TfLite Signed-off-by: Francis Murtagh Change-Id: I5bbf770aa29d81af3568c15b47d2b2c18e55bb28 --- src/armnnDeserializer/Deserializer.cpp | 3 +++ src/armnnQuantizer/ArmNNQuantizerMain.cpp | 16 ++++++++++--- src/armnnQuantizer/CommandLineProcessor.cpp | 12 ++++++---- src/armnnSerializer/ArmnnSchema.fbs | 3 ++- src/armnnSerializer/SerializerUtils.cpp | 2 ++ src/backends/backendsCommon/WorkloadData.cpp | 3 ++- .../backendsCommon/test/WorkloadTestUtils.hpp | 2 ++ src/backends/reference/RefLayerSupport.cpp | 28 +++++++++++++++++++--- 8 files changed, 57 insertions(+), 12 deletions(-) diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 1f7c360d51..bc6fbf0194 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -505,6 +505,9 @@ armnn::TensorInfo ToTensorInfo(Deserializer::TensorRawPtr tensorPtr) switch (tensorPtr->dataType()) { + case DataType_QAsymmS8: + type = armnn::DataType::QAsymmS8; + break; case DataType_QuantisedAsymm8: case DataType_QAsymmU8: type = armnn::DataType::QAsymmU8; diff --git a/src/armnnQuantizer/ArmNNQuantizerMain.cpp b/src/armnnQuantizer/ArmNNQuantizerMain.cpp index 30167e73f2..219363edbb 100644 --- a/src/armnnQuantizer/ArmNNQuantizerMain.cpp +++ b/src/armnnQuantizer/ArmNNQuantizerMain.cpp @@ -36,9 +36,19 @@ int main(int argc, char* argv[]) inputFileStream.close(); armnn::QuantizerOptions quantizerOptions; - quantizerOptions.m_ActivationFormat = cmdline.GetQuantizationScheme() == "QSymm16" - ? armnn::DataType::QSymmS16 - : armnn::DataType::QAsymmU8; + + if (cmdline.GetQuantizationScheme() == "QAsymmS8") + { + quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmS8; + } + else if (cmdline.GetQuantizationScheme() == "QSymmS16") + { + quantizerOptions.m_ActivationFormat = armnn::DataType::QSymmS16; + } + else + { + quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmU8; + } quantizerOptions.m_PreserveType = cmdline.HasPreservedDataType(); diff --git a/src/armnnQuantizer/CommandLineProcessor.cpp b/src/armnnQuantizer/CommandLineProcessor.cpp index d2163c0869..0cccb66f63 100644 --- a/src/armnnQuantizer/CommandLineProcessor.cpp +++ b/src/armnnQuantizer/CommandLineProcessor.cpp @@ -67,8 +67,10 @@ bool ValidateQuantizationScheme(const std::string& scheme) return false; } - std::vector supportedSchemes = { - "QAsymm8", + std::vector supportedSchemes = + { + "QAsymmS8", + "QAsymmU8", "QSymm16" }; @@ -93,8 +95,10 @@ bool CommandLineProcessor::ProcessCommandLine(int argc, char* argv[]) ("help,h", "Display help messages") ("infile,f", po::value(&m_InputFileName)->required(), "Input file containing float 32 ArmNN Input Graph") - ("scheme,s", po::value(&m_QuantizationScheme)->default_value("QAsymm8"), - "Quantization scheme, \"QAsymm8\" or \"QSymm16\", default value QAsymm8") + ("scheme,s", po::value(&m_QuantizationScheme)->default_value("QAsymmU8"), + "Quantization scheme," + " \"QAsymmU8\" or \"QAsymmS8\" or \"QSymm16\"," + " default value QAsymmU8") ("csvfile,c", po::value(&m_CsvFileName)->default_value(""), "CSV file containing paths for RAW input tensors") ("preserve-data-type,p", po::bool_switch(&m_PreserveDataType)->default_value(false), diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index d7565a5b9a..ca3db5d542 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -37,7 +37,8 @@ enum DataType : byte { Boolean = 4, QuantisedSymm16 = 5, // deprecated QAsymmU8 = 6, - QSymmS16 = 7 + QSymmS16 = 7, + QAsymmS8 = 8 } enum DataLayout : byte { diff --git a/src/armnnSerializer/SerializerUtils.cpp b/src/armnnSerializer/SerializerUtils.cpp index 02a5ed3872..c1847715a0 100644 --- a/src/armnnSerializer/SerializerUtils.cpp +++ b/src/armnnSerializer/SerializerUtils.cpp @@ -58,6 +58,8 @@ armnnSerializer::DataType GetFlatBufferDataType(armnn::DataType dataType) return armnnSerializer::DataType::DataType_Signed32; case armnn::DataType::QSymmS16: return armnnSerializer::DataType::DataType_QSymmS16; + case armnn::DataType::QAsymmS8: + return armnnSerializer::DataType::DataType_QAsymmS8; case armnn::DataType::QAsymmU8: return armnnSerializer::DataType::DataType_QAsymmU8; case armnn::DataType::Boolean: diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index dbd1158380..bb0c21ffba 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -994,6 +994,7 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c { DataType::Float32, DataType::Float16, + DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 }; @@ -1183,8 +1184,8 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co std::vector supportedTypes = { DataType::Float32, - DataType::QAsymmU8, DataType::QAsymmS8, + DataType::QAsymmU8, DataType::QSymmS16, DataType::QSymmS8, DataType::Float16 diff --git a/src/backends/backendsCommon/test/WorkloadTestUtils.hpp b/src/backends/backendsCommon/test/WorkloadTestUtils.hpp index 0b0f265db4..51683335e1 100644 --- a/src/backends/backendsCommon/test/WorkloadTestUtils.hpp +++ b/src/backends/backendsCommon/test/WorkloadTestUtils.hpp @@ -98,6 +98,8 @@ inline armnn::Optional GetBiasTypeFromWeightsType(armnn::Option case armnn::DataType::Float16: case armnn::DataType::Float32: return weightsType; + case armnn::DataType::QAsymmS8: + return armnn::DataType::Signed32; case armnn::DataType::QAsymmU8: return armnn::DataType::Signed32; case armnn::DataType::QSymmS16: diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index bd2e7289d8..cb94955e7a 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -815,11 +815,12 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, bool supported = true; // Define supported types. - std::array supportedTypes = + std::array supportedTypes = { DataType::Float32, DataType::Float16, DataType::QAsymmU8, + DataType::QAsymmS8, DataType::QSymmS16 }; @@ -835,8 +836,29 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported, "Reference Fully Connected: weights type not supported."); - supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported, - "Reference Fully Connected: input and weight types mismatched."); + ARMNN_NO_DEPRECATE_WARN_BEGIN + std::array supportedWeightTypes = + { + DataType::QAsymmU8, + DataType::QSymmS8, + DataType::QuantizedSymm8PerAxis // deprecated + }; + ARMNN_NO_DEPRECATE_WARN_END + + if (IsQuantized8BitType(input.GetDataType())) + { + + supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported, + "Reference Fully Connected: weights type not supported for quantized input."); + } + else + { + supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported, + "Reference Fully Connected: weights is not a supported type."); + + supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported, + "Reference Fully Connected: input and weights types mismatched."); + } if (descriptor.m_BiasEnabled) { -- cgit v1.2.1