diff options
-rw-r--r-- | ethosu/vela/pass_packing.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index 66f7ffb3..f157e67b 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -525,17 +525,21 @@ def pack_into_passes(nng, arch, verbose_packing=False): pass_list_top = sorted(pass_list_top, key=lambda ps: -1 if ps.ops[0].op_index is None else ps.ops[0].op_index) # A concat is implemented by several AvgPool ops writing to the same ofm but with slice offset - # Group all AvgPool ops for a concat so that they run in one sequence (within the same cmd stream) + # If there is a cpu op in between, group all AvgPool ops for a concat so that they run + # within the same cmd stream last_idx = len(pass_list) - 1 for npu_ps in reversed(pass_list): if npu_ps.placement == PassPlacement.Cpu or not npu_ps.ops[0].original_type.is_concat_op(): continue # Concat pass found, search forward for the next avgpool op writing to the same ofm idx = pass_list.index(npu_ps) + concat_is_split_between_npu_ops = False for next_ps in pass_list[idx + 1 :]: + if next_ps.placement == PassPlacement.Cpu: + concat_is_split_between_npu_ops = True next_is_concat = next_ps.ops[0].original_type.is_concat_op() - if next_is_concat and next_ps.ops[0].ofm == npu_ps.ops[0].ofm: - # Avgpool writing to the same OFM, group them + if next_is_concat and next_ps.ops[0].ofm == npu_ps.ops[0].ofm and concat_is_split_between_npu_ops: + # Avgpool writing to the same OFM and there is a cpu op between them, group them pass_list.remove(npu_ps) insert_index = pass_list.index(next_ps) pass_list.insert(insert_index, npu_ps) |