From 2abd3dd75bd3d20e1a3aeaf12362f9872b40fa0a Mon Sep 17 00:00:00 2001 From: Diqing Zhong Date: Tue, 25 Aug 2020 10:40:36 +0200 Subject: MLBEDSW-2786: Fix IFM order in binary operation - Setup ifm/ifm2 based on primary op's inputs Change-Id: I727eab473165d7cc876b70fa8873fbc0c1480fb5 Signed-off-by: Diqing Zhong --- ethosu/vela/pass_packing.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) (limited to 'ethosu') diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index a4caf0c0..9e36cd62 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -385,15 +385,23 @@ def pack_into_passes(nng, arch, verbose_packing=False): # to avoid that they would accidentally be assigned as ifm or ifm2 lut_list = [] input_refcounts = collections.defaultdict(int) - for op in ops_list: + input_ops_list = ops_list.copy() + + # Check primary_op first + if primary_op is not None: + for inp in primary_op.inputs: + if len(inp.ops) == 1 and inp.ops[0].type == "DMA" and inp.purpose == TensorPurpose.FeatureMap: + src_op = inp.ops[0] + if src_op in input_ops_list: + inp = src_op.inputs[0] + input_ops_list.remove(src_op) + add_input_list(inp, input_set, input_refcounts, lut_list, ordered_input_list) + input_ops_list.remove(primary_op) + + # Check rest of the list + for op in input_ops_list: for inp in op.inputs: - if inp in input_set: - if input_refcounts[inp] == 0: - if inp.purpose == TensorPurpose.LUT: - lut_list.append(inp) - else: - ordered_input_list.append(inp) - input_refcounts[inp] += 1 + add_input_list(inp, input_set, input_refcounts, lut_list, ordered_input_list) name = ops_list[0].name non_dma_ops = [op for op in ops_list if op.type != "DMA"] @@ -472,6 +480,15 @@ def pack_into_passes(nng, arch, verbose_packing=False): return None + def add_input_list(inp_to_add, inp_set, inp_refcnts, lut_list, ordered_inp_list): + if inp_to_add in inp_set: + if inp_refcnts[inp_to_add] == 0: + if inp_to_add.purpose == TensorPurpose.LUT: + lut_list.append(inp_to_add) + else: + ordered_inp_list.append(inp_to_add) + inp_refcnts[inp_to_add] += 1 + for sg in nng.subgraphs: reverse_pass_list = [] visit_op_refcount = collections.defaultdict(int) -- cgit v1.2.1