diff options
author | Diqing Zhong <diqing.zhong@arm.com> | 2020-08-25 10:40:36 +0200 |
---|---|---|
committer | tim.hall <tim.hall@arm.com> | 2020-08-27 13:59:58 +0000 |
commit | 2abd3dd75bd3d20e1a3aeaf12362f9872b40fa0a (patch) | |
tree | e8f63ebcfbb1c16ee8cf00ab4e569d4e39ecfd89 /ethosu | |
parent | f0c59bf945d7746961fa05186d1353ed91f587bc (diff) | |
download | ethos-u-vela-2abd3dd75bd3d20e1a3aeaf12362f9872b40fa0a.tar.gz |
MLBEDSW-2786: Fix IFM order in binary operation
- Setup ifm/ifm2 based on primary op's inputs
Change-Id: I727eab473165d7cc876b70fa8873fbc0c1480fb5
Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/pass_packing.py | 33 |
1 files changed, 25 insertions, 8 deletions
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index a4caf0c0..9e36cd62 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -385,15 +385,23 @@ def pack_into_passes(nng, arch, verbose_packing=False): # to avoid that they would accidentally be assigned as ifm or ifm2 lut_list = [] input_refcounts = collections.defaultdict(int) - for op in ops_list: + input_ops_list = ops_list.copy() + + # Check primary_op first + if primary_op is not None: + for inp in primary_op.inputs: + if len(inp.ops) == 1 and inp.ops[0].type == "DMA" and inp.purpose == TensorPurpose.FeatureMap: + src_op = inp.ops[0] + if src_op in input_ops_list: + inp = src_op.inputs[0] + input_ops_list.remove(src_op) + add_input_list(inp, input_set, input_refcounts, lut_list, ordered_input_list) + input_ops_list.remove(primary_op) + + # Check rest of the list + for op in input_ops_list: for inp in op.inputs: - if inp in input_set: - if input_refcounts[inp] == 0: - if inp.purpose == TensorPurpose.LUT: - lut_list.append(inp) - else: - ordered_input_list.append(inp) - input_refcounts[inp] += 1 + add_input_list(inp, input_set, input_refcounts, lut_list, ordered_input_list) name = ops_list[0].name non_dma_ops = [op for op in ops_list if op.type != "DMA"] @@ -472,6 +480,15 @@ def pack_into_passes(nng, arch, verbose_packing=False): return None + def add_input_list(inp_to_add, inp_set, inp_refcnts, lut_list, ordered_inp_list): + if inp_to_add in inp_set: + if inp_refcnts[inp_to_add] == 0: + if inp_to_add.purpose == TensorPurpose.LUT: + lut_list.append(inp_to_add) + else: + ordered_inp_list.append(inp_to_add) + inp_refcnts[inp_to_add] += 1 + for sg in nng.subgraphs: reverse_pass_list = [] visit_op_refcount = collections.defaultdict(int) |