diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-02-03 09:13:57 +0100 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-02-04 08:50:43 +0100 |
commit | fcb1a00cfd4216782f4fc4429ce66c592a0b8030 (patch) | |
tree | a6f3866608000a85ebdd1b524fb4a84f41733d56 | |
parent | c77615121c28409081d2ac6526694edebb8d7255 (diff) | |
download | ethos-u-vela-fcb1a00cfd4216782f4fc4429ce66c592a0b8030.tar.gz |
MLBEDSW-3951 Consider reshaping in pass packing
Consider reshaping in pass packing, when desiding if
operators can be packed.
For the cases where there is a reshape between ops
they cannot be fused.
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I8f2833b3fff156e9633ce0189d1d0df9109a6622
-rw-r--r-- | ethosu/vela/pass_packing.py | 39 |
1 files changed, 24 insertions, 15 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index a95e3839..c973b9c3 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -150,7 +150,7 @@ test_sequence = [ # ops_set npu_pre_ops, # incompatible_pack_flags - PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.ElementWise, + PassFlags.Cpu | PassFlags.MemoryOnly, # flags_to_set PassFlags.Npu | PassFlags.Mac | PassFlags.Pre | PassFlags.ElementWise, # flags_to_clear @@ -296,21 +296,9 @@ def pack_into_passes(nng, arch, verbose_packing=False): for inp in reversed(curr_op.inputs): if inp is None: continue - can_pack = True - if len(inp.ops) == 1: - next_op = inp.ops[0] - for outp in next_op.outputs: - consumers = outp.consumers() - if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op): - can_pack = False - break - else: - can_pack = False - - if can_pack: - to_process.append((next_op, inp)) + if can_pack(inp, curr_op): + to_process.append((inp.ops[0], inp)) else: - assert inp is not None input_set.add(inp) break @@ -469,6 +457,27 @@ def pack_into_passes(nng, arch, verbose_packing=False): return None + def can_pack(inp, curr_op): + if len(inp.ops) == 1: + next_op = inp.ops[0] + for outp in next_op.outputs: + consumers = outp.consumers() + if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op): + return False + + # There cannot be any reshaping between next_op ofm and corresponding curr_op ifm + if len(curr_op.ifm_shapes) != 0 and len(next_op.ofm_shapes) != 0: + if inp == curr_op.ifm and next_op.ofm_shapes[0] != curr_op.ifm_shapes[0]: + return False + elif ( + curr_op.ifm2 is not None and inp == curr_op.ifm2 and next_op.ofm_shapes[0] != curr_op.ifm_shapes[1] + ): + return False + else: + return False + + return True + def add_input_list(inp_to_add, inp_set, inp_refcnts, lut_list, ordered_inp_list): if inp_to_add in inp_set: if inp_refcnts[inp_to_add] == 0: |