diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r-- | ethosu/vela/graph_optimiser_util.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index 8a94f361..f1b9e1aa 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -39,6 +39,10 @@ memory_only_ops = ( Op.Identity, ) +# This list contains ops that requires its ofm shape to be intact in order +# to be able to decompose it correctly in the graph optimiser step +ofm_not_replaceable_ops = (Op.Mean,) + def _avoid_nhcwb16_for_concat(tens): # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a @@ -300,8 +304,11 @@ def bypass_memory_only_ops(op, arch, nng): ifm_has_multiple_cons = len(op.ifm.consumer_list) > 1 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops) + producer_ofm_not_replaceable = any( + ifm_prod is not None and ifm_prod.type in ofm_not_replaceable_ops for ifm_prod in op.ifm.ops + ) - if ifm_has_multiple_cons or ifm_is_cpu_produced: + if ifm_has_multiple_cons or ifm_is_cpu_produced or producer_ofm_not_replaceable: # Convert to a memcpy op op.type = Op.Memcpy DebugDatabase.add_optimised(op, op) |