aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 218f499..9b98b8f 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -2328,6 +2328,20 @@ def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
return op
+def fixup_pool_strides(op: Operation, arch, nng):
+ """Fixup Pool strides when the kernel size is equal to IFM shape. Stride is then irrelevant."""
+ if op.type.is_pool_op():
+ ifm, _ = op.get_ifm_ofm()
+ kernel_w, kernel_h = op.get_kernel_size()
+ if kernel_w == ifm.shape[2] and kernel_h == ifm.shape[1]:
+ stride_n, _, _, stride_c = op.attrs["strides"]
+ op.attrs["strides"] = (stride_n, 1, 1, stride_c)
+ op.attrs["stride_w"] = 1
+ op.attrs["stride_h"] = 1
+
+ return op
+
+
def fixup_dilation_gt2(op: Operation, arch, nng) -> Operation:
"""Fixup Conv2DBias and DepthwiseConv2DBias to allow dilation greater than 2."""
assert op.run_on_npu
@@ -2573,6 +2587,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
optimise_quantize,
convert_shape_op_to_constant_tensor,
fixup_or_check_asymmetric_weights(force_symmetric_int_weights),
+ fixup_pool_strides,
]
for idx, sg in enumerate(nng.subgraphs):