aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Gunnarsson <johan.gunnarsson@arm.com>2023-09-07 12:43:49 +0200
committerJohan Alfven <johan.alfven@arm.com>2023-09-14 07:49:46 +0000
commitb4e804bb53aba48985abf3bf8466bc02310f60fc (patch)
treeda0c679fa2463a64920e1c1f0cbcac59e2ba4d3e
parent7ccc583c7cf36fc8bb8391594ee818263247a995 (diff)
downloadethos-u-vela-b4e804bb53aba48985abf3bf8466bc02310f60fc.tar.gz
MLBEDSW-8010: Refine fixup_pool_strides to also check stride
Only set stride to (1, 1) if kernel, stride and IFM shape all are equal. And also set padding to VALID to handle ops with SAME padding. Signed-off-by: Johan Gunnarsson <johan.gunnarsson@arm.com> Change-Id: Id3cc34686f09667ea21541fac432351555344e3d
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index f686fea..2fb75e6 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -2340,15 +2340,19 @@ def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
def fixup_pool_strides(op: Operation, arch, nng):
- """Fixup Pool strides when the kernel size is equal to IFM shape. Stride is then irrelevant."""
+ """Fixup Pool strides when the kernel size, IFM shape and stride are equal. Then stride can be changed
+ to (1, 1) and padding can be changed to VALID, so the strides are within the limits for the NPU."""
if op.type in (Op.AvgPool, Op.MaxPool, Op.QuantizedAvgPool, Op.QuantizedMaxPool):
ifm, _ = op.get_ifm_ofm()
kernel_w, kernel_h = op.get_kernel_size()
- if kernel_w == ifm.shape[2] and kernel_h == ifm.shape[1]:
- stride_n, _, _, stride_c = op.attrs["strides"]
- op.attrs["strides"] = (stride_n, 1, 1, stride_c)
+ stride_w, stride_h = op.get_kernel_stride()
+ if kernel_w == stride_w == ifm.shape[2] and kernel_h == stride_h == ifm.shape[1]:
+ if "strides" in op.attrs:
+ stride_n, _, _, stride_c = op.attrs["strides"]
+ op.attrs["strides"] = (stride_n, 1, 1, stride_c)
op.attrs["stride_w"] = 1
op.attrs["stride_h"] = 1
+ op.attrs["padding"] = Padding.VALID
return op