aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Gunnarsson <johan.gunnarsson@arm.com>2023-08-29 15:33:10 +0200
committerFredrik Svedberg <fredrik.svedberg@arm.com>2023-09-05 10:51:42 +0000
commit24570f098bd765842a4c9bc56e28f22295073467 (patch)
tree5eac0ddd12308b7aff7e3bd77111b7e48838983c
parent985563791a811e1ea3b5137f97e5a5fc4dafd4b1 (diff)
downloadethos-u-vela-24570f098bd765842a4c9bc56e28f22295073467.tar.gz
MLBEDSW-7968: Add fixup for strides when kernel size equals IFM shape
There are networks out there with Pool ops with filter (W, H) equals IFM (W, H) equals stride (W, H). The stride is technically too large for the NPU, but we can actually run these ops in the NPU since the filter is large enough the window doesn't slide. To support these ops we need to fix the stride so later checks don't put this op on CPU. Change-Id: I8f0a46b26fb94ee76c33748589536cc5ba07ea59 Signed-off-by: Johan Gunnarsson <johan.gunnarsson@arm.com>
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 218f499..9b98b8f 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -2328,6 +2328,20 @@ def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
return op
+def fixup_pool_strides(op: Operation, arch, nng):
+ """Fixup Pool strides when the kernel size is equal to IFM shape. Stride is then irrelevant."""
+ if op.type.is_pool_op():
+ ifm, _ = op.get_ifm_ofm()
+ kernel_w, kernel_h = op.get_kernel_size()
+ if kernel_w == ifm.shape[2] and kernel_h == ifm.shape[1]:
+ stride_n, _, _, stride_c = op.attrs["strides"]
+ op.attrs["strides"] = (stride_n, 1, 1, stride_c)
+ op.attrs["stride_w"] = 1
+ op.attrs["stride_h"] = 1
+
+ return op
+
+
def fixup_dilation_gt2(op: Operation, arch, nng) -> Operation:
"""Fixup Conv2DBias and DepthwiseConv2DBias to allow dilation greater than 2."""
assert op.run_on_npu
@@ -2573,6 +2587,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
optimise_quantize,
convert_shape_op_to_constant_tensor,
fixup_or_check_asymmetric_weights(force_symmetric_int_weights),
+ fixup_pool_strides,
]
for idx, sg in enumerate(nng.subgraphs):