diff options
Diffstat (limited to 'ethosu/vela/pass_packing.py')
-rw-r--r-- | ethosu/vela/pass_packing.py | 39 |
1 files changed, 24 insertions, 15 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index a95e3839..c973b9c3 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -150,7 +150,7 @@ test_sequence = [ # ops_set npu_pre_ops, # incompatible_pack_flags - PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.ElementWise, + PassFlags.Cpu | PassFlags.MemoryOnly, # flags_to_set PassFlags.Npu | PassFlags.Mac | PassFlags.Pre | PassFlags.ElementWise, # flags_to_clear @@ -296,21 +296,9 @@ def pack_into_passes(nng, arch, verbose_packing=False): for inp in reversed(curr_op.inputs): if inp is None: continue - can_pack = True - if len(inp.ops) == 1: - next_op = inp.ops[0] - for outp in next_op.outputs: - consumers = outp.consumers() - if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op): - can_pack = False - break - else: - can_pack = False - - if can_pack: - to_process.append((next_op, inp)) + if can_pack(inp, curr_op): + to_process.append((inp.ops[0], inp)) else: - assert inp is not None input_set.add(inp) break @@ -469,6 +457,27 @@ def pack_into_passes(nng, arch, verbose_packing=False): return None + def can_pack(inp, curr_op): + if len(inp.ops) == 1: + next_op = inp.ops[0] + for outp in next_op.outputs: + consumers = outp.consumers() + if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op): + return False + + # There cannot be any reshaping between next_op ofm and corresponding curr_op ifm + if len(curr_op.ifm_shapes) != 0 and len(next_op.ofm_shapes) != 0: + if inp == curr_op.ifm and next_op.ofm_shapes[0] != curr_op.ifm_shapes[0]: + return False + elif ( + curr_op.ifm2 is not None and inp == curr_op.ifm2 and next_op.ofm_shapes[0] != curr_op.ifm_shapes[1] + ): + return False + else: + return False + + return True + def add_input_list(inp_to_add, inp_set, inp_refcnts, lut_list, ordered_inp_list): if inp_to_add in inp_set: if inp_refcnts[inp_to_add] == 0: |