aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index f686fea..2fb75e6 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -2340,15 +2340,19 @@ def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
def fixup_pool_strides(op: Operation, arch, nng):
- """Fixup Pool strides when the kernel size is equal to IFM shape. Stride is then irrelevant."""
+ """Fixup Pool strides when the kernel size, IFM shape and stride are equal. Then stride can be changed
+ to (1, 1) and padding can be changed to VALID, so the strides are within the limits for the NPU."""
if op.type in (Op.AvgPool, Op.MaxPool, Op.QuantizedAvgPool, Op.QuantizedMaxPool):
ifm, _ = op.get_ifm_ofm()
kernel_w, kernel_h = op.get_kernel_size()
- if kernel_w == ifm.shape[2] and kernel_h == ifm.shape[1]:
- stride_n, _, _, stride_c = op.attrs["strides"]
- op.attrs["strides"] = (stride_n, 1, 1, stride_c)
+ stride_w, stride_h = op.get_kernel_stride()
+ if kernel_w == stride_w == ifm.shape[2] and kernel_h == stride_h == ifm.shape[1]:
+ if "strides" in op.attrs:
+ stride_n, _, _, stride_c = op.attrs["strides"]
+ op.attrs["strides"] = (stride_n, 1, 1, stride_c)
op.attrs["stride_w"] = 1
op.attrs["stride_h"] = 1
+ op.attrs["padding"] = Padding.VALID
return op