From 9358296a51b9186335304a53bd7ea5dfbe5322d8 Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Wed, 9 Sep 2020 21:58:15 +0100 Subject: vela: Improve the scaling is equal check - Fixed and documented both tensor and quant params scaling checks - Added quant params validity check and tensor quantisation check - Added valid tensor checks to some graph optimisation functions Signed-off-by: Tim Hall Change-Id: I8d6e8f03a603d28886dde511672c8399c85b794c --- ethosu/vela/graph_optimiser.py | 30 ++++++++++++++++++-------- ethosu/vela/supported_operators.py | 21 ++++++++++++------- ethosu/vela/tensor.py | 43 ++++++++++++++++++++++++++++++++++---- 3 files changed, 74 insertions(+), 20 deletions(-) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index f6209ed2..4696446c 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -35,6 +35,7 @@ from .operation import NpuBlockType from .operation import Op from .operation import Operation from .softmax import SoftMax +from .tensor import check_quantized_tens_scaling_equal from .tensor import create_const_tensor from .tensor import create_reshape_tensor from .tensor import QuantizationParameters @@ -341,7 +342,7 @@ def convert_batched_fc_to_conv(op, arch, nng): # There is a preceding Reshape # Compare input of prev_op and input of op, to see if prev_op can be removed ifm_prev_op = prev_op.inputs[0] - if ifm_prev_op.shape == ifm.shape and ifm_prev_op.quantization.is_scaling_equal(ifm.quantization): + if ifm_prev_op.shape == ifm.shape and check_quantized_tens_scaling_equal(ifm_prev_op, ifm.quantization): # prev_op can be removed op.set_input_tensor(ifm_prev_op, 0) else: @@ -369,7 +370,7 @@ def convert_batched_fc_to_conv(op, arch, nng): # There is a subsequent Reshape # Compare desired shape and output of consumer op, to see if consumer op can be removed ofm_cons_op = ofm.consumer_list[0].outputs[0] - if desired_shape == ofm_cons_op.shape and ofm.quantization.is_scaling_equal(ofm_cons_op.quantization): + if desired_shape == ofm_cons_op.shape and check_quantized_tens_scaling_equal(ofm, ofm_cons_op): op.outputs[0] = ofm_cons_op op.outputs[0].ops = [op] else: @@ -613,7 +614,7 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng): ofm = op.outputs[0] # Relu with differing IFM and OFM scaling cannot be fused with another primary op # and requires its own to be inserted - if not ifm.is_scaling_equal(ofm): + if not check_quantized_tens_scaling_equal(ifm, ofm): # Override this op with its own primary op (avgpool) relu_fused_op = create_avgpool_nop(op.name + "_avgpool") # And fuse the original activation function to it @@ -727,9 +728,12 @@ def convert_mul_max_to_abs_or_lrelu(op, arch, nng): if mul.activation: return op ifm, ofm = op.get_ifm_ofm() + if ifm is None or ofm is None: + return op + if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype: return op - if not ifm.is_scaling_equal(ofm) or not ifm.is_scaling_equal(mul_ofm): + if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm): # rewrite to LeakyRelu currently only makes sense if the quantization is identical return op @@ -780,6 +784,8 @@ def convert_lrelu_to_mul_max(op, arch): # Converts LeakyRelu to Max(alpha * IFM, identity * IFM) # (the opposite of convert_mul_max_to_abs_or_lrelu) ifm, ofm = op.get_ifm_ofm() + if ifm is None or ofm is None: + return op # Add multiplication with alpha mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha") @@ -796,7 +802,7 @@ def convert_lrelu_to_mul_max(op, arch): fm_alpha = ofm.clone(op.name + "_alpha") mul_alpha.set_output_tensor(fm_alpha) - if ifm.is_scaling_equal(ofm): + if check_quantized_tens_scaling_equal(ifm, ofm): # No identity multiplication is needed fm_id = ifm else: @@ -829,6 +835,8 @@ def convert_lrelu_to_mul_max(op, arch): def convert_to_lut(op, lut_values, lut_name): # Rewrite the operation by Add with scalar 0 + LUT activation ifm = op.inputs[0] + if ifm is None: + return op assert ifm.dtype.size_in_bytes() == 1 op.type = Op.Add op.name = op.name + "_lut_" + lut_name @@ -908,10 +916,12 @@ def convert_lrelu(op, arch, nng): if op.type != Op.LeakyRelu: return op ifm, ofm = op.get_ifm_ofm() + if ifm is None or ofm is None: + return op if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype: # use LUT for int8/uint8 return convert_lrelu_to_lut(op, arch) - if ifm.is_scaling_equal(ofm) and ifm.dtype == ofm.dtype and ifm.dtype == DataType.int16: + if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16: # use LeakyRelu unmodified for int16 with equal input/output scaling return op return convert_lrelu_to_mul_max(op, arch) @@ -953,9 +963,9 @@ def remove_unwanted_reshapes(op, arch, nng): cons_op_ofm = cons_op.outputs[0] if len(prev_op_ifm.shape) == len(cons_op_ofm.shape): # Check if quantization is the same in the input and output for the reshape ops - if prev_op_ifm.quantization.is_scaling_equal( - prev_op_ofm.quantization - ) and cons_op_ifm.quantization.is_scaling_equal(cons_op_ofm.quantization): + if check_quantized_tens_scaling_equal(prev_op_ifm, prev_op_ofm) and check_quantized_tens_scaling_equal( + cons_op_ifm, cons_op_ofm + ): op.set_input_tensor(prev_op_ifm, 0) op.set_output_tensor(cons_op_ofm) return op @@ -966,6 +976,8 @@ def fuse_activation_function_with_prev(op, arch, nng): if not op.attrs.get("is_nop", False) or op.activation is None: return op ifm, ofm = op.get_ifm_ofm() + if ifm is None or ofm is None: + return op # finds the input(s) to the operation prev_op = ifm.ops[0] # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index f4dd5796..dfb7bc7d 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -24,6 +24,8 @@ from .data_type import DataType from .numeric_util import is_integer from .operation import get_slice_offsets from .operation import Op +from .tensor import check_quantized_tens_scaling_equal +from .tensor import check_tens_quantized # Custom decorator function to allow formatting docstrings containing "{}" @@ -730,17 +732,22 @@ class SupportedOperators: @classmethod def check_quantization_restrictions_binary_elem_wise(cls, op): - # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops + # checks that IFM1, IFM2 and OFM quantization are equal for binary ops + assert len(op.inputs) >= 2 and len(op.outputs) == 1 if ( - op.inputs[0].quantization is None - or not op.inputs[0].is_scaling_equal(op.inputs[1]) - or not op.inputs[0].is_scaling_equal(op.outputs[0]) + not check_tens_quantized(op.inputs[0]) + or not check_tens_quantized(op.inputs[1]) + or not check_tens_quantized(op.outputs[0]) ): - print( - "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator" - ) + warn_cpu(op, "has non-quantised input and/or output tensors") + return False + + if not check_quantized_tens_scaling_equal(op.inputs[0], op.inputs[1]) or not check_quantized_tens_scaling_equal( + op.inputs[0], op.outputs[0] + ): + warn_cpu(op, "has input/output tensors with different quantisation which is illegal") return False return True diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 98dfa3d3..84af8edb 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -23,6 +23,7 @@ from functools import lru_cache import numpy as np from . import numeric_util +from .data_type import BaseType from .data_type import DataType from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .operation import Op @@ -229,11 +230,22 @@ class QuantizationParameters: return res def is_scaling_equal(self, other): - if other is None or not isinstance(other, QuantizationParameters): + # quantisation parameter scaling is not equal if 'other' is None because + # it implies that the tensor it belongs to is not quantised. otherwise, + # it depends upon whether the scale and zero point are equal + + if other is None: return False + assert isinstance(other, QuantizationParameters) + return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point + def is_valid(self): + # quantisation parameters are consider valid if they have a scale and zero point + + return None not in (self.scale_f32, self.zero_point) + def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None): # Tensor @@ -765,9 +777,6 @@ class Tensor: return True return False - def is_scaling_equal(self, tens): - return self.quantization.is_scaling_equal(tens.quantization) - def equivalent(self, tens): return self.equivalence_id == tens.equivalence_id @@ -785,7 +794,33 @@ class Tensor: else: return self.shape.copy() + def is_quantized(self): + # a tensor is quantized if it has an integral type and it contains valid quantization params + + if (self.dtype.type & BaseType.Int) == 0 or self.quantization is None: + return False + + assert isinstance(self.quantisation, QuantizationParameters) + assert self.quantization.is_valid() + + return True + def __str__(self): return "" % (self.name, self.shape, self.dtype) __repr__ = __str__ + + +def check_tens_quantized(tens): + # checks that a tensor is quantized + + return isinstance(tens, Tensor) and tens.is_quantized() + + +def check_quantized_tens_scaling_equal(tens_a, tens_b): + # checks that the scaling of two quantized tensors are equal + + assert check_tens_quantized(tens_a) + assert check_tens_quantized(tens_b) + + return tens_a.quantization.is_scaling_equal(tens_b.quantization) -- cgit v1.2.1