From b4e804bb53aba48985abf3bf8466bc02310f60fc Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Thu, 7 Sep 2023 12:43:49 +0200 Subject: MLBEDSW-8010: Refine fixup_pool_strides to also check stride Only set stride to (1, 1) if kernel, stride and IFM shape all are equal. And also set padding to VALID to handle ops with SAME padding. Signed-off-by: Johan Gunnarsson Change-Id: Id3cc34686f09667ea21541fac432351555344e3d --- ethosu/vela/tflite_graph_optimiser.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index f686fea..2fb75e6 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -2340,15 +2340,19 @@ def convert_shape_op_to_constant_tensor(op: Operation, arch, nng): def fixup_pool_strides(op: Operation, arch, nng): - """Fixup Pool strides when the kernel size is equal to IFM shape. Stride is then irrelevant.""" + """Fixup Pool strides when the kernel size, IFM shape and stride are equal. Then stride can be changed + to (1, 1) and padding can be changed to VALID, so the strides are within the limits for the NPU.""" if op.type in (Op.AvgPool, Op.MaxPool, Op.QuantizedAvgPool, Op.QuantizedMaxPool): ifm, _ = op.get_ifm_ofm() kernel_w, kernel_h = op.get_kernel_size() - if kernel_w == ifm.shape[2] and kernel_h == ifm.shape[1]: - stride_n, _, _, stride_c = op.attrs["strides"] - op.attrs["strides"] = (stride_n, 1, 1, stride_c) + stride_w, stride_h = op.get_kernel_stride() + if kernel_w == stride_w == ifm.shape[2] and kernel_h == stride_h == ifm.shape[1]: + if "strides" in op.attrs: + stride_n, _, _, stride_c = op.attrs["strides"] + op.attrs["strides"] = (stride_n, 1, 1, stride_c) op.attrs["stride_w"] = 1 op.attrs["stride_h"] = 1 + op.attrs["padding"] = Padding.VALID return op -- cgit v1.2.1