aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
authorAyaan Masood <Ayaan.Masood@arm.com>2022-06-29 11:30:57 +0100
committerAyaan Masood <Ayaan.Masood@arm.com>2022-06-29 11:30:57 +0100
commit4965faee41300393cd8d74da4b399fa4c4ee9030 (patch)
tree1054d6f89be70ec471007132dec97d325ecc0067 /ethosu/vela/tflite_graph_optimiser.py
parent68b8f2f9457d56df3211be5318e3682332bcefbf (diff)
downloadethos-u-vela-4965faee41300393cd8d74da4b399fa4c4ee9030.tar.gz
MLBEDSW-6313 Static optimisation for Shape OP
*Shape OP value is available at compile time hence it can be optimised *Disconnected shape OP at compile time from parent tensor *Transformed shape OP tensor into constant Change-Id: I0a024269e2b592c6146dd72e62d7a41951fb727a Signed-off-by: Ayaan Masood <Ayaan.Masood@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 06395784..cf3985e4 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1391,12 +1391,43 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
return op
+def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
+ """Static optimisation for SHAPE operator output value known at compile time"""
+
+ # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
+
+ if op.type == Op.Shape and op.run_on_npu:
+
+ ifm, ofm = op.get_ifm_ofm()
+
+ if len(ifm.shape) != ofm.shape[0]:
+ return op
+
+ # Remove reference of the current shape op from the parent tensor's consumer list
+ ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
+
+ # Clear any references to parent node
+ op.inputs = []
+
+ # Convert this SHAPE op to const
+ op.type = Op.Const
+
+ # Add size calculation to shape output tensors
+ ofm.values = np.array(ifm.shape)
+
+ return op
+
+
def supported_operator_check(op, arch, nng):
op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
return op
def tflite_optimise_graph(nng, arch):
+
+ # Compile time optimisations
+ optimisation_list = [convert_shape_op_to_constant_tensor]
+
# Pre-processing step
pre_process_list = [
supported_operator_check,
@@ -1409,6 +1440,16 @@ def tflite_optimise_graph(nng, arch):
sg,
arch,
[],
+ optimisation_list,
+ rewrite_unsupported=False,
+ )
+
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng,
+ sg,
+ arch,
+ [],
pre_process_list,
rewrite_unsupported=False,
)