aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2022-03-17 14:06:00 +0100
committerLouis Verhaard <louis.verhaard@arm.com>2022-03-17 14:08:21 +0100
commit43d275875bb78163604ec116e06153e53d2fcbc1 (patch)
tree0c1abc4856681f3be60817e2ca72952d72f6c6b8
parent5c8f1e598748c0429a88aa35a1f12c731892f9b1 (diff)
downloadethos-u-vela-43d275875bb78163604ec116e06153e53d2fcbc1.tar.gz
MLBEDSW-5332: Bug fix optimise_strided_conv
Added check that horizontal padding is unaffected when applying graph optimization "optimise_strided_conv". Signed-off-by: Louis Verhaard <louis.verhaard@arm.com> Change-Id: I7032a44163e300cdf62cf615b4b10a1417e38eaa
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 3815eed..576ead0 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -532,19 +532,25 @@ def reorder_depthwise_weights(op, arch, nng):
def optimise_strided_conv(op, arch, nng):
+ if op.type != Op.Conv2DBias or op.op_index != 0:
+ return op
stride_x, stride_y = op.get_kernel_stride()
- ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
+ weight_tensor = op.weights
+ ifm_shape = op.ifm_shapes[0]
if (
- op.type == Op.Conv2DBias
- and op.op_index == 0
- and stride_x == 2
- and op.ifm_shapes[0].depth <= 4
- and op.ifm_shapes[0].width % 2 == 0
+ stride_x == 2
+ and ifm_shape.depth <= 4
+ and ifm_shape.width % 2 == 0
and weight_tensor is not None
and weight_tensor.shape[1] >= 2
):
- ifm_shape = op.ifm_shapes[0]
+ k_w, _ = op.get_kernel_size()
+ curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
+ optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
+ if curr_padding_x != optimised_padding_x:
+ # Horizontal padding would become different after optimisation; this would not work
+ return op
# IFM
op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])