From fa4cb29996ffe1e64e39655c2195af6ff02e887a Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Thu, 10 Sep 2020 08:19:36 +0200 Subject: MLBEDSW-2994 Remove undesired reshape OPs Addded functionality for removing reshape OPs, that enclose an Elementwize OP with only one non-constant Tensor. Signed-off-by: Patrik Gustavsson Change-Id: Idaac50cfbd732e2667668be2baa059673236cc56 --- ethosu/vela/graph_optimiser.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) (limited to 'ethosu/vela') diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index bd30fd3d..48684058 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -809,6 +809,41 @@ def convert_lrelu(op, arch): return convert_lrelu_to_mul_max(op, arch) +def remove_unwanted_reshapes(op, arch): + # Try to remove reshapes enclosing ElementWise operator with only one non-constant input + if not op.run_on_npu or op.attrs["npu_block_type"] != NpuBlockType.ElementWise: + return op + + # Check if the ElementWise operator only have one non-constant input + non_const_tens = [x for x in op.inputs if x.ops[0].type != "Const"] + if len(non_const_tens) != 1: + return op + ifm = non_const_tens[0] + + # Check if operation is enclosed by Reshapes that can be removed + ofm = op.outputs[0] + prev_op = ifm.ops[0] + if ( + len(ifm.consumer_list) == 1 + and prev_op.type == "Reshape" + and len(ofm.consumer_list) == 1 + and ofm.consumer_list[0].type == "Reshape" + ): + # Operation is enclosed by reshapes, check if they can be removed + prev_op_ifm, _, _, prev_op_ofm = prev_op.get_ifm_weights_biases_ofm() + cons_op = ofm.consumer_list[0] + cons_op_ifm = ofm + cons_op_ofm = cons_op.outputs[0] + if len(prev_op_ifm.shape) == len(cons_op_ofm.shape): + # Check if quantization is the same in the input and output for the reshape ops + if prev_op_ifm.quantization.is_scaling_equal( + prev_op_ofm.quantization + ) and cons_op_ifm.quantization.is_scaling_equal(cons_op_ofm.quantization): + op.inputs[0] = prev_op_ifm + op.outputs[0] = cons_op_ofm + return op + + def fuse_activation_function_with_prev(op, arch): # if op is a no-op: attempts to move the activation function to the preceding op if not op.attrs.get("is_nop", False) or op.attrs.get("fused_activation_function", None) is None: @@ -901,6 +936,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False): fixup_resizebilinear, fixup_bias_tensors, convert_mul_max_to_abs_or_lrelu, + remove_unwanted_reshapes, convert_lrelu, ] -- cgit v1.2.1