aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
authorJonas Ohlsson <jonas.ohlsson@arm.com>2021-08-25 11:38:03 +0200
committerJonas Ohlsson <jonas.ohlsson@arm.com>2021-08-25 14:06:57 +0200
commitfbfd96e79177f79376d7cced5fb06465a4e00055 (patch)
tree315f0120ce75981be24a73a5c10f09eb2ee96ae5 /ethosu/vela/tflite_graph_optimiser.py
parent4a5ec683de6aca21ab56be4963afd66829fbc0a0 (diff)
downloadethos-u-vela-fbfd96e79177f79376d7cced5fb06465a4e00055.tar.gz
Handle sg input and output for Squeeze operator3.1.0.rc2
Update to handle the case when the Squeeze Op ifm/ofm are the subgraph ifm/ofm, to facilitate the removal of the Squeeze Op. Adding NOP to maintain the original tensors. Updated pytests for squeeze operator. Signed-off-by: Jonas Ohlsson <jonas.ohlsson@arm.com> Change-Id: I623cae05e696fb16ccf29dedc42fd822601e9fd9
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 6c85bb43..ff2f5a08 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -185,7 +185,7 @@ def remove_SplitSliceRead(op, arch):
len(op.ofm.consumer_list) == 1
and op.ofm.consumer_list[0] is not None
and op.ofm.consumer_list[0].run_on_npu
- and op.ofm.consumer_list[0].type != Op.Reshape
+ and op.ofm.consumer_list[0].type not in (Op.Reshape, Op.Squeeze)
and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
):
# SplitSliceRead can be performed by tensor consumer
@@ -245,10 +245,10 @@ def insert_copy_op_after_tens(tens):
def fix_sg_input_output(op, arch, nng):
- if not op.run_on_npu or op.type != Op.Reshape:
+ if not op.run_on_npu or op.type not in (Op.Reshape, Op.Squeeze):
return op
- # For the Reshape operators we want to remove, tensors are removed.
+ # For the Reshape/Squeeze operators we want to remove, tensors are removed.
# But in order to to do this, they cannot be outputs of the sg,
# this need to be fixed prior to the removal.
# Solution is to add a avgpool NOP, to maintain the original tensor.
@@ -259,12 +259,12 @@ def fix_sg_input_output(op, arch, nng):
ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
- # Check if ifm/ofm is produced repectivly consumed by CPU
+ # Check if ifm/ofm is produced respectively consumed by CPU
ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
- # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
+ # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape/Squeeze
insert_copy_op_after_tens(op.ifm)
return op
@@ -1062,7 +1062,7 @@ def convert_tanh_sigmoid_to_lut(op, arch, nng):
def remove_reshape_and_squeeze_ops(op, arch):
- if op.run_on_npu and (op.type == Op.Reshape or op.type == Op.Squeeze):
+ if op.run_on_npu and op.type in (Op.Reshape, Op.Squeeze):
ofm = op.ofm
ifm = op.ifm