From fcb1a00cfd4216782f4fc4429ce66c592a0b8030 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Wed, 3 Feb 2021 09:13:57 +0100 Subject: MLBEDSW-3951 Consider reshaping in pass packing Consider reshaping in pass packing, when desiding if operators can be packed. For the cases where there is a reshape between ops they cannot be fused. Signed-off-by: Patrik Gustavsson Change-Id: I8f2833b3fff156e9633ce0189d1d0df9109a6622 --- ethosu/vela/pass_packing.py | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index a95e3839..c973b9c3 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -150,7 +150,7 @@ test_sequence = [ # ops_set npu_pre_ops, # incompatible_pack_flags - PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.ElementWise, + PassFlags.Cpu | PassFlags.MemoryOnly, # flags_to_set PassFlags.Npu | PassFlags.Mac | PassFlags.Pre | PassFlags.ElementWise, # flags_to_clear @@ -296,21 +296,9 @@ def pack_into_passes(nng, arch, verbose_packing=False): for inp in reversed(curr_op.inputs): if inp is None: continue - can_pack = True - if len(inp.ops) == 1: - next_op = inp.ops[0] - for outp in next_op.outputs: - consumers = outp.consumers() - if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op): - can_pack = False - break - else: - can_pack = False - - if can_pack: - to_process.append((next_op, inp)) + if can_pack(inp, curr_op): + to_process.append((inp.ops[0], inp)) else: - assert inp is not None input_set.add(inp) break @@ -469,6 +457,27 @@ def pack_into_passes(nng, arch, verbose_packing=False): return None + def can_pack(inp, curr_op): + if len(inp.ops) == 1: + next_op = inp.ops[0] + for outp in next_op.outputs: + consumers = outp.consumers() + if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op): + return False + + # There cannot be any reshaping between next_op ofm and corresponding curr_op ifm + if len(curr_op.ifm_shapes) != 0 and len(next_op.ofm_shapes) != 0: + if inp == curr_op.ifm and next_op.ofm_shapes[0] != curr_op.ifm_shapes[0]: + return False + elif ( + curr_op.ifm2 is not None and inp == curr_op.ifm2 and next_op.ofm_shapes[0] != curr_op.ifm_shapes[1] + ): + return False + else: + return False + + return True + def add_input_list(inp_to_add, inp_set, inp_refcnts, lut_list, ordered_inp_list): if inp_to_add in inp_set: if inp_refcnts[inp_to_add] == 0: -- cgit v1.2.1