diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-11-05 18:00:21 +0000 |
---|---|---|
committer | Francis Murtagh <francis.murtagh@arm.com> | 2019-11-06 12:10:02 +0000 |
commit | 5edc8816118fcddb2681379db04c978041ce8b46 (patch) | |
tree | 22e4382138e9963d0ed3dacefda4fb142877e1fc /src/backends/backendsCommon/WorkloadData.cpp | |
parent | ec33a91ec1557b78b2d01975ec4c5eaf24aa058c (diff) | |
download | armnn-5edc8816118fcddb2681379db04c978041ce8b46.tar.gz |
IVGCVSW-3837 Add support for per-axis quantization to reference Convolution2d workload
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I0ac08ba4864d48e6f64c4ac645dad8ea850be112
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 109 |
1 files changed, 106 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index e1a369af7c..201cc7d1ec 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -338,6 +338,102 @@ void ValidateTensorNumElementsMatch(const TensorInfo& first, } } +void ValidateWeightDataType(const TensorInfo& inputInfo, + const TensorInfo& weightInfo, + const std::string& descName) +{ + const DataType inputType = inputInfo.GetDataType(); + if (inputType == DataType::QuantisedAsymm8) + { + const std::vector<DataType> validTypes = + { + DataType::QuantisedAsymm8, + DataType::QuantizedSymm8PerAxis + }; + + ValidateDataTypes(weightInfo, validTypes, descName); + } + else + { + ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight"); + } +} + +void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo, + const std::string& descName, + const std::string& tensorName) +{ + const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim(); + if (!quantizationDim.has_value()) + { + throw InvalidArgumentException(boost::str( + boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.") + % descName % tensorName)); + } + + if (quantizationDim.value() != 0) + { + throw InvalidArgumentException(boost::str( + boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, " + "but got: %3%") % descName % tensorName % quantizationDim.value())); + } +} + +void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo, + const std::string& descName, + const std::string& tensorName) +{ + int32_t quantizationOffset = tensorInfo.GetQuantizationOffset(); + if (quantizationOffset != 0) + { + throw InvalidArgumentException(boost::str( + boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, " + "but got: %3%") % descName % tensorName % quantizationOffset)); + } +} + +void ValidatePerAxisQuantization(const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + const TensorInfo& weightInfo, + const Optional<TensorInfo>& optionalBiasInfo, + const std::string& descName) +{ + if (weightInfo.HasPerAxisQuantization()) + { + const DataType inputDataType = inputInfo.GetDataType(); + const DataType outputDataType = outputInfo.GetDataType(); + + const bool canHavePerAxisQuantization = + inputDataType == DataType::QuantisedAsymm8 && inputDataType == outputDataType; + + if (!canHavePerAxisQuantization) + { + throw InvalidArgumentException(boost::str( + boost::format("%1%: Per-axis quantization parameters set on tensor %2%, " + "but data type does not support per-axis quantization.") % descName % "weight")); + } + + ValidateTensorDataType(weightInfo, DataType::QuantizedSymm8PerAxis, descName, "weight"); + ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight"); + ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight"); + + if (optionalBiasInfo.has_value()) + { + const TensorInfo& biasInfo = optionalBiasInfo.value(); + if (!biasInfo.HasPerAxisQuantization()) + { + throw InvalidArgumentException(boost::str( + boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on " + "weight tensor.") % descName)); + } + + ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias"); + ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias"); + ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias"); + } + } +} + } // anonymous namespace void QueueDescriptor::ValidateInputsOutputs(const std::string& descName, @@ -1040,19 +1136,26 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo(); ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight"); - ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight"); + ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName); + Optional<TensorInfo> optionalBiasTensorInfo; if (m_Parameters.m_BiasEnabled) { ValidatePointer(m_Bias, descriptorName, "bias"); - const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo(); - ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias"); + optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo()); + const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value(); ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias"); ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName); } + ValidatePerAxisQuantization(inputTensorInfo, + outputTensorInfo, + weightTensorInfo, + optionalBiasTensorInfo, + descriptorName); + std::vector<DataType> supportedTypes = { DataType::Float32, |