aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py47
1 files changed, 34 insertions, 13 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 687e5d4f..ccbb1f28 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -141,7 +141,7 @@ def rewrite_split_ops(tens, arch, nng):
if not split_op.run_on_npu:
return tens
- inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
+ inp, outputs, axis, offset_start, offset_end, strides_tens = split_op.get_split_inputs_axis()
tens.ops = []
new_op = Operation(Op.SplitSliceRead, split_op.name)
@@ -150,8 +150,10 @@ def rewrite_split_ops(tens, arch, nng):
if None in (offset_end, offset_start):
read_shape = None
else:
- # the read shape is relative to each start offset
- read_shape = Shape4D([oe - os for oe, os in zip(offset_end, offset_start)])
+ # The read shape is relative to each start offset
+ # Limit read shape to the size of the IFM - offset is not necessarily limited
+ ifm_dims = split_op.ifm_shapes[0].as_list()
+ read_shape = Shape4D([min(oe, ifm_dim) - os for oe, os, ifm_dim in zip(offset_end, offset_start, ifm_dims)])
# For Split the offset cannot be extracted from the tensor so it has to
# be calculated from the index of the output tensor
@@ -182,6 +184,9 @@ def rewrite_split_ops(tens, arch, nng):
new_op.set_output_tensor(tens)
new_op.ifm_shapes.append(Shape4D(inp.shape))
new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
+ # Set stride multiplier in H/W if a stride tensor is provided
+ s_h, s_w = (strides_tens.values[-3], strides_tens.values[-2]) if strides_tens else (1, 1)
+ new_op.ifm_stride_multiplier[0] = [1, s_h, s_w] # C/H/W
DebugDatabase.add_optimised(split_op, new_op)
return tens
@@ -193,18 +198,24 @@ def remove_SplitSliceRead(op, arch):
# Check if it is possible to put the SplitSliceRead on the tensor consumer(s),
# or if an avgpool need to be inserted
# Not possible to move:
+ # - if ifm stride multiplier is larger than one in any dimension
# - if consumer is a Transpose op since ifm shape has been reshaped and can not be changed
# - if consumer is elementwise and ifm needs to be broadcasted
- if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all(
- consumer is not None
- and consumer.run_on_npu
- and consumer.type not in memory_only_ops
- and consumer.original_type != Op.Transpose
- and check_splitsliceread_to_consumer_shape(op, consumer)
- and not (
- consumer.type.is_binary_elementwise_op() and Shape4D.from_list(consumer.ofm.shape) != op.ofm_shapes[0]
+ if (
+ op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
+ and all(s_mul == 1 for s_mul in op.ifm_stride_multiplier[0])
+ and all(
+ consumer is not None
+ and consumer.run_on_npu
+ and consumer.type not in memory_only_ops
+ and consumer.original_type != Op.Transpose
+ and check_splitsliceread_to_consumer_shape(op, consumer)
+ and not (
+ consumer.type.is_binary_elementwise_op()
+ and Shape4D.from_list(consumer.ofm.shape) != op.ofm_shapes[0]
+ )
+ for consumer in op.ofm.consumer_list
)
- for consumer in op.ofm.consumer_list
):
# SplitSliceRead can be performed by tensor consumer(s)
for cons_op in list(op.ofm.consumer_list):
@@ -219,6 +230,9 @@ def remove_SplitSliceRead(op, arch):
avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
avgpool_op.read_offsets[0] = op.read_offsets[0]
avgpool_op.read_shapes[0] = op.read_shapes[0]
+ if any(s_mul != 1 for s_mul in op.ifm_stride_multiplier[0]):
+ avgpool_op.ifm_stride_multiplier[0] = op.ifm_stride_multiplier[0].copy()
+ avgpool_op.ifm.force_linear_format = True
op.ifm.consumer_list.remove(op)
DebugDatabase.add_optimised(op, avgpool_op)
@@ -827,7 +841,7 @@ def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation:
if op.type == Op.FullyConnected:
# Check if the first dimension indicates batching
if op.ifm_shapes[0].batch > 1:
- batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
+ batching_split = {4: (2, 2), 6: (2, 3), 8: (2, 4), 9: (3, 3), 12: (3, 4), 16: (4, 4)}
n = op.ifm_shapes[0].batch
h, w = batching_split.get(n, (1, n))
op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
@@ -840,6 +854,13 @@ def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation:
n = op.ofm_shapes[0].batch
h, w = batching_split.get(n, (1, n))
op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
+ if h == 1 and w > 4:
+ # If batch can not be found in the split set the weights are going to be
+ # read from memory several times. Convert op to conv2d since this
+ # enables weight buffering.
+ op.type = Op.Conv2DBias
+ op.attrs["padding"] = Padding.SAME
+ DebugDatabase.add_optimised(op, op)
return op