aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-02-03 09:13:57 +0100
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-02-04 08:50:43 +0100
commitfcb1a00cfd4216782f4fc4429ce66c592a0b8030 (patch)
treea6f3866608000a85ebdd1b524fb4a84f41733d56
parentc77615121c28409081d2ac6526694edebb8d7255 (diff)
downloadethos-u-vela-fcb1a00cfd4216782f4fc4429ce66c592a0b8030.tar.gz
MLBEDSW-3951 Consider reshaping in pass packing
Consider reshaping in pass packing, when desiding if operators can be packed. For the cases where there is a reshape between ops they cannot be fused. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I8f2833b3fff156e9633ce0189d1d0df9109a6622
-rw-r--r--ethosu/vela/pass_packing.py39
1 files changed, 24 insertions, 15 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index a95e383..c973b9c 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -150,7 +150,7 @@ test_sequence = [
# ops_set
npu_pre_ops,
# incompatible_pack_flags
- PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.ElementWise,
+ PassFlags.Cpu | PassFlags.MemoryOnly,
# flags_to_set
PassFlags.Npu | PassFlags.Mac | PassFlags.Pre | PassFlags.ElementWise,
# flags_to_clear
@@ -296,21 +296,9 @@ def pack_into_passes(nng, arch, verbose_packing=False):
for inp in reversed(curr_op.inputs):
if inp is None:
continue
- can_pack = True
- if len(inp.ops) == 1:
- next_op = inp.ops[0]
- for outp in next_op.outputs:
- consumers = outp.consumers()
- if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op):
- can_pack = False
- break
- else:
- can_pack = False
-
- if can_pack:
- to_process.append((next_op, inp))
+ if can_pack(inp, curr_op):
+ to_process.append((inp.ops[0], inp))
else:
- assert inp is not None
input_set.add(inp)
break
@@ -469,6 +457,27 @@ def pack_into_passes(nng, arch, verbose_packing=False):
return None
+ def can_pack(inp, curr_op):
+ if len(inp.ops) == 1:
+ next_op = inp.ops[0]
+ for outp in next_op.outputs:
+ consumers = outp.consumers()
+ if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op):
+ return False
+
+ # There cannot be any reshaping between next_op ofm and corresponding curr_op ifm
+ if len(curr_op.ifm_shapes) != 0 and len(next_op.ofm_shapes) != 0:
+ if inp == curr_op.ifm and next_op.ofm_shapes[0] != curr_op.ifm_shapes[0]:
+ return False
+ elif (
+ curr_op.ifm2 is not None and inp == curr_op.ifm2 and next_op.ofm_shapes[0] != curr_op.ifm_shapes[1]
+ ):
+ return False
+ else:
+ return False
+
+ return True
+
def add_input_list(inp_to_add, inp_set, inp_refcnts, lut_list, ordered_inp_list):
if inp_to_add in inp_set:
if inp_refcnts[inp_to_add] == 0: