aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py93
1 files changed, 92 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 85fb8ba..cc947bc 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -191,8 +191,12 @@ def remove_SplitSliceRead(op, arch):
if op.type == Op.SplitSliceRead:
# Check if it is possible to put the SplitSliceRead on the tensor consumer(s),
# or if an avgpool need to be inserted
+ # Not possible to do if consumer is a Transpose op since ifm shape has been reshaped and can not be changed
if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all(
- consumer is not None and consumer.run_on_npu and consumer.type not in memory_only_ops
+ consumer is not None
+ and consumer.run_on_npu
+ and consumer.type not in memory_only_ops
+ and consumer.original_type != Op.Transpose
for consumer in op.ofm.consumer_list
):
# SplitSliceRead can be performed by tensor consumer(s)
@@ -2535,6 +2539,92 @@ def fixup_dilation_gt2(op: Operation, arch, nng) -> Operation:
return op
+def fixup_transpose(op, arch, nng):
+ """
+ Convert Transpose to AvgPool where the strides for height and width is swapped on the OFM
+ in order to achieve the transpose. It is only possible to swap height and width on the op.
+
+ Shape (2,3) transposed to Shape (3,2)
+ |0|1|2| ifm_stride_w = 1 |0|3| ofm_stride_w = 1
+ |4|5|6| ifm_stride_h = 3 |1|4| ofm_stride_h = 2
+ |2|5|
+
+ To achieve the above with the AvgPool, the ofm_shape must be set equal to the ifm_shape.
+ The reason is that AvgPool uses the ofm shape when looping over the memory. So if the
+ ofm shape is not equal to the ifm shape the full ifm will not be read.
+ When looping over the values the following formula is used:
+
+ IFM [h_pos, w_pos] = h_pos * ifm_stride_h + w_pos * ifm_stride_w
+ OFM [h_pos, w_pos] = h_pos * ofm_stride_w + w_pos * ofm_stride_h (stride has been swapped)
+
+ Below code changes op to an AvgPool and sets the correct shapes. The actual stride swap
+ is done when creating the ofm featuremap. As seen there are several corner cases
+ when it is possible to transpose the depth channel.
+ """
+ if op.type == Op.Transpose:
+ op.name = f"{op.name}_avgpool"
+ op.type = Op.AvgPool
+ op.attrs["padding"] = Padding.VALID
+ op.attrs["stride_w"] = 1
+ op.attrs["stride_h"] = 1
+ op.attrs["filter_width"] = 1
+ op.attrs["filter_height"] = 1
+ op.attrs["strides"] = [1, 1, 1, 1]
+ op.attrs["ksize"] = [1, 1, 1, 1]
+ # Swapping strides only works in linear format (ofm)
+ op.ofm.force_linear_format = True
+
+ # Convert IFM to correct 4D shape
+ perm = op.inputs[1]
+ ifm_shape = op.ifm.shape
+
+ # IFM rank 2 case
+ if len(ifm_shape) == 2:
+ # IFM shape: WxC -> 1xWxCx1
+ op.ifm_shapes[0] = Shape4D([1, ifm_shape[0], ifm_shape[1], 1])
+
+ # IFM rank 3 cases
+ elif len(ifm_shape) == 3:
+ # Check if HxWxC -> WxHxC
+ if perm.values[0] == 1 and perm.values[1] == 0:
+ # IFM shape: HxWxC -> 1xHxWxC
+ op.ifm_shapes[0] = Shape4D([1, ifm_shape[0], ifm_shape[1], ifm_shape[2]])
+
+ # Check if 1xWxC -> 1xCxW
+ elif ifm_shape[0] == 1 and perm.values[1] == 2 and perm.values[2] == 1:
+ # IFM shape: 1xWxC -> 1xWxCx1
+ op.ifm_shapes[0] = Shape4D([1, ifm_shape[1], ifm_shape[2], 1])
+
+ # Check if Hx1xC -> Cx1xH
+ elif ifm_shape[1] == 1 and perm.values[0] == 2 and perm.values[2] == 0:
+ # IFM shape: Hx1xC -> 1xHxCx1
+ op.ifm_shapes[0] = Shape4D([1, ifm_shape[0], ifm_shape[2], 1])
+
+ # IFM rank 4 cases
+ elif len(ifm_shape) == 4:
+ # Check if 1xHxWxC -> 1xWxHxC
+ if perm.values[1] == 2 and perm.values[2] == 1:
+ # IFM shape is correct
+ pass
+
+ # Check if 1x1xWxC -> 1x1xCxW
+ elif ifm_shape[1] == 1 and perm.values[2] == 3 and perm.values[3] == 2:
+ # IFM shape: 1x1xWxC -> 1xWxCx1
+ op.ifm_shapes[0] = Shape4D([1, ifm_shape[2], ifm_shape[3], 1])
+
+ # Check if 1xHx1xC -> 1xCx1xH
+ elif ifm_shape[2] == 1 and perm.values[1] == 3 and perm.values[3] == 1:
+ # IFM shape: 1xHx1xC -> 1xHxCx1
+ op.ifm_shapes[0] = Shape4D([1, ifm_shape[1], ifm_shape[3], 1])
+
+ # OFM shape must use IFM shape
+ op.ofm_shapes[0] = op.ifm_shapes[0]
+
+ DebugDatabase.add_optimised(op, op)
+
+ return op
+
+
def fixup_reshape(op, arch, nng):
def _get_explicit_shape(implicit_shape, total_size):
# the explicit shape is a copy of the implicit shape but with the special -1 (remaining size) value converted to
@@ -2824,6 +2914,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
convert_quantize,
replace_pad_by_hw_pad,
fixup_dilation_gt2,
+ fixup_transpose,
]
for idx, sg in enumerate(nng.subgraphs):