diff options
author | Rickard Bolin <rickard.bolin@arm.com> | 2024-01-31 12:05:11 +0000 |
---|---|---|
committer | Rickard Bolin <rickard.bolin@arm.com> | 2024-05-16 14:08:21 +0000 |
commit | be78a053a57da7bdae240690c933824c0861f55b (patch) | |
tree | e6eabce902b42fcbdc7ef4cf7cfbc8136e11246d /ethosu/vela/tflite_graph_optimiser.py | |
parent | 891468561ecfc61d27adcdc92b41ec216eaa1b08 (diff) | |
download | ethos-u-vela-main.tar.gz |
MLBEDSW-8561: Striding support in H/W for StridedSliceHEAD3.12.0.rc1main
Change-Id: Ie6f39d9c4125f7c16d27621de47cd76143c2e636
Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 38 |
1 files changed, 26 insertions, 12 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 3af8588c..ccbb1f28 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -141,7 +141,7 @@ def rewrite_split_ops(tens, arch, nng): if not split_op.run_on_npu: return tens - inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis() + inp, outputs, axis, offset_start, offset_end, strides_tens = split_op.get_split_inputs_axis() tens.ops = [] new_op = Operation(Op.SplitSliceRead, split_op.name) @@ -150,8 +150,10 @@ def rewrite_split_ops(tens, arch, nng): if None in (offset_end, offset_start): read_shape = None else: - # the read shape is relative to each start offset - read_shape = Shape4D([oe - os for oe, os in zip(offset_end, offset_start)]) + # The read shape is relative to each start offset + # Limit read shape to the size of the IFM - offset is not necessarily limited + ifm_dims = split_op.ifm_shapes[0].as_list() + read_shape = Shape4D([min(oe, ifm_dim) - os for oe, os, ifm_dim in zip(offset_end, offset_start, ifm_dims)]) # For Split the offset cannot be extracted from the tensor so it has to # be calculated from the index of the output tensor @@ -182,6 +184,9 @@ def rewrite_split_ops(tens, arch, nng): new_op.set_output_tensor(tens) new_op.ifm_shapes.append(Shape4D(inp.shape)) new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx]) + # Set stride multiplier in H/W if a stride tensor is provided + s_h, s_w = (strides_tens.values[-3], strides_tens.values[-2]) if strides_tens else (1, 1) + new_op.ifm_stride_multiplier[0] = [1, s_h, s_w] # C/H/W DebugDatabase.add_optimised(split_op, new_op) return tens @@ -193,18 +198,24 @@ def remove_SplitSliceRead(op, arch): # Check if it is possible to put the SplitSliceRead on the tensor consumer(s), # or if an avgpool need to be inserted # Not possible to move: + # - if ifm stride multiplier is larger than one in any dimension # - if consumer is a Transpose op since ifm shape has been reshaped and can not be changed # - if consumer is elementwise and ifm needs to be broadcasted - if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all( - consumer is not None - and consumer.run_on_npu - and consumer.type not in memory_only_ops - and consumer.original_type != Op.Transpose - and check_splitsliceread_to_consumer_shape(op, consumer) - and not ( - consumer.type.is_binary_elementwise_op() and Shape4D.from_list(consumer.ofm.shape) != op.ofm_shapes[0] + if ( + op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) + and all(s_mul == 1 for s_mul in op.ifm_stride_multiplier[0]) + and all( + consumer is not None + and consumer.run_on_npu + and consumer.type not in memory_only_ops + and consumer.original_type != Op.Transpose + and check_splitsliceread_to_consumer_shape(op, consumer) + and not ( + consumer.type.is_binary_elementwise_op() + and Shape4D.from_list(consumer.ofm.shape) != op.ofm_shapes[0] + ) + for consumer in op.ofm.consumer_list ) - for consumer in op.ofm.consumer_list ): # SplitSliceRead can be performed by tensor consumer(s) for cons_op in list(op.ofm.consumer_list): @@ -219,6 +230,9 @@ def remove_SplitSliceRead(op, arch): avgpool_op.ofm_shapes.append(op.ofm_shapes[0]) avgpool_op.read_offsets[0] = op.read_offsets[0] avgpool_op.read_shapes[0] = op.read_shapes[0] + if any(s_mul != 1 for s_mul in op.ifm_stride_multiplier[0]): + avgpool_op.ifm_stride_multiplier[0] = op.ifm_stride_multiplier[0].copy() + avgpool_op.ifm.force_linear_format = True op.ifm.consumer_list.remove(op) DebugDatabase.add_optimised(op, avgpool_op) |