aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 29598032..6c85bb43 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1061,8 +1061,8 @@ def convert_tanh_sigmoid_to_lut(op, arch, nng):
return op
-def remove_reshapes(op, arch):
- if op.run_on_npu and op.type == Op.Reshape:
+def remove_reshape_and_squeeze_ops(op, arch):
+ if op.run_on_npu and (op.type == Op.Reshape or op.type == Op.Squeeze):
ofm = op.ofm
ifm = op.ifm
@@ -1073,11 +1073,11 @@ def remove_reshapes(op, arch):
# or the reshape need to be replace with a NOP.
return
- # Check if Reshape ifm/ofm are network ifm/ofm
+ # Check if ifm/ofm are network ifm/ofm
ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
- # Check if ifm/ofm is produced repectivly consumed by CPU
+ # Check if ifm/ofm is produced respectively consumed by CPU
ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
@@ -1097,7 +1097,7 @@ def remove_reshapes(op, arch):
if cons_ifm == ifm:
ifm_cons.set_input_tensor(ofm, ifm_idx)
else:
- # Bypassed Reshape by replacing ofm with ifm
+ # Bypassed by replacing ofm with ifm
for cons in ofm.consumer_list:
for ifm_idx, cons_ifm in enumerate(cons.inputs):
if cons_ifm == ofm:
@@ -1567,9 +1567,9 @@ def tflite_optimise_graph(nng, arch):
nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
)
- # Removal of reshapes
+ # Removal of reshapes and squeeze
for sg in nng.subgraphs:
- rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
+ rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshape_and_squeeze_ops])
sg.refresh_after_modification()
# Rewrite of operators