diff options
author | Johan Gunnarsson <johan.gunnarsson@arm.com> | 2023-09-07 12:43:49 +0200 |
---|---|---|
committer | Johan Alfven <johan.alfven@arm.com> | 2023-09-14 07:49:46 +0000 |
commit | b4e804bb53aba48985abf3bf8466bc02310f60fc (patch) | |
tree | da0c679fa2463a64920e1c1f0cbcac59e2ba4d3e | |
parent | 7ccc583c7cf36fc8bb8391594ee818263247a995 (diff) | |
download | ethos-u-vela-b4e804bb53aba48985abf3bf8466bc02310f60fc.tar.gz |
MLBEDSW-8010: Refine fixup_pool_strides to also check stride
Only set stride to (1, 1) if kernel, stride and IFM shape all are
equal. And also set padding to VALID to handle ops with SAME padding.
Signed-off-by: Johan Gunnarsson <johan.gunnarsson@arm.com>
Change-Id: Id3cc34686f09667ea21541fac432351555344e3d
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index f686feaf..2fb75e61 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -2340,15 +2340,19 @@ def convert_shape_op_to_constant_tensor(op: Operation, arch, nng): def fixup_pool_strides(op: Operation, arch, nng): - """Fixup Pool strides when the kernel size is equal to IFM shape. Stride is then irrelevant.""" + """Fixup Pool strides when the kernel size, IFM shape and stride are equal. Then stride can be changed + to (1, 1) and padding can be changed to VALID, so the strides are within the limits for the NPU.""" if op.type in (Op.AvgPool, Op.MaxPool, Op.QuantizedAvgPool, Op.QuantizedMaxPool): ifm, _ = op.get_ifm_ofm() kernel_w, kernel_h = op.get_kernel_size() - if kernel_w == ifm.shape[2] and kernel_h == ifm.shape[1]: - stride_n, _, _, stride_c = op.attrs["strides"] - op.attrs["strides"] = (stride_n, 1, 1, stride_c) + stride_w, stride_h = op.get_kernel_stride() + if kernel_w == stride_w == ifm.shape[2] and kernel_h == stride_h == ifm.shape[1]: + if "strides" in op.attrs: + stride_n, _, _, stride_c = op.attrs["strides"] + op.attrs["strides"] = (stride_n, 1, 1, stride_c) op.attrs["stride_w"] = 1 op.attrs["stride_h"] = 1 + op.attrs["padding"] = Padding.VALID return op |