diff options
author | Johan Gunnarsson <johan.gunnarsson@arm.com> | 2023-08-29 15:33:10 +0200 |
---|---|---|
committer | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2023-09-05 10:51:42 +0000 |
commit | 24570f098bd765842a4c9bc56e28f22295073467 (patch) | |
tree | 5eac0ddd12308b7aff7e3bd77111b7e48838983c | |
parent | 985563791a811e1ea3b5137f97e5a5fc4dafd4b1 (diff) | |
download | ethos-u-vela-24570f098bd765842a4c9bc56e28f22295073467.tar.gz |
MLBEDSW-7968: Add fixup for strides when kernel size equals IFM shape
There are networks out there with Pool ops with filter (W, H) equals
IFM (W, H) equals stride (W, H). The stride is technically too large
for the NPU, but we can actually run these ops in the NPU since the
filter is large enough the window doesn't slide. To support these ops
we need to fix the stride so later checks don't put this op on CPU.
Change-Id: I8f0a46b26fb94ee76c33748589536cc5ba07ea59
Signed-off-by: Johan Gunnarsson <johan.gunnarsson@arm.com>
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 218f499a..9b98b8fa 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -2328,6 +2328,20 @@ def convert_shape_op_to_constant_tensor(op: Operation, arch, nng): return op +def fixup_pool_strides(op: Operation, arch, nng): + """Fixup Pool strides when the kernel size is equal to IFM shape. Stride is then irrelevant.""" + if op.type.is_pool_op(): + ifm, _ = op.get_ifm_ofm() + kernel_w, kernel_h = op.get_kernel_size() + if kernel_w == ifm.shape[2] and kernel_h == ifm.shape[1]: + stride_n, _, _, stride_c = op.attrs["strides"] + op.attrs["strides"] = (stride_n, 1, 1, stride_c) + op.attrs["stride_w"] = 1 + op.attrs["stride_h"] = 1 + + return op + + def fixup_dilation_gt2(op: Operation, arch, nng) -> Operation: """Fixup Conv2DBias and DepthwiseConv2DBias to allow dilation greater than 2.""" assert op.run_on_npu @@ -2573,6 +2587,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): optimise_quantize, convert_shape_op_to_constant_tensor, fixup_or_check_asymmetric_weights(force_symmetric_int_weights), + fixup_pool_strides, ] for idx, sg in enumerate(nng.subgraphs): |