diff options
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 218f499..9b98b8f 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -2328,6 +2328,20 @@ def convert_shape_op_to_constant_tensor(op: Operation, arch, nng): return op +def fixup_pool_strides(op: Operation, arch, nng): + """Fixup Pool strides when the kernel size is equal to IFM shape. Stride is then irrelevant.""" + if op.type.is_pool_op(): + ifm, _ = op.get_ifm_ofm() + kernel_w, kernel_h = op.get_kernel_size() + if kernel_w == ifm.shape[2] and kernel_h == ifm.shape[1]: + stride_n, _, _, stride_c = op.attrs["strides"] + op.attrs["strides"] = (stride_n, 1, 1, stride_c) + op.attrs["stride_w"] = 1 + op.attrs["stride_h"] = 1 + + return op + + def fixup_dilation_gt2(op: Operation, arch, nng) -> Operation: """Fixup Conv2DBias and DepthwiseConv2DBias to allow dilation greater than 2.""" assert op.run_on_npu @@ -2573,6 +2587,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): optimise_quantize, convert_shape_op_to_constant_tensor, fixup_or_check_asymmetric_weights(force_symmetric_int_weights), + fixup_pool_strides, ] for idx, sg in enumerate(nng.subgraphs): |