aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 6c85bb43..ff2f5a08 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -185,7 +185,7 @@ def remove_SplitSliceRead(op, arch):
len(op.ofm.consumer_list) == 1
and op.ofm.consumer_list[0] is not None
and op.ofm.consumer_list[0].run_on_npu
- and op.ofm.consumer_list[0].type != Op.Reshape
+ and op.ofm.consumer_list[0].type not in (Op.Reshape, Op.Squeeze)
and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
):
# SplitSliceRead can be performed by tensor consumer
@@ -245,10 +245,10 @@ def insert_copy_op_after_tens(tens):
def fix_sg_input_output(op, arch, nng):
- if not op.run_on_npu or op.type != Op.Reshape:
+ if not op.run_on_npu or op.type not in (Op.Reshape, Op.Squeeze):
return op
- # For the Reshape operators we want to remove, tensors are removed.
+ # For the Reshape/Squeeze operators we want to remove, tensors are removed.
# But in order to to do this, they cannot be outputs of the sg,
# this need to be fixed prior to the removal.
# Solution is to add a avgpool NOP, to maintain the original tensor.
@@ -259,12 +259,12 @@ def fix_sg_input_output(op, arch, nng):
ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
- # Check if ifm/ofm is produced repectivly consumed by CPU
+ # Check if ifm/ofm is produced respectively consumed by CPU
ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
- # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
+ # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape/Squeeze
insert_copy_op_after_tens(op.ifm)
return op
@@ -1062,7 +1062,7 @@ def convert_tanh_sigmoid_to_lut(op, arch, nng):
def remove_reshape_and_squeeze_ops(op, arch):
- if op.run_on_npu and (op.type == Op.Reshape or op.type == Op.Squeeze):
+ if op.run_on_npu and op.type in (Op.Reshape, Op.Squeeze):
ofm = op.ofm
ifm = op.ifm