aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/pass_packing.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 66f7ffb3..f157e67b 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -525,17 +525,21 @@ def pack_into_passes(nng, arch, verbose_packing=False):
pass_list_top = sorted(pass_list_top, key=lambda ps: -1 if ps.ops[0].op_index is None else ps.ops[0].op_index)
# A concat is implemented by several AvgPool ops writing to the same ofm but with slice offset
- # Group all AvgPool ops for a concat so that they run in one sequence (within the same cmd stream)
+ # If there is a cpu op in between, group all AvgPool ops for a concat so that they run
+ # within the same cmd stream
last_idx = len(pass_list) - 1
for npu_ps in reversed(pass_list):
if npu_ps.placement == PassPlacement.Cpu or not npu_ps.ops[0].original_type.is_concat_op():
continue
# Concat pass found, search forward for the next avgpool op writing to the same ofm
idx = pass_list.index(npu_ps)
+ concat_is_split_between_npu_ops = False
for next_ps in pass_list[idx + 1 :]:
+ if next_ps.placement == PassPlacement.Cpu:
+ concat_is_split_between_npu_ops = True
next_is_concat = next_ps.ops[0].original_type.is_concat_op()
- if next_is_concat and next_ps.ops[0].ofm == npu_ps.ops[0].ofm:
- # Avgpool writing to the same OFM, group them
+ if next_is_concat and next_ps.ops[0].ofm == npu_ps.ops[0].ofm and concat_is_split_between_npu_ops:
+ # Avgpool writing to the same OFM and there is a cpu op between them, group them
pass_list.remove(npu_ps)
insert_index = pass_list.index(next_ps)
pass_list.insert(insert_index, npu_ps)