aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-04-14 17:54:10 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-04-15 15:19:04 +0200
commit3645d009628bbb00185132e70d257d2c038716e7 (patch)
tree974c41393b836e7aeef9b1c1dda2faef3f63ac94
parent1d6d5c47c2000facc377620a64084738339ccda9 (diff)
downloadethos-u-vela-3645d009628bbb00185132e70d257d2c038716e7.tar.gz
MLBEDSW-4397 Fix Reshape ifm/ofm prod/cons by cpu op
Not only the sg input outputs need to be considered before removing Reshape. Added check if Reshape ifm/ofm is produced respectively consumed by CPU. Handling is the same as if tensor is sg input/output. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: If509e1d23e3f22ed4c963d8dabd8c00c6b9c07e3
-rw-r--r--ethosu/vela/graph_optimiser.py17
1 files changed, 13 insertions, 4 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index dd540a79..b708b62e 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -344,14 +344,19 @@ def fix_sg_input_output(op, arch, nng):
# But in order to to do this, they cannot be outputs of the sg,
# this need to be fixed prior to the removal.
# Solution is to add a avgpool NOP, to maintain the original tensor.
+ # This is also valid when reshape ifm/ofm is produced respectively
+ # consumed by CPU
# Check if operator ifm/ofm are sg ifm/ofm
ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
+ # Check if ifm/ofm is produced repectivly consumed by CPU
+ ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+ ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
- if op.type == Op.Reshape and (ifm_is_sg_ofm or ifm_is_sg_ifm) and ofm_is_sg_ofm:
- # Both ifm and ofm are sg outputs, only ifm need a copy, in order to remove the Reshape
+ if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
+ # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
insert_copy_op_after_tens(op.ifm)
return op
@@ -1194,10 +1199,14 @@ def remove_reshapes(op, arch):
ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
+ # Check if ifm/ofm is produced repectivly consumed by CPU
+ ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+ ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
+
# This case should be handled prior to this function
- assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm) and ofm_is_sg_ofm)
+ assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
- if ofm_is_sg_ofm:
+ if ofm_is_sg_ofm or ofm_is_cpu_consumed:
# Bypassed by replacing ifm with ofm
ofm.ops = []
for prev_op in ifm.ops: