aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFrancis Murtagh <francis.murtagh@arm.com>2020-03-10 13:51:45 +0000
committerJim Flynn <jim.flynn@arm.com>2020-03-10 14:02:00 +0000
commitddb1d06dbcb5dc4a89a237ac1176279669817f46 (patch)
treee2d1104bea631f8f366f529201041e701d09b998 /src
parent6445cfff7519effd1df04eac88ae17d6e4e6693b (diff)
downloadarmnn-ddb1d06dbcb5dc4a89a237ac1176279669817f46.tar.gz
MLCE-159 Add QAsymmS8 to ArmnnQuantizer
* Allow per layer quantization from Fp32 to Int8 (QAsymmS8) like TfLite Signed-off-by: Francis Murtagh <francis.murtagh@arm.com> Change-Id: I5bbf770aa29d81af3568c15b47d2b2c18e55bb28
Diffstat (limited to 'src')
-rw-r--r--src/armnnDeserializer/Deserializer.cpp3
-rw-r--r--src/armnnQuantizer/ArmNNQuantizerMain.cpp16
-rw-r--r--src/armnnQuantizer/CommandLineProcessor.cpp12
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs3
-rw-r--r--src/armnnSerializer/SerializerUtils.cpp2
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp3
-rw-r--r--src/backends/backendsCommon/test/WorkloadTestUtils.hpp2
-rw-r--r--src/backends/reference/RefLayerSupport.cpp28
8 files changed, 57 insertions, 12 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 1f7c360d51..bc6fbf0194 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -505,6 +505,9 @@ armnn::TensorInfo ToTensorInfo(Deserializer::TensorRawPtr tensorPtr)
switch (tensorPtr->dataType())
{
+ case DataType_QAsymmS8:
+ type = armnn::DataType::QAsymmS8;
+ break;
case DataType_QuantisedAsymm8:
case DataType_QAsymmU8:
type = armnn::DataType::QAsymmU8;
diff --git a/src/armnnQuantizer/ArmNNQuantizerMain.cpp b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
index 30167e73f2..219363edbb 100644
--- a/src/armnnQuantizer/ArmNNQuantizerMain.cpp
+++ b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
@@ -36,9 +36,19 @@ int main(int argc, char* argv[])
inputFileStream.close();
armnn::QuantizerOptions quantizerOptions;
- quantizerOptions.m_ActivationFormat = cmdline.GetQuantizationScheme() == "QSymm16"
- ? armnn::DataType::QSymmS16
- : armnn::DataType::QAsymmU8;
+
+ if (cmdline.GetQuantizationScheme() == "QAsymmS8")
+ {
+ quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmS8;
+ }
+ else if (cmdline.GetQuantizationScheme() == "QSymmS16")
+ {
+ quantizerOptions.m_ActivationFormat = armnn::DataType::QSymmS16;
+ }
+ else
+ {
+ quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmU8;
+ }
quantizerOptions.m_PreserveType = cmdline.HasPreservedDataType();
diff --git a/src/armnnQuantizer/CommandLineProcessor.cpp b/src/armnnQuantizer/CommandLineProcessor.cpp
index d2163c0869..0cccb66f63 100644
--- a/src/armnnQuantizer/CommandLineProcessor.cpp
+++ b/src/armnnQuantizer/CommandLineProcessor.cpp
@@ -67,8 +67,10 @@ bool ValidateQuantizationScheme(const std::string& scheme)
return false;
}
- std::vector<std::string> supportedSchemes = {
- "QAsymm8",
+ std::vector<std::string> supportedSchemes =
+ {
+ "QAsymmS8",
+ "QAsymmU8",
"QSymm16"
};
@@ -93,8 +95,10 @@ bool CommandLineProcessor::ProcessCommandLine(int argc, char* argv[])
("help,h", "Display help messages")
("infile,f", po::value<std::string>(&m_InputFileName)->required(),
"Input file containing float 32 ArmNN Input Graph")
- ("scheme,s", po::value<std::string>(&m_QuantizationScheme)->default_value("QAsymm8"),
- "Quantization scheme, \"QAsymm8\" or \"QSymm16\", default value QAsymm8")
+ ("scheme,s", po::value<std::string>(&m_QuantizationScheme)->default_value("QAsymmU8"),
+ "Quantization scheme,"
+ " \"QAsymmU8\" or \"QAsymmS8\" or \"QSymm16\","
+ " default value QAsymmU8")
("csvfile,c", po::value<std::string>(&m_CsvFileName)->default_value(""),
"CSV file containing paths for RAW input tensors")
("preserve-data-type,p", po::bool_switch(&m_PreserveDataType)->default_value(false),
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index d7565a5b9a..ca3db5d542 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -37,7 +37,8 @@ enum DataType : byte {
Boolean = 4,
QuantisedSymm16 = 5, // deprecated
QAsymmU8 = 6,
- QSymmS16 = 7
+ QSymmS16 = 7,
+ QAsymmS8 = 8
}
enum DataLayout : byte {
diff --git a/src/armnnSerializer/SerializerUtils.cpp b/src/armnnSerializer/SerializerUtils.cpp
index 02a5ed3872..c1847715a0 100644
--- a/src/armnnSerializer/SerializerUtils.cpp
+++ b/src/armnnSerializer/SerializerUtils.cpp
@@ -58,6 +58,8 @@ armnnSerializer::DataType GetFlatBufferDataType(armnn::DataType dataType)
return armnnSerializer::DataType::DataType_Signed32;
case armnn::DataType::QSymmS16:
return armnnSerializer::DataType::DataType_QSymmS16;
+ case armnn::DataType::QAsymmS8:
+ return armnnSerializer::DataType::DataType_QAsymmS8;
case armnn::DataType::QAsymmU8:
return armnnSerializer::DataType::DataType_QAsymmU8;
case armnn::DataType::Boolean:
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index dbd1158380..bb0c21ffba 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -994,6 +994,7 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
{
DataType::Float32,
DataType::Float16,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -1183,8 +1184,8 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
std::vector<DataType> supportedTypes =
{
DataType::Float32,
- DataType::QAsymmU8,
DataType::QAsymmS8,
+ DataType::QAsymmU8,
DataType::QSymmS16,
DataType::QSymmS8,
DataType::Float16
diff --git a/src/backends/backendsCommon/test/WorkloadTestUtils.hpp b/src/backends/backendsCommon/test/WorkloadTestUtils.hpp
index 0b0f265db4..51683335e1 100644
--- a/src/backends/backendsCommon/test/WorkloadTestUtils.hpp
+++ b/src/backends/backendsCommon/test/WorkloadTestUtils.hpp
@@ -98,6 +98,8 @@ inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Option
case armnn::DataType::Float16:
case armnn::DataType::Float32:
return weightsType;
+ case armnn::DataType::QAsymmS8:
+ return armnn::DataType::Signed32;
case armnn::DataType::QAsymmU8:
return armnn::DataType::Signed32;
case armnn::DataType::QSymmS16:
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index bd2e7289d8..cb94955e7a 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -815,11 +815,12 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
bool supported = true;
// Define supported types.
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,5> supportedTypes =
{
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS16
};
@@ -835,8 +836,29 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
"Reference Fully Connected: weights type not supported.");
- supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
- "Reference Fully Connected: input and weight types mismatched.");
+ ARMNN_NO_DEPRECATE_WARN_BEGIN
+ std::array<DataType, 3> supportedWeightTypes =
+ {
+ DataType::QAsymmU8,
+ DataType::QSymmS8,
+ DataType::QuantizedSymm8PerAxis // deprecated
+ };
+ ARMNN_NO_DEPRECATE_WARN_END
+
+ if (IsQuantized8BitType(input.GetDataType()))
+ {
+
+ supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
+ "Reference Fully Connected: weights type not supported for quantized input.");
+ }
+ else
+ {
+ supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
+ "Reference Fully Connected: weights is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
+ "Reference Fully Connected: input and weights types mismatched.");
+ }
if (descriptor.m_BiasEnabled)
{