aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 3815eedd..576ead03 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -532,19 +532,25 @@ def reorder_depthwise_weights(op, arch, nng):
def optimise_strided_conv(op, arch, nng):
+ if op.type != Op.Conv2DBias or op.op_index != 0:
+ return op
stride_x, stride_y = op.get_kernel_stride()
- ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
+ weight_tensor = op.weights
+ ifm_shape = op.ifm_shapes[0]
if (
- op.type == Op.Conv2DBias
- and op.op_index == 0
- and stride_x == 2
- and op.ifm_shapes[0].depth <= 4
- and op.ifm_shapes[0].width % 2 == 0
+ stride_x == 2
+ and ifm_shape.depth <= 4
+ and ifm_shape.width % 2 == 0
and weight_tensor is not None
and weight_tensor.shape[1] >= 2
):
- ifm_shape = op.ifm_shapes[0]
+ k_w, _ = op.get_kernel_size()
+ curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
+ optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
+ if curr_padding_x != optimised_padding_x:
+ # Horizontal padding would become different after optimisation; this would not work
+ return op
# IFM
op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])