From 3e7157ba59f12aa0d277a9b3a7cb3f8a19267338 Mon Sep 17 00:00:00 2001 From: Raul Farkas Date: Tue, 9 May 2023 09:09:17 +0100 Subject: MLBEDSW-7315: Add support for AvgPool with stride_width > 3 * Convert AvgPool with stride_width > 3 and Valid padding to Conv2D to optimize it to run on NPU. Change-Id: I06ab412357f0b09b1498f9019a9d1963a324ad34 Signed-off-by: Raul Farkas --- SUPPORTED_OPS.md | 3 +- ethosu/vela/high_level_command_to_npu_op.py | 8 ++++ ethosu/vela/operation.py | 2 +- ethosu/vela/tflite_graph_optimiser.py | 59 ++++++++++++++++++++++++++++- ethosu/vela/tflite_supported_operators.py | 21 ++++++++-- 5 files changed, 85 insertions(+), 8 deletions(-) diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md index 947b585a..fdceb43c 100644 --- a/SUPPORTED_OPS.md +++ b/SUPPORTED_OPS.md @@ -134,7 +134,8 @@ This is a list of constraints that the AVERAGE_POOL_2D operator must satisfy in - Stride values for both width and height must be integer types - IFM and OFM data types must match - Kernel filter values for both width and height must be integer types -- Stride values for both width and height must be in the range [1, 3] +- Stride width must be greater than or equal to 1. + For stride width greater than 3, valid padding needs to be used. - Kernel filter values for both width and height must be in the range [1, 8] - VALID padding: Kernel filter height must be in the range [1, 256] - VALID padding: Product of kernel filter width and height must be in the range [1, 65536] diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py index 9526bd50..79ac3929 100644 --- a/ethosu/vela/high_level_command_to_npu_op.py +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -308,6 +308,14 @@ def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool: if tens.dtype == DataType.int32 and is_ifm_tensor: return True if ps.primary_op.rounding_mode == RoundingMode.AwayZero: + if ( + ps.primary_op.original_type == Op.AvgPool + and ps.primary_op.type == Op.Conv2DBias + and ps.primary_op.attrs.get("padding", None) == Padding.VALID + ): + # Force zero point to 0 for AveragePool operators converted to a Conv2DBias with rounding away from + # zero. + return True if ps.primary_op.original_type == Op.ResizeBilinear and ps.primary_op.type == Op.DepthwiseConv2DBias: # Force zero point to 0 for ResizeBilinear operators converted to a DepthwiseConv with rounding away from # zero. This is because the reference kernel ignores the zero points. diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 3685c5ae..998d94ff 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -615,7 +615,7 @@ class Operation: is_supported = False if self.original_type == Op.ResizeBilinear and self.type == Op.DepthwiseConv2DBias: is_supported = True - if self.original_type == Op.AvgPool and self.type == Op.DepthwiseConv2DBias: + if self.original_type == Op.AvgPool and self.type in (Op.DepthwiseConv2DBias, Op.Conv2DBias): is_supported = True if is_supported: diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 76383a4b..78906374 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -949,7 +949,57 @@ def reorder_depthwise_weights(op, arch, nng): return op -def fixup_strided_conv(op: Operation, arch, nng) -> Operation: +def convert_avg_pool_to_conv2d(op: Operation, arch, nng) -> Operation: + """Convert strided Average Pools with stride >= 4 to Conv2D.""" + if op.type != Op.AvgPool: + return op + + stride_x, stride_y = op.get_kernel_stride() + # For strides <= 3 no optimization is needed + if stride_x <= 3: + return op + h, w = op.attrs["filter_height"], op.attrs["filter_width"] + inputs = op.inputs[0] + shape = inputs.shape + + # Set necessary conv2d attributes + op.attrs.update( + { + "stride_h": stride_y, + "stride_w": stride_x, + "dilation_h_factor": 1, + "dilation_w_factor": 1, + "strides": (1, stride_y, stride_x, 1), + "dilation": (1, 1, 1, 1), + } + ) + + # Change op type + op.type = Op.Conv2DBias + op.name += "_conv2d" + + op.rounding_mode = RoundingMode.AwayZero + shape = [h, w, 1, op.ofm.shape[-1]] + weights = np.full(shape, 1) + quant = QuantizationParameters(scale_f32=1 / (h * w), zero_point=0) + # Add unit weight tensor + op.add_input_tensor( + create_const_tensor( + "weights", + shape, + inputs.dtype, + weights, + quantization=quant, + ), + ) + op.weights.values = np.reshape(op.inputs[1].values, shape) + + # Set IFM/OFM shapes after changing op type + op.set_ifm_ofm_shapes() + return op + + +def fixup_strided_conv(op: Operation, arch, nng): """Optimize or fixup strided Conv2DBias Optimization: Reduce, when possible, the Conv2DBias stride from N with 1 > N > 4 to 1 @@ -1853,7 +1903,11 @@ def fixup_bias_tensors(op, arch, nng, dtype=None): if dtype is None: dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values) - op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0]) + bias_index = op.type.info.indices.biases[0] + if bias_index < len(op.inputs): + op.set_input_tensor(bias_tensor, bias_index) + else: + op.add_input_tensor(bias_tensor) return op @@ -2349,6 +2403,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): convert_prelu, convert_mul_max_to_abs_or_lrelu, convert_lrelu, + convert_avg_pool_to_conv2d, fixup_strided_conv, convert_hardswish_to_lut, rewrite_fully_connected_input, diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 25b68970..a24eebc5 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -220,7 +220,7 @@ class TFLiteSupportedOperators: # Conv specific ops: for op_type in TFLiteSupportedOperators.convolution_ops: - self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_conv_stride) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_width_no_upper_limit) # Conv-like checks: for op_type in TFLiteSupportedOperators.convolution_like_ops: @@ -244,10 +244,11 @@ class TFLiteSupportedOperators: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_depth_multiplier) # Pooling checks: - for op_type in TFLiteSupportedOperators.pooling_ops: + for op_type in TFLiteSupportedOperators.pooling_ops - TFLiteSupportedOperators.avg_pooling_ops: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range) # AVG pooling specific checks: for op_type in TFLiteSupportedOperators.avg_pooling_ops: + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range_no_padding) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_range) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_height_range_valid_pad) self.specific_constraints[op_type].append( @@ -545,7 +546,7 @@ class TFLiteSupportedOperators: return True, "Op has depth_multiplier=1" @staticmethod - def constraint_conv_stride(op): + def constraint_stride_width_no_upper_limit(op): """Stride width must be greater than or equal to 1. For stride widths greater than 3, the post-optimization stride needs to be less than or equal to 3. Stride height must be between 1 and 3.""" @@ -560,6 +561,17 @@ class TFLiteSupportedOperators: return valid, f"Op has stride WxH as: {w}x{h}" + @staticmethod + def constraint_stride_range_no_padding(op): + """Stride width must be greater than or equal to 1. + For stride width greater than 3, valid padding needs to be used.""" + w, _ = op.get_kernel_stride() + valid, message = TFLiteSupportedOperators.constraint_stride_width_no_upper_limit(op) + padding = op.attrs.get("padding", None) + is_optimized_with_valid_padding = padding in (None, Padding.VALID) or w <= 3 + valid = valid and is_optimized_with_valid_padding + return valid, f"{message}, padding: {padding}" + @staticmethod def constraint_depthwise_conv_stride(op): "Stride values for both width and height must be between 1 and 3" @@ -614,10 +626,11 @@ class TFLiteSupportedOperators: def constraint_filter_range(cls, op): "Kernel filter values for both width and height must be in the range [{}, {}]" if op.attrs["padding"] == Padding.SAME: + sw, _ = op.get_kernel_stride() w = op.kernel.width h = op.kernel.height filter_min, filter_max = cls.filter_range - valid = (filter_min <= w <= filter_max) and (filter_min <= h <= filter_max) + valid = ((filter_min <= w <= filter_max) or sw == w) and (filter_min <= h <= filter_max) return valid, f"Op has kernel filter WxH as: {w}x{h}" return True, "Op has padding=VALID" -- cgit v1.2.1