diff options
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 56 |
1 files changed, 48 insertions, 8 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 9e72a6c1..778aa2ac 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -876,7 +876,12 @@ def tosa_optimise_graph(nng, arch): # TODO the supported operator checking need to be split in semantic and HW checks for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], [supported_operator_check], rewrite_unsupported=False, + nng, + sg, + arch, + [], + [supported_operator_check], + rewrite_unsupported=False, ) # Decomposing and rewrite of concat @@ -893,7 +898,12 @@ def tosa_optimise_graph(nng, arch): # Handle sg input output for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=True, + nng, + sg, + arch, + [], + [fix_sg_input_output_tosa], + rewrite_unsupported=True, ) # Removal of reshapes @@ -909,19 +919,34 @@ def tosa_optimise_graph(nng, arch): for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], [set_ifm_ofm_op_shapes], rewrite_unsupported=False, + nng, + sg, + arch, + [], + [set_ifm_ofm_op_shapes], + rewrite_unsupported=False, ) # Removal of Transpose for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False, + nng, + sg, + arch, + [], + [remove_const_transpose], + rewrite_unsupported=False, ) # TODO, when and where to best handle calc_scaling_avgpool for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], [calc_scaling_avgpool], rewrite_unsupported=False, + nng, + sg, + arch, + [], + [calc_scaling_avgpool], + rewrite_unsupported=False, ) # Rewite Operators step @@ -929,13 +954,22 @@ def tosa_optimise_graph(nng, arch): for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False, + nng, + sg, + arch, + [], + op_rewrite_list, + rewrite_unsupported=False, ) # Post-processing step 1 for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], [rewrite_activation, add_padding_fields], + nng, + sg, + arch, + [], + [rewrite_activation, add_padding_fields], ) # Removal of Slice, need to be done after optimisation has been performed, @@ -946,6 +980,12 @@ def tosa_optimise_graph(nng, arch): # Post-processing step 2 for idx, sg in enumerate(nng.subgraphs): - nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],) + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, + sg, + arch, + [], + [fixup_quantization], + ) return nng |