diff options
author | Raul Farkas <raul.farkas@arm.com> | 2023-03-16 16:38:05 +0000 |
---|---|---|
committer | Raul Farkas <raul.farkas@arm.com> | 2023-03-27 16:35:56 +0100 |
commit | 72c6a2414205e033279f80b622cdf479c05a4f5b (patch) | |
tree | 35dedce67cedd2fe5533cf0beb2942a7f31199e3 /ethosu/vela/tflite_graph_optimiser.py | |
parent | 430002df36f79d035e31e8304fb8b176129cd3cc (diff) | |
download | ethos-u-vela-72c6a2414205e033279f80b622cdf479c05a4f5b.tar.gz |
MLBEDSW-6343: Remove op_index constraint
Remove op_index constraint and force linear format for all Conv2D that
have strides that can be optimised.
Change-Id: Idef3508ab074ea9abeacac030eaaa15a00ad1211
Signed-off-by: Raul Farkas <raul.farkas@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 21 |
1 files changed, 12 insertions, 9 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 44f5d6ae..518b6db0 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -942,19 +942,21 @@ def reorder_depthwise_weights(op, arch, nng): return op -def fixup_strided_conv(op, arch, nng): +def fixup_strided_conv(op: Operation, arch, nng): + """Optimize or fixup strided Conv2DBias + Optimization: + Reduce, when possible, the Conv2DBias stride from 2 to 1 by re-shaping + both IFM and filter. + + Fixup: + Introduce software support for Conv2DBias with stride_width = 4 by + reducing it to 1 when possible by re-shaping both IFM and filter. + """ if op.type != Op.Conv2DBias: return op stride_x, stride_y = op.get_kernel_stride() weight_tensor = op.weights ifm_shape = op.ifm_shapes[0] - - # Do not optimize if op is not the first in the network and stride is - # supported by the hardware - if op.op_index != 0 and stride_x < 4: - return op - op.ifm.needs_linear_format = True - if ( (stride_x == 2 or stride_x == 4) and ifm_shape.depth <= 4 @@ -1004,6 +1006,7 @@ def fixup_strided_conv(op, arch, nng): stride_x = 1 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)}) + op.ifm.force_linear_format = True return op @@ -2125,7 +2128,6 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): convert_prelu, convert_mul_max_to_abs_or_lrelu, convert_lrelu, - fixup_strided_conv, convert_hardswish_to_lut, rewrite_fully_connected_input, convert_batched_fc_shape, @@ -2139,6 +2141,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): convert_tanh_sigmoid_to_lut, replace_pad_by_hw_pad, fixup_dilation_gt2, + fixup_strided_conv, ] for idx, sg in enumerate(nng.subgraphs): |