diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-04-14 17:54:10 +0200 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-04-15 15:19:04 +0200 |
commit | 3645d009628bbb00185132e70d257d2c038716e7 (patch) | |
tree | 974c41393b836e7aeef9b1c1dda2faef3f63ac94 /ethosu/vela/graph_optimiser.py | |
parent | 1d6d5c47c2000facc377620a64084738339ccda9 (diff) | |
download | ethos-u-vela-3645d009628bbb00185132e70d257d2c038716e7.tar.gz |
MLBEDSW-4397 Fix Reshape ifm/ofm prod/cons by cpu op
Not only the sg input outputs need to be considered
before removing Reshape.
Added check if Reshape ifm/ofm is produced respectively
consumed by CPU. Handling is the same as if tensor is
sg input/output.
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: If509e1d23e3f22ed4c963d8dabd8c00c6b9c07e3
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index dd540a79..b708b62e 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -344,14 +344,19 @@ def fix_sg_input_output(op, arch, nng): # But in order to to do this, they cannot be outputs of the sg, # this need to be fixed prior to the removal. # Solution is to add a avgpool NOP, to maintain the original tensor. + # This is also valid when reshape ifm/ofm is produced respectively + # consumed by CPU # Check if operator ifm/ofm are sg ifm/ofm ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list) ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list) + # Check if ifm/ofm is produced repectivly consumed by CPU + ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops) + ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list) - if op.type == Op.Reshape and (ifm_is_sg_ofm or ifm_is_sg_ifm) and ofm_is_sg_ofm: - # Both ifm and ofm are sg outputs, only ifm need a copy, in order to remove the Reshape + if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed): + # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape insert_copy_op_after_tens(op.ifm) return op @@ -1194,10 +1199,14 @@ def remove_reshapes(op, arch): ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list) ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list) + # Check if ifm/ofm is produced repectivly consumed by CPU + ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops) + ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list) + # This case should be handled prior to this function - assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm) and ofm_is_sg_ofm) + assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed)) - if ofm_is_sg_ofm: + if ofm_is_sg_ofm or ofm_is_cpu_consumed: # Bypassed by replacing ifm with ofm ofm.ops = [] for prev_op in ifm.ops: |