aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r--ethosu/vela/graph_optimiser_util.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 8a94f361..f1b9e1aa 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -39,6 +39,10 @@ memory_only_ops = (
Op.Identity,
)
+# This list contains ops that requires its ofm shape to be intact in order
+# to be able to decompose it correctly in the graph optimiser step
+ofm_not_replaceable_ops = (Op.Mean,)
+
def _avoid_nhcwb16_for_concat(tens):
# If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -300,8 +304,11 @@ def bypass_memory_only_ops(op, arch, nng):
ifm_has_multiple_cons = len(op.ifm.consumer_list) > 1
ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+ producer_ofm_not_replaceable = any(
+ ifm_prod is not None and ifm_prod.type in ofm_not_replaceable_ops for ifm_prod in op.ifm.ops
+ )
- if ifm_has_multiple_cons or ifm_is_cpu_produced:
+ if ifm_has_multiple_cons or ifm_is_cpu_produced or producer_ofm_not_replaceable:
# Convert to a memcpy op
op.type = Op.Memcpy
DebugDatabase.add_optimised(op, op)