From 885033b5bf2f6513b438f273b2bc71964f0c6c59 Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Thu, 21 Jul 2022 11:46:03 +0100 Subject: MLBEDSW-4157: Add RESIZE_NEAREST_NEIGHBOR support - Changed ResizeBilinear to support ResizeNearestNeighbor as well for 1x1 IFM, IFM equal OFM, and non-align corners - Added support for ResizeNearestNeighbor with align corners by converting to a DepthwiseConv - Updated supported operator unit tests - Added is_resize() helper function and some associated refactoring Signed-off-by: Tim Hall Change-Id: Id5bdf2a25e8aa6a4f28b7236250abf768141ce37 --- ethosu/vela/tflite_graph_optimiser.py | 172 ++++++++++++++++++++++++---------- 1 file changed, 125 insertions(+), 47 deletions(-) (limited to 'ethosu/vela/tflite_graph_optimiser.py') diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index d2899c4c..ed8fa1e3 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -279,10 +279,9 @@ def fixup_conv2d_backprop(op, arch, nng): # Convert the op to an elementwise add -def convert_resizebilinear_1x1_to_add(op): - op.type = Op.Add +def convert_resize_1x1_to_add(op): + op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor op.name = op.name + "_add" - op.attrs["resizebilinear"] = True # Create an input tensor filled with zeros shape = op.ofm_shapes[0].as_list() tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add") @@ -301,12 +300,103 @@ def convert_resizebilinear_1x1_to_add(op): return op -# Convert ResizeBilinear to a number of 2x2 nearest neighbor upscaling and one avgpool op with kernel size dependent -# on the upscaling factor. Avgpool kernel limit of 8x8 when padding is applied limits upscaling to 8x8. -def convert_resizebilinear_to_upscale_and_average_pool(op): +# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled +# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient +# to select the appropriate nearest neighbor value +def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor): + ifm = op.ifm + ofm = op.ofm + output_depth = ofm.shape[-1] + dw_op_attrs = { + "padding": Padding.VALID, + "stride_h": 1, + "stride_w": 1, + "strides": (1, 1, 1, 1), + "depth_multiplier": 1, + "channel_multiplier": 1, + "dilation_h_factor": 1, + "dilation_w_factor": 1, + "dilation": (1, 1, 1, 1), + } + + # change resizebilinear to depthwise + op.type = Op.DepthwiseConv2DBias + op.attrs.update(dw_op_attrs) + op.set_input_tensor(ifm, 0) # ifm tensor index + op.activation = None + + # add input resample to resize by x2 + op.ifm_resampling_mode = resampling_mode.NEAREST + + # don't care about the rounding mode as it is nearest neighbor + + # setup weight tensor + weight_quant = QuantizationParameters() + weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value + weight_quant.zero_point = 0 + weight_quant.quant_dim = 0 + ofm_dtype = ofm.dtype + if ofm_dtype == DataType.uint8: + weight_value_dtype = np.uint8 + weight_quant.quant_min = 0 + weight_quant.quant_max = (1 << ofm_dtype.bits) - 1 + else: + if ofm_dtype == DataType.int8: + weight_value_dtype = np.int8 + else: + assert ofm_dtype == DataType.int16 + weight_value_dtype = np.int16 + + weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1)) + weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1 + + weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO + + # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which + # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is + # below-and-right (i.e. next) to it (D). + # 0---1---2 + # | A | B | + # 1---*---+ + # | C | D | + # 2---+---+ + weight_values = [0] * (upscale_factor * upscale_factor) + centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2) + weight_values[centre_coeff] = 1 + + # add weight tensor, this will discard the size tensor of the resize op + op.set_input_tensor( + create_const_tensor( + "weights", + weight_shape, + ofm.dtype, + np.array(weight_values).reshape(weight_shape), + value_dtype=weight_value_dtype, + quantization=weight_quant, + ), + 1, # inputs tensor weight index + ) + + # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor. + # need to append the bias tensor as resize ops only have 2 inputs + assert len(op.inputs) == 2 + op.inputs.append(None) + fixup_bias_tensors(op, None, None) + + # finally update the shape incase we've change the tensor shapes or connections + op.set_ifm_ofm_shapes() + + return op + + +# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one +# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum +# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding. +def convert_resize_to_upscale_and_average_pool(op): pre_op = op outputs = op.outputs dtype = op.ifm.dtype + op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)}) op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1 op.ifm_resampling_mode = resampling_mode.NEAREST @@ -321,14 +411,14 @@ def convert_resizebilinear_to_upscale_and_average_pool(op): # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral n = int(np.log2(upscale_factor)) - # Perform 2x2 upscaling n-1 times + # Perform x2 upscaling n-1 times scaled_op = pre_op for count in range(n - 1): if count > 0: scaled_op = op.clone(f"_{count}") scaled_op.inputs[0] = pre_op.outputs[0] - # Nearest neighbor 2x2 upscaling + # Nearest neighbor x2 upscaling upscaled_shape = upscaled_shape * 2 shape = op.ofm_shapes[0].as_list() shape[1:3] = upscaled_shape @@ -339,17 +429,30 @@ def convert_resizebilinear_to_upscale_and_average_pool(op): scaled_op.set_ifm_ofm_shapes() - # Last 2x2 upscaling also applies avgpool with kernel size dependent on the upscaling factor and adds - # padding to the right and bottom. + # Last x2 upscaling if n > 1: scaled_op = op.clone(f"_{n-1}") scaled_op.inputs[0] = pre_op.outputs[0] - if op.attrs["align_corners"]: - scaled_op.attrs["padding"] = Padding.VALID - else: - scaled_op.attrs["padding"] = Padding.EXPLICIT - scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1] - scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)}) + + if scaled_op.original_type == Op.ResizeBilinear: + if scaled_op.attrs["align_corners"]: + # no padding + scaled_op.attrs["padding"] = Padding.VALID + else: + # padding to the right and bottom (limits average pool to 8x8 kernel) + scaled_op.attrs["padding"] = Padding.EXPLICIT + scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1] + + # kernal size dependent on the upscaling factor + scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)}) + else: # Op.ResizeNearestNeighbor + if scaled_op.attrs["align_corners"]: + # use depthwise conv to select the correct value + scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor) + else: + # keep 1x1 kernel and average pool + pass + scaled_op.outputs = outputs scaled_op.outputs[0].ops = [scaled_op] scaled_op.set_ifm_ofm_shapes() @@ -357,16 +460,16 @@ def convert_resizebilinear_to_upscale_and_average_pool(op): return op -def fixup_resizebilinear(op, arch, nng): - if op.type == Op.ResizeBilinear and op.run_on_npu: +def fixup_resize(op, arch, nng): + if op.type.is_resize_op() and op.run_on_npu: if op.ifm_shapes[0] == op.ofm_shapes[0]: - # Bypass nop resizebilinear + # Bypass the resize op which is essentially a NOP op.inputs = op.inputs[:1] op.type = Op.Identity elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1: - convert_resizebilinear_1x1_to_add(op) + convert_resize_1x1_to_add(op) else: - convert_resizebilinear_to_upscale_and_average_pool(op) + convert_resize_to_upscale_and_average_pool(op) return op @@ -1130,31 +1233,6 @@ def convert_pad(op: Operation, arch, nng): return avgpool_op -def add_attrs_to_resizebilinear(op, arch, nng): - if op.type == Op.ResizeBilinear and op.run_on_npu: - input_shape = op.ifm_shapes[0] - upscaled_height = input_shape.height * 2 - upscaled_width = input_shape.width * 2 - out_shape = op.ofm_shapes[0] - if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width: - # this means the output is supposed to be a x2 upscale, - # so we need to do SAME padding - op.attrs["padding"] = Padding.SAME - elif ( - op.attrs["align_corners"] - and out_shape.height == (upscaled_height - 1) - and out_shape.width == (upscaled_width - 1) - ): - # here we can just run the avg pool without padding and - # produce a (M * 2 - 1, N * 2 - 1) sized output - op.attrs["padding"] = Padding.VALID - else: - return op - op.ifm_resampling_mode = resampling_mode.NEAREST - op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)}) - return op - - def fixup_bias_tensors(op, arch, nng): if op.type.needs_bias() and op.bias is None: # Op has no bias, add bias tensor filled with zeros @@ -1577,7 +1655,7 @@ def tflite_optimise_graph(nng, arch): fixup_conv2d_backprop, fixup_relus_with_differing_ifm_ofm_scaling, reorder_depthwise_weights, - fixup_resizebilinear, + fixup_resize, fixup_bias_tensors, fixup_asymmetric_weights, convert_mul_max_to_abs_or_lrelu, -- cgit v1.2.1