diff options
author | Johan Alfven <johan.alfven@arm.com> | 2024-04-02 20:56:09 +0200 |
---|---|---|
committer | Johan Alfven <johan.alfven@arm.com> | 2024-04-03 19:44:27 +0200 |
commit | 7647b0fe74e68792963c5602a083c215ee369182 (patch) | |
tree | 7cc4073de4677150161041bd8ef72ae558926a48 /ethosu/vela/tflite_graph_optimiser.py | |
parent | 55d90dd1f51e95e3b066ab2976b595107cc485c9 (diff) | |
download | ethos-u-vela-7647b0fe74e68792963c5602a083c215ee369182.tar.gz |
MLBEDSW-8875: MLCE: Update criteria when to move SplitSpliceRead to consumer
- When possible, a read slice from a split or stride is moved to
the following op. The problem in this case was that the following
op was a Maxpool op (from Softmax). The Maxpool op is using a
different input shape compared to the original Softmax op, and
this input shape was then changed when the read slice was applied
to the Maxpool op.
- The result is a faulty Maxpool op with an output diff.
- The fix is to prevent moving the slice read when the consumer
input shape differs from the Split/Stride ofm shape
Change-Id: I649d89c38645fa51c20c3602954e2b8af9372076
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 1e53e378..687e5d4f 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -34,6 +34,7 @@ from .errors import UnsupportedFeatureError from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .graph_optimiser_util import bypass_memory_only_ops from .graph_optimiser_util import calc_explicit_padding +from .graph_optimiser_util import check_splitsliceread_to_consumer_shape from .graph_optimiser_util import convert_depthwise_to_conv from .graph_optimiser_util import create_avg_pool_for_concat from .graph_optimiser_util import memory_only_ops @@ -199,6 +200,7 @@ def remove_SplitSliceRead(op, arch): and consumer.run_on_npu and consumer.type not in memory_only_ops and consumer.original_type != Op.Transpose + and check_splitsliceread_to_consumer_shape(op, consumer) and not ( consumer.type.is_binary_elementwise_op() and Shape4D.from_list(consumer.ofm.shape) != op.ofm_shapes[0] ) |