diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 59 |
1 files changed, 57 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 76383a4b..78906374 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -949,7 +949,57 @@ def reorder_depthwise_weights(op, arch, nng): return op -def fixup_strided_conv(op: Operation, arch, nng) -> Operation: +def convert_avg_pool_to_conv2d(op: Operation, arch, nng) -> Operation: + """Convert strided Average Pools with stride >= 4 to Conv2D.""" + if op.type != Op.AvgPool: + return op + + stride_x, stride_y = op.get_kernel_stride() + # For strides <= 3 no optimization is needed + if stride_x <= 3: + return op + h, w = op.attrs["filter_height"], op.attrs["filter_width"] + inputs = op.inputs[0] + shape = inputs.shape + + # Set necessary conv2d attributes + op.attrs.update( + { + "stride_h": stride_y, + "stride_w": stride_x, + "dilation_h_factor": 1, + "dilation_w_factor": 1, + "strides": (1, stride_y, stride_x, 1), + "dilation": (1, 1, 1, 1), + } + ) + + # Change op type + op.type = Op.Conv2DBias + op.name += "_conv2d" + + op.rounding_mode = RoundingMode.AwayZero + shape = [h, w, 1, op.ofm.shape[-1]] + weights = np.full(shape, 1) + quant = QuantizationParameters(scale_f32=1 / (h * w), zero_point=0) + # Add unit weight tensor + op.add_input_tensor( + create_const_tensor( + "weights", + shape, + inputs.dtype, + weights, + quantization=quant, + ), + ) + op.weights.values = np.reshape(op.inputs[1].values, shape) + + # Set IFM/OFM shapes after changing op type + op.set_ifm_ofm_shapes() + return op + + +def fixup_strided_conv(op: Operation, arch, nng): """Optimize or fixup strided Conv2DBias Optimization: Reduce, when possible, the Conv2DBias stride from N with 1 > N > 4 to 1 @@ -1853,7 +1903,11 @@ def fixup_bias_tensors(op, arch, nng, dtype=None): if dtype is None: dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values) - op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0]) + bias_index = op.type.info.indices.biases[0] + if bias_index < len(op.inputs): + op.set_input_tensor(bias_tensor, bias_index) + else: + op.add_input_tensor(bias_tensor) return op @@ -2349,6 +2403,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): convert_prelu, convert_mul_max_to_abs_or_lrelu, convert_lrelu, + convert_avg_pool_to_conv2d, fixup_strided_conv, convert_hardswish_to_lut, rewrite_fully_connected_input, |