diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 21 |
1 files changed, 12 insertions, 9 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 44f5d6ae..518b6db0 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -942,19 +942,21 @@ def reorder_depthwise_weights(op, arch, nng): return op -def fixup_strided_conv(op, arch, nng): +def fixup_strided_conv(op: Operation, arch, nng): + """Optimize or fixup strided Conv2DBias + Optimization: + Reduce, when possible, the Conv2DBias stride from 2 to 1 by re-shaping + both IFM and filter. + + Fixup: + Introduce software support for Conv2DBias with stride_width = 4 by + reducing it to 1 when possible by re-shaping both IFM and filter. + """ if op.type != Op.Conv2DBias: return op stride_x, stride_y = op.get_kernel_stride() weight_tensor = op.weights ifm_shape = op.ifm_shapes[0] - - # Do not optimize if op is not the first in the network and stride is - # supported by the hardware - if op.op_index != 0 and stride_x < 4: - return op - op.ifm.needs_linear_format = True - if ( (stride_x == 2 or stride_x == 4) and ifm_shape.depth <= 4 @@ -1004,6 +1006,7 @@ def fixup_strided_conv(op, arch, nng): stride_x = 1 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)}) + op.ifm.force_linear_format = True return op @@ -2125,7 +2128,6 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): convert_prelu, convert_mul_max_to_abs_or_lrelu, convert_lrelu, - fixup_strided_conv, convert_hardswish_to_lut, rewrite_fully_connected_input, convert_batched_fc_shape, @@ -2139,6 +2141,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): convert_tanh_sigmoid_to_lut, replace_pad_by_hw_pad, fixup_dilation_gt2, + fixup_strided_conv, ] for idx, sg in enumerate(nng.subgraphs): |