diff options
author | Raul Farkas <raul.farkas@arm.com> | 2023-05-09 10:39:52 +0100 |
---|---|---|
committer | Rickard Bolin <rickard.bolin@arm.com> | 2023-05-10 10:38:23 +0000 |
commit | 69782af3ff2cda96dff09ad66799b3ac8f16c19d (patch) | |
tree | b2a06c927be4c20c60a624bf8ae3626f8137f387 /ethosu | |
parent | ac111101cccbdeddb4bef91a0a7b142e33c365c1 (diff) | |
download | ethos-u-vela-69782af3ff2cda96dff09ad66799b3ac8f16c19d.tar.gz |
Revert "MLBEDSW-6343: Remove op_index constraint"
This reverts commit 72c6a2414205e033279f80b622cdf479c05a4f5b.
Reason for revert: Fix performance regression caused by breaking cascades in certain models
Change-Id: I5aba6e3c59ab27c5129f4a3f0c320ed18df78943
Signed-off-by: Raul Farkas <raul.farkas@arm.com>
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 1b70165e..07f65a44 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -964,6 +964,12 @@ def fixup_strided_conv(op: Operation, arch, nng): stride_x, stride_y = op.get_kernel_stride() weight_tensor = op.weights ifm_shape = op.ifm_shapes[0] + + # Do not optimize if op is not the first in the network and stride is + # supported by the hardware + if op.op_index != 0 and stride_x < 4: + return op + if ( (stride_x == 2 or stride_x == 4) and ifm_shape.depth <= 4 @@ -1013,7 +1019,6 @@ def fixup_strided_conv(op: Operation, arch, nng): stride_x = 1 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)}) - op.ifm.force_linear_format = True return op @@ -2238,6 +2243,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): convert_prelu, convert_mul_max_to_abs_or_lrelu, convert_lrelu, + fixup_strided_conv, convert_hardswish_to_lut, rewrite_fully_connected_input, convert_batched_fc_shape, @@ -2251,7 +2257,6 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): convert_tanh_sigmoid_to_lut, replace_pad_by_hw_pad, fixup_dilation_gt2, - fixup_strided_conv, ] for idx, sg in enumerate(nng.subgraphs): |