diff options
author | Louis Verhaard <louis.verhaard@arm.com> | 2022-03-17 14:06:00 +0100 |
---|---|---|
committer | Louis Verhaard <louis.verhaard@arm.com> | 2022-03-17 14:08:21 +0100 |
commit | 43d275875bb78163604ec116e06153e53d2fcbc1 (patch) | |
tree | 0c1abc4856681f3be60817e2ca72952d72f6c6b8 /ethosu/vela | |
parent | 5c8f1e598748c0429a88aa35a1f12c731892f9b1 (diff) | |
download | ethos-u-vela-43d275875bb78163604ec116e06153e53d2fcbc1.tar.gz |
MLBEDSW-5332: Bug fix optimise_strided_conv
Added check that horizontal padding is unaffected when applying
graph optimization "optimise_strided_conv".
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Change-Id: I7032a44163e300cdf62cf615b4b10a1417e38eaa
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 3815eedd..576ead03 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -532,19 +532,25 @@ def reorder_depthwise_weights(op, arch, nng): def optimise_strided_conv(op, arch, nng): + if op.type != Op.Conv2DBias or op.op_index != 0: + return op stride_x, stride_y = op.get_kernel_stride() - ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm() + weight_tensor = op.weights + ifm_shape = op.ifm_shapes[0] if ( - op.type == Op.Conv2DBias - and op.op_index == 0 - and stride_x == 2 - and op.ifm_shapes[0].depth <= 4 - and op.ifm_shapes[0].width % 2 == 0 + stride_x == 2 + and ifm_shape.depth <= 4 + and ifm_shape.width % 2 == 0 and weight_tensor is not None and weight_tensor.shape[1] >= 2 ): - ifm_shape = op.ifm_shapes[0] + k_w, _ = op.get_kernel_size() + curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w) + optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2) + if curr_padding_x != optimised_padding_x: + # Horizontal padding would become different after optimisation; this would not work + return op # IFM op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2]) |