aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-09-10 08:19:36 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-09-11 12:19:22 +0000
commitfa4cb29996ffe1e64e39655c2195af6ff02e887a (patch)
treebfd40a4b055b7ef7878b5a1405f3eb20172098a9
parente1cc3de77668b66ea413c570de7161e2bba89502 (diff)
downloadethos-u-vela-fa4cb29996ffe1e64e39655c2195af6ff02e887a.tar.gz
MLBEDSW-2994 Remove undesired reshape OPs
Addded functionality for removing reshape OPs, that enclose an Elementwize OP with only one non-constant Tensor. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: Idaac50cfbd732e2667668be2baa059673236cc56
-rw-r--r--ethosu/vela/graph_optimiser.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index bd30fd3d..48684058 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -809,6 +809,41 @@ def convert_lrelu(op, arch):
return convert_lrelu_to_mul_max(op, arch)
+def remove_unwanted_reshapes(op, arch):
+ # Try to remove reshapes enclosing ElementWise operator with only one non-constant input
+ if not op.run_on_npu or op.attrs["npu_block_type"] != NpuBlockType.ElementWise:
+ return op
+
+ # Check if the ElementWise operator only have one non-constant input
+ non_const_tens = [x for x in op.inputs if x.ops[0].type != "Const"]
+ if len(non_const_tens) != 1:
+ return op
+ ifm = non_const_tens[0]
+
+ # Check if operation is enclosed by Reshapes that can be removed
+ ofm = op.outputs[0]
+ prev_op = ifm.ops[0]
+ if (
+ len(ifm.consumer_list) == 1
+ and prev_op.type == "Reshape"
+ and len(ofm.consumer_list) == 1
+ and ofm.consumer_list[0].type == "Reshape"
+ ):
+ # Operation is enclosed by reshapes, check if they can be removed
+ prev_op_ifm, _, _, prev_op_ofm = prev_op.get_ifm_weights_biases_ofm()
+ cons_op = ofm.consumer_list[0]
+ cons_op_ifm = ofm
+ cons_op_ofm = cons_op.outputs[0]
+ if len(prev_op_ifm.shape) == len(cons_op_ofm.shape):
+ # Check if quantization is the same in the input and output for the reshape ops
+ if prev_op_ifm.quantization.is_scaling_equal(
+ prev_op_ofm.quantization
+ ) and cons_op_ifm.quantization.is_scaling_equal(cons_op_ofm.quantization):
+ op.inputs[0] = prev_op_ifm
+ op.outputs[0] = cons_op_ofm
+ return op
+
+
def fuse_activation_function_with_prev(op, arch):
# if op is a no-op: attempts to move the activation function to the preceding op
if not op.attrs.get("is_nop", False) or op.attrs.get("fused_activation_function", None) is None:
@@ -901,6 +936,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
fixup_resizebilinear,
fixup_bias_tensors,
convert_mul_max_to_abs_or_lrelu,
+ remove_unwanted_reshapes,
convert_lrelu,
]