aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2024-05-13 13:44:42 +0200
committerJohan Alfven <johan.alfven@arm.com>2024-05-15 14:51:05 +0200
commit891468561ecfc61d27adcdc92b41ec216eaa1b08 (patch)
tree5302e1e70549122d0c39c581a34db876010eb23f /ethosu/vela
parentf49370003956d4f6f7d177114a68edb07b364fe9 (diff)
downloadethos-u-vela-891468561ecfc61d27adcdc92b41ec216eaa1b08.tar.gz
MLBEDSW-9067: MLCE: Group Avgpool ops for concat
- Concat is implemented by several avgpool ops, all of them writing to the same ofm but with a slice offset. If a compiled network contains cpu fallbacks the avgpool ops might end up running in different custom ops. This works fine as long as the runtime provides the same scratch area. If not the output from the concat might be corrupt. - This fix adds an extra step to the pass packing so that all avgpool ops for a concat is group together and run within the same custom op in order to prevent possible corruption. Change-Id: I343e08d7b4046f969b3d9ec3479db6490cbe4170 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela')
-rw-r--r--ethosu/vela/graph_optimiser_util.py2
-rw-r--r--ethosu/vela/pass_packing.py19
2 files changed, 20 insertions, 1 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 46762e4d..8a94f361 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -348,6 +348,8 @@ def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_o
"""Creates an average pool for the given concat op/input feature map"""
ofm = concat_op.ofm
avgpool_op = create_avgpool_nop(name)
+ # Enforce original type since this is used in pass packing to group concat ops
+ avgpool_op._original_type = concat_op.type
avgpool_op.inputs = [ifm]
avgpool_op.outputs = [ofm]
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 0de0341d..66f7ffb3 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -524,6 +524,23 @@ def pack_into_passes(nng, arch, verbose_packing=False):
# Sort ops by op_index (same call order as in the original graph)
pass_list_top = sorted(pass_list_top, key=lambda ps: -1 if ps.ops[0].op_index is None else ps.ops[0].op_index)
+ # A concat is implemented by several AvgPool ops writing to the same ofm but with slice offset
+ # Group all AvgPool ops for a concat so that they run in one sequence (within the same cmd stream)
+ last_idx = len(pass_list) - 1
+ for npu_ps in reversed(pass_list):
+ if npu_ps.placement == PassPlacement.Cpu or not npu_ps.ops[0].original_type.is_concat_op():
+ continue
+ # Concat pass found, search forward for the next avgpool op writing to the same ofm
+ idx = pass_list.index(npu_ps)
+ for next_ps in pass_list[idx + 1 :]:
+ next_is_concat = next_ps.ops[0].original_type.is_concat_op()
+ if next_is_concat and next_ps.ops[0].ofm == npu_ps.ops[0].ofm:
+ # Avgpool writing to the same OFM, group them
+ pass_list.remove(npu_ps)
+ insert_index = pass_list.index(next_ps)
+ pass_list.insert(insert_index, npu_ps)
+ break
+
# Sort the rest of the list based on critera 2.
# Search from bottom of list and when a CPU pass is found
# search forward in the list and see if it is possible to join another CPU pass.