aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2023-05-09 10:39:52 +0100
committerRickard Bolin <rickard.bolin@arm.com>2023-05-10 10:38:23 +0000
commit69782af3ff2cda96dff09ad66799b3ac8f16c19d (patch)
treeb2a06c927be4c20c60a624bf8ae3626f8137f387
parentac111101cccbdeddb4bef91a0a7b142e33c365c1 (diff)
downloadethos-u-vela-69782af3ff2cda96dff09ad66799b3ac8f16c19d.tar.gz
Revert "MLBEDSW-6343: Remove op_index constraint"
This reverts commit 72c6a2414205e033279f80b622cdf479c05a4f5b. Reason for revert: Fix performance regression caused by breaking cascades in certain models Change-Id: I5aba6e3c59ab27c5129f4a3f0c320ed18df78943 Signed-off-by: Raul Farkas <raul.farkas@arm.com>
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 1b70165e..07f65a44 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -964,6 +964,12 @@ def fixup_strided_conv(op: Operation, arch, nng):
stride_x, stride_y = op.get_kernel_stride()
weight_tensor = op.weights
ifm_shape = op.ifm_shapes[0]
+
+ # Do not optimize if op is not the first in the network and stride is
+ # supported by the hardware
+ if op.op_index != 0 and stride_x < 4:
+ return op
+
if (
(stride_x == 2 or stride_x == 4)
and ifm_shape.depth <= 4
@@ -1013,7 +1019,6 @@ def fixup_strided_conv(op: Operation, arch, nng):
stride_x = 1
op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
- op.ifm.force_linear_format = True
return op
@@ -2238,6 +2243,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
convert_prelu,
convert_mul_max_to_abs_or_lrelu,
convert_lrelu,
+ fixup_strided_conv,
convert_hardswish_to_lut,
rewrite_fully_connected_input,
convert_batched_fc_shape,
@@ -2251,7 +2257,6 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
convert_tanh_sigmoid_to_lut,
replace_pad_by_hw_pad,
fixup_dilation_gt2,
- fixup_strided_conv,
]
for idx, sg in enumerate(nng.subgraphs):