From 190b63a6ae6908625dffab203a8137c27aaec5fd Mon Sep 17 00:00:00 2001 From: Johan Alfven Date: Thu, 4 Apr 2024 13:26:18 +0200 Subject: MLBEDSW-8886: Regression: Output diff on LSTM - Fix regression caused by too strict constraints on SplitSpliceRead causing output diff for LSTM. - As long as the SplitSpliceRead shape fits within the consumer ifm shape it is ok to move the read. Change-Id: Ia6f508f99638c3aedccc7fd9f31405527bb64f87 Signed-off-by: Johan Alfven --- ethosu/vela/graph_optimiser_util.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'ethosu') diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index 44a08f55..46762e4d 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -213,13 +213,18 @@ def set_ifm_ofm_op_shapes(op, arch, nng): def check_splitsliceread_to_consumer_shape(op, cons_op): assert op.type == Op.SplitSliceRead - # SplitSliceRead ofm shape must match consumer ifm shape + # SplitSliceRead ofm shape must fit within the consumer ifm shape if cons_op.ifm == op.ofm: - return cons_op.ifm_shapes[0] == op.ofm_shapes[0] + cons_shape = cons_op.ifm_shapes[0].as_list() + read_shape = op.ofm_shapes[0].as_list() elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm: - return cons_op.ifm_shapes[1] == op.ofm_shapes[0] + cons_shape = cons_op.ifm_shapes[1].as_list() + read_shape = op.ofm_shapes[0].as_list() + else: + return False - return False + # All read shape values <= consumer shape values + return all(read_shape[idx] <= x for idx, x in enumerate(cons_shape)) def move_splitsliceread_to_consumer(op, cons_op): -- cgit v1.2.1