diff options
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 60 |
1 files changed, 59 insertions, 1 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 44e0f8ec..169da40d 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -19,7 +19,10 @@ from . import rewrite_graph from .api import NpuRoundingMode from .data_type import DataType from .debug_database import DebugDatabase +from .graph_optimiser_util import bypass_reshape_and_squeeze_ops from .graph_optimiser_util import calc_explicit_padding +from .graph_optimiser_util import convert_depthwise_to_conv +from .graph_optimiser_util import fix_sg_input_output from .graph_optimiser_util import needed_total_padding from .graph_optimiser_util import set_ifm_ofm_op_shapes from .graph_optimiser_util import set_tensor_equivalence @@ -74,6 +77,43 @@ def add_padding_fields(op, arch, nng): return op +def remove_const_transpose(op, arch, nng): + if op.type == Op.Transpose: + removed = False + if len(op.ifm.ops) == 1: + prev_op = op.ifm.ops[0] + if prev_op.type == Op.Const: + # Transpose the Tensor and data and remove Transpose + # TODO move to Tensor? + reorder = op.attrs["perms"] + shape = op.ifm.shape.copy() + tens = op.ifm + + tens.shape = [shape[idx] for idx in reorder] + tens.bandwidth_shape = tens.shape + tens.storage_shape = tens.shape + + if tens.values is not None: + tens.values = tens.values.transpose(reorder) + + op.ofm.values = tens.values + # Bypass the Transpose op + prev_op.set_output_tensor(op.ofm) + DebugDatabase.add_optimised(op, prev_op) + removed = True + + if not removed: + print("Cannot remove Transpose, and handling of Transpose is not supported") + assert False + + return op + + +def remove_reshapes(op, arch): + if op.run_on_npu and op.type == Op.Reshape: + bypass_reshape_and_squeeze_ops(op) + + def rewrite_activation(op, arch, nng): if op.type not in (Op.ReluN, Op.Clamp): return op @@ -206,6 +246,7 @@ def fixup_quantization(op, arch, nng): def supported_operator_check(op, arch, nng): op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op) + assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const) return op @@ -221,8 +262,25 @@ def tosa_optimise_graph(nng, arch): nng, sg, arch, [], pre_process_list, rewrite_unsupported=False, ) + # Removal of Transpose + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False, + ) + + # Handle sg input output + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False, + ) + + # Removal of reshapes + for sg in nng.subgraphs: + rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes]) + sg.refresh_after_modification() + # Rewite Operators step - op_rewrite_list = [set_tensor_equivalence, rewrite_rescale] + op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv] for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( |