diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index f686fea..2fb75e6 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -2340,15 +2340,19 @@ def convert_shape_op_to_constant_tensor(op: Operation, arch, nng): def fixup_pool_strides(op: Operation, arch, nng): - """Fixup Pool strides when the kernel size is equal to IFM shape. Stride is then irrelevant.""" + """Fixup Pool strides when the kernel size, IFM shape and stride are equal. Then stride can be changed + to (1, 1) and padding can be changed to VALID, so the strides are within the limits for the NPU.""" if op.type in (Op.AvgPool, Op.MaxPool, Op.QuantizedAvgPool, Op.QuantizedMaxPool): ifm, _ = op.get_ifm_ofm() kernel_w, kernel_h = op.get_kernel_size() - if kernel_w == ifm.shape[2] and kernel_h == ifm.shape[1]: - stride_n, _, _, stride_c = op.attrs["strides"] - op.attrs["strides"] = (stride_n, 1, 1, stride_c) + stride_w, stride_h = op.get_kernel_stride() + if kernel_w == stride_w == ifm.shape[2] and kernel_h == stride_h == ifm.shape[1]: + if "strides" in op.attrs: + stride_n, _, _, stride_c = op.attrs["strides"] + op.attrs["strides"] = (stride_n, 1, 1, stride_c) op.attrs["stride_w"] = 1 op.attrs["stride_h"] = 1 + op.attrs["padding"] = Padding.VALID return op |