aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r--ethosu/vela/graph_optimiser_util.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 44a08f55..46762e4d 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -213,13 +213,18 @@ def set_ifm_ofm_op_shapes(op, arch, nng):
def check_splitsliceread_to_consumer_shape(op, cons_op):
assert op.type == Op.SplitSliceRead
- # SplitSliceRead ofm shape must match consumer ifm shape
+ # SplitSliceRead ofm shape must fit within the consumer ifm shape
if cons_op.ifm == op.ofm:
- return cons_op.ifm_shapes[0] == op.ofm_shapes[0]
+ cons_shape = cons_op.ifm_shapes[0].as_list()
+ read_shape = op.ofm_shapes[0].as_list()
elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm:
- return cons_op.ifm_shapes[1] == op.ofm_shapes[0]
+ cons_shape = cons_op.ifm_shapes[1].as_list()
+ read_shape = op.ofm_shapes[0].as_list()
+ else:
+ return False
- return False
+ # All read shape values <= consumer shape values
+ return all(read_shape[idx] <= x for idx, x in enumerate(cons_shape))
def move_splitsliceread_to_consumer(op, cons_op):