diff options
-rw-r--r-- | ethosu/vela/graph_optimiser_util.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index 44a08f55..46762e4d 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -213,13 +213,18 @@ def set_ifm_ofm_op_shapes(op, arch, nng): def check_splitsliceread_to_consumer_shape(op, cons_op): assert op.type == Op.SplitSliceRead - # SplitSliceRead ofm shape must match consumer ifm shape + # SplitSliceRead ofm shape must fit within the consumer ifm shape if cons_op.ifm == op.ofm: - return cons_op.ifm_shapes[0] == op.ofm_shapes[0] + cons_shape = cons_op.ifm_shapes[0].as_list() + read_shape = op.ofm_shapes[0].as_list() elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm: - return cons_op.ifm_shapes[1] == op.ofm_shapes[0] + cons_shape = cons_op.ifm_shapes[1].as_list() + read_shape = op.ofm_shapes[0].as_list() + else: + return False - return False + # All read shape values <= consumer shape values + return all(read_shape[idx] <= x for idx, x in enumerate(cons_shape)) def move_splitsliceread_to_consumer(op, cons_op): |