aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py56
1 files changed, 48 insertions, 8 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 9e72a6c1..778aa2ac 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -876,7 +876,12 @@ def tosa_optimise_graph(nng, arch):
# TODO the supported operator checking need to be split in semantic and HW checks
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], [supported_operator_check], rewrite_unsupported=False,
+ nng,
+ sg,
+ arch,
+ [],
+ [supported_operator_check],
+ rewrite_unsupported=False,
)
# Decomposing and rewrite of concat
@@ -893,7 +898,12 @@ def tosa_optimise_graph(nng, arch):
# Handle sg input output
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=True,
+ nng,
+ sg,
+ arch,
+ [],
+ [fix_sg_input_output_tosa],
+ rewrite_unsupported=True,
)
# Removal of reshapes
@@ -909,19 +919,34 @@ def tosa_optimise_graph(nng, arch):
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], [set_ifm_ofm_op_shapes], rewrite_unsupported=False,
+ nng,
+ sg,
+ arch,
+ [],
+ [set_ifm_ofm_op_shapes],
+ rewrite_unsupported=False,
)
# Removal of Transpose
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False,
+ nng,
+ sg,
+ arch,
+ [],
+ [remove_const_transpose],
+ rewrite_unsupported=False,
)
# TODO, when and where to best handle calc_scaling_avgpool
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], [calc_scaling_avgpool], rewrite_unsupported=False,
+ nng,
+ sg,
+ arch,
+ [],
+ [calc_scaling_avgpool],
+ rewrite_unsupported=False,
)
# Rewite Operators step
@@ -929,13 +954,22 @@ def tosa_optimise_graph(nng, arch):
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
+ nng,
+ sg,
+ arch,
+ [],
+ op_rewrite_list,
+ rewrite_unsupported=False,
)
# Post-processing step 1
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], [rewrite_activation, add_padding_fields],
+ nng,
+ sg,
+ arch,
+ [],
+ [rewrite_activation, add_padding_fields],
)
# Removal of Slice, need to be done after optimisation has been performed,
@@ -946,6 +980,12 @@ def tosa_optimise_graph(nng, arch):
# Post-processing step 2
for idx, sg in enumerate(nng.subgraphs):
- nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng,
+ sg,
+ arch,
+ [],
+ [fixup_quantization],
+ )
return nng