From 4965faee41300393cd8d74da4b399fa4c4ee9030 Mon Sep 17 00:00:00 2001 From: Ayaan Masood Date: Wed, 29 Jun 2022 11:30:57 +0100 Subject: MLBEDSW-6313 Static optimisation for Shape OP *Shape OP value is available at compile time hence it can be optimised *Disconnected shape OP at compile time from parent tensor *Transformed shape OP tensor into constant Change-Id: I0a024269e2b592c6146dd72e62d7a41951fb727a Signed-off-by: Ayaan Masood --- ethosu/vela/tflite_graph_optimiser.py | 41 +++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) (limited to 'ethosu/vela/tflite_graph_optimiser.py') diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 06395784..cf3985e4 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1391,18 +1391,59 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): return op +def convert_shape_op_to_constant_tensor(op: Operation, arch, nng): + """Static optimisation for SHAPE operator output value known at compile time""" + + # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant + + if op.type == Op.Shape and op.run_on_npu: + + ifm, ofm = op.get_ifm_ofm() + + if len(ifm.shape) != ofm.shape[0]: + return op + + # Remove reference of the current shape op from the parent tensor's consumer list + ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index] + + # Clear any references to parent node + op.inputs = [] + + # Convert this SHAPE op to const + op.type = Op.Const + + # Add size calculation to shape output tensors + ofm.values = np.array(ifm.shape) + + return op + + def supported_operator_check(op, arch, nng): op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op) return op def tflite_optimise_graph(nng, arch): + + # Compile time optimisations + optimisation_list = [convert_shape_op_to_constant_tensor] + # Pre-processing step pre_process_list = [ supported_operator_check, set_ifm_ofm_op_shapes, ] + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, + sg, + arch, + [], + optimisation_list, + rewrite_unsupported=False, + ) + for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( nng, -- cgit v1.2.1