diff options
author | Ayaan Masood <Ayaan.Masood@arm.com> | 2022-06-29 11:30:57 +0100 |
---|---|---|
committer | Ayaan Masood <Ayaan.Masood@arm.com> | 2022-06-29 11:30:57 +0100 |
commit | 4965faee41300393cd8d74da4b399fa4c4ee9030 (patch) | |
tree | 1054d6f89be70ec471007132dec97d325ecc0067 /ethosu/vela/tflite_graph_optimiser.py | |
parent | 68b8f2f9457d56df3211be5318e3682332bcefbf (diff) | |
download | ethos-u-vela-4965faee41300393cd8d74da4b399fa4c4ee9030.tar.gz |
MLBEDSW-6313 Static optimisation for Shape OP
*Shape OP value is available at compile time hence
it can be optimised
*Disconnected shape OP at compile time from parent
tensor
*Transformed shape OP tensor into constant
Change-Id: I0a024269e2b592c6146dd72e62d7a41951fb727a
Signed-off-by: Ayaan Masood <Ayaan.Masood@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 06395784..cf3985e4 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1391,12 +1391,43 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): return op +def convert_shape_op_to_constant_tensor(op: Operation, arch, nng): + """Static optimisation for SHAPE operator output value known at compile time""" + + # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant + + if op.type == Op.Shape and op.run_on_npu: + + ifm, ofm = op.get_ifm_ofm() + + if len(ifm.shape) != ofm.shape[0]: + return op + + # Remove reference of the current shape op from the parent tensor's consumer list + ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index] + + # Clear any references to parent node + op.inputs = [] + + # Convert this SHAPE op to const + op.type = Op.Const + + # Add size calculation to shape output tensors + ofm.values = np.array(ifm.shape) + + return op + + def supported_operator_check(op, arch, nng): op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op) return op def tflite_optimise_graph(nng, arch): + + # Compile time optimisations + optimisation_list = [convert_shape_op_to_constant_tensor] + # Pre-processing step pre_process_list = [ supported_operator_check, @@ -1409,6 +1440,16 @@ def tflite_optimise_graph(nng, arch): sg, arch, [], + optimisation_list, + rewrite_unsupported=False, + ) + + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, + sg, + arch, + [], pre_process_list, rewrite_unsupported=False, ) |