diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r-- | ethosu/vela/graph_optimiser_util.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index 73fbf6c7..3e15f126 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -49,15 +49,18 @@ def _avoid_nhcwb16_for_concat(tens): def _avoid_nhcwb16_for_split(tens): # If read offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input + + # Return True if NHCWB16 needs to be avoided + def offset_not_aligned(read_offset): + return read_offset is not None and (read_offset.depth % 16) != 0 + for cons_op in tens.consumer_list: if cons_op.ifm == tens: - read_offset = cons_op.read_offsets[0] - elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens: - read_offset = cons_op.read_offsets[1] - else: - assert False - if read_offset is not None and (read_offset[-1] % 16) != 0: - return True + if offset_not_aligned(cons_op.read_offsets[0]): + return True + if cons_op.ifm2 is not None and cons_op.ifm2 == tens: + if offset_not_aligned(cons_op.read_offsets[1]): + return True return False |