diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 06395784..cf3985e4 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1391,12 +1391,43 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): return op +def convert_shape_op_to_constant_tensor(op: Operation, arch, nng): + """Static optimisation for SHAPE operator output value known at compile time""" + + # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant + + if op.type == Op.Shape and op.run_on_npu: + + ifm, ofm = op.get_ifm_ofm() + + if len(ifm.shape) != ofm.shape[0]: + return op + + # Remove reference of the current shape op from the parent tensor's consumer list + ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index] + + # Clear any references to parent node + op.inputs = [] + + # Convert this SHAPE op to const + op.type = Op.Const + + # Add size calculation to shape output tensors + ofm.values = np.array(ifm.shape) + + return op + + def supported_operator_check(op, arch, nng): op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op) return op def tflite_optimise_graph(nng, arch): + + # Compile time optimisations + optimisation_list = [convert_shape_op_to_constant_tensor] + # Pre-processing step pre_process_list = [ supported_operator_check, @@ -1409,6 +1440,16 @@ def tflite_optimise_graph(nng, arch): sg, arch, [], + optimisation_list, + rewrite_unsupported=False, + ) + + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, + sg, + arch, + [], pre_process_list, rewrite_unsupported=False, ) |