From 7647b0fe74e68792963c5602a083c215ee369182 Mon Sep 17 00:00:00 2001 From: Johan Alfven Date: Tue, 2 Apr 2024 20:56:09 +0200 Subject: MLBEDSW-8875: MLCE: Update criteria when to move SplitSpliceRead to consumer - When possible, a read slice from a split or stride is moved to the following op. The problem in this case was that the following op was a Maxpool op (from Softmax). The Maxpool op is using a different input shape compared to the original Softmax op, and this input shape was then changed when the read slice was applied to the Maxpool op. - The result is a faulty Maxpool op with an output diff. - The fix is to prevent moving the slice read when the consumer input shape differs from the Split/Stride ofm shape Change-Id: I649d89c38645fa51c20c3602954e2b8af9372076 Signed-off-by: Johan Alfven --- ethosu/vela/graph_optimiser_util.py | 13 ++++++++++++- ethosu/vela/tflite_graph_optimiser.py | 2 ++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index c0099ffa..44a08f55 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -211,6 +211,17 @@ def set_ifm_ofm_op_shapes(op, arch, nng): return op +def check_splitsliceread_to_consumer_shape(op, cons_op): + assert op.type == Op.SplitSliceRead + # SplitSliceRead ofm shape must match consumer ifm shape + if cons_op.ifm == op.ofm: + return cons_op.ifm_shapes[0] == op.ofm_shapes[0] + elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm: + return cons_op.ifm_shapes[1] == op.ofm_shapes[0] + + return False + + def move_splitsliceread_to_consumer(op, cons_op): assert op.type == Op.SplitSliceRead diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 1e53e378..687e5d4f 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -34,6 +34,7 @@ from .errors import UnsupportedFeatureError from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .graph_optimiser_util import bypass_memory_only_ops from .graph_optimiser_util import calc_explicit_padding +from .graph_optimiser_util import check_splitsliceread_to_consumer_shape from .graph_optimiser_util import convert_depthwise_to_conv from .graph_optimiser_util import create_avg_pool_for_concat from .graph_optimiser_util import memory_only_ops @@ -199,6 +200,7 @@ def remove_SplitSliceRead(op, arch): and consumer.run_on_npu and consumer.type not in memory_only_ops and consumer.original_type != Op.Transpose + and check_splitsliceread_to_consumer_shape(op, consumer) and not ( consumer.type.is_binary_elementwise_op() and Shape4D.from_list(consumer.ofm.shape) != op.ofm_shapes[0] ) -- cgit v1.2.1