diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-08-23 15:33:59 +0200 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2021-09-03 12:19:48 +0000 |
commit | df99510f04aef99d1b8e9be9bfcde8fc1738b65f (patch) | |
tree | 00668b0e74f95da5cc51a41b9340d8c88fbc7ffe /ethosu/vela/tosa_graph_optimiser.py | |
parent | cce872bc3de3ed5f9bf1aa1a8cf9ce41cf2b2520 (diff) | |
download | ethos-u-vela-df99510f04aef99d1b8e9be9bfcde8fc1738b65f.tar.gz |
TOSA: Added Depthwise support
This is mainly to add support for depthwise conv2d
with dephmultiplier = 1.
(But there are no testcases suited, all I have sourced
has depth_multiplier set to 2, which is not supported.)
-Added support for depthwise conv2d.
-Added support for removing Transpose of constant data
-Added support for removing reshape
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I143e6246becfa78fd9f7510af0bf0d6b3fbbf2c7
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 60 |
1 files changed, 59 insertions, 1 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 44e0f8ec..169da40d 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -19,7 +19,10 @@ from . import rewrite_graph from .api import NpuRoundingMode from .data_type import DataType from .debug_database import DebugDatabase +from .graph_optimiser_util import bypass_reshape_and_squeeze_ops from .graph_optimiser_util import calc_explicit_padding +from .graph_optimiser_util import convert_depthwise_to_conv +from .graph_optimiser_util import fix_sg_input_output from .graph_optimiser_util import needed_total_padding from .graph_optimiser_util import set_ifm_ofm_op_shapes from .graph_optimiser_util import set_tensor_equivalence @@ -74,6 +77,43 @@ def add_padding_fields(op, arch, nng): return op +def remove_const_transpose(op, arch, nng): + if op.type == Op.Transpose: + removed = False + if len(op.ifm.ops) == 1: + prev_op = op.ifm.ops[0] + if prev_op.type == Op.Const: + # Transpose the Tensor and data and remove Transpose + # TODO move to Tensor? + reorder = op.attrs["perms"] + shape = op.ifm.shape.copy() + tens = op.ifm + + tens.shape = [shape[idx] for idx in reorder] + tens.bandwidth_shape = tens.shape + tens.storage_shape = tens.shape + + if tens.values is not None: + tens.values = tens.values.transpose(reorder) + + op.ofm.values = tens.values + # Bypass the Transpose op + prev_op.set_output_tensor(op.ofm) + DebugDatabase.add_optimised(op, prev_op) + removed = True + + if not removed: + print("Cannot remove Transpose, and handling of Transpose is not supported") + assert False + + return op + + +def remove_reshapes(op, arch): + if op.run_on_npu and op.type == Op.Reshape: + bypass_reshape_and_squeeze_ops(op) + + def rewrite_activation(op, arch, nng): if op.type not in (Op.ReluN, Op.Clamp): return op @@ -206,6 +246,7 @@ def fixup_quantization(op, arch, nng): def supported_operator_check(op, arch, nng): op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op) + assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const) return op @@ -221,8 +262,25 @@ def tosa_optimise_graph(nng, arch): nng, sg, arch, [], pre_process_list, rewrite_unsupported=False, ) + # Removal of Transpose + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False, + ) + + # Handle sg input output + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False, + ) + + # Removal of reshapes + for sg in nng.subgraphs: + rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes]) + sg.refresh_after_modification() + # Rewite Operators step - op_rewrite_list = [set_tensor_equivalence, rewrite_rescale] + op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv] for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( |