diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 93 |
1 files changed, 92 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 85fb8bad..cc947bcf 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -191,8 +191,12 @@ def remove_SplitSliceRead(op, arch): if op.type == Op.SplitSliceRead: # Check if it is possible to put the SplitSliceRead on the tensor consumer(s), # or if an avgpool need to be inserted + # Not possible to do if consumer is a Transpose op since ifm shape has been reshaped and can not be changed if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all( - consumer is not None and consumer.run_on_npu and consumer.type not in memory_only_ops + consumer is not None + and consumer.run_on_npu + and consumer.type not in memory_only_ops + and consumer.original_type != Op.Transpose for consumer in op.ofm.consumer_list ): # SplitSliceRead can be performed by tensor consumer(s) @@ -2535,6 +2539,92 @@ def fixup_dilation_gt2(op: Operation, arch, nng) -> Operation: return op +def fixup_transpose(op, arch, nng): + """ + Convert Transpose to AvgPool where the strides for height and width is swapped on the OFM + in order to achieve the transpose. It is only possible to swap height and width on the op. + + Shape (2,3) transposed to Shape (3,2) + |0|1|2| ifm_stride_w = 1 |0|3| ofm_stride_w = 1 + |4|5|6| ifm_stride_h = 3 |1|4| ofm_stride_h = 2 + |2|5| + + To achieve the above with the AvgPool, the ofm_shape must be set equal to the ifm_shape. + The reason is that AvgPool uses the ofm shape when looping over the memory. So if the + ofm shape is not equal to the ifm shape the full ifm will not be read. + When looping over the values the following formula is used: + + IFM [h_pos, w_pos] = h_pos * ifm_stride_h + w_pos * ifm_stride_w + OFM [h_pos, w_pos] = h_pos * ofm_stride_w + w_pos * ofm_stride_h (stride has been swapped) + + Below code changes op to an AvgPool and sets the correct shapes. The actual stride swap + is done when creating the ofm featuremap. As seen there are several corner cases + when it is possible to transpose the depth channel. + """ + if op.type == Op.Transpose: + op.name = f"{op.name}_avgpool" + op.type = Op.AvgPool + op.attrs["padding"] = Padding.VALID + op.attrs["stride_w"] = 1 + op.attrs["stride_h"] = 1 + op.attrs["filter_width"] = 1 + op.attrs["filter_height"] = 1 + op.attrs["strides"] = [1, 1, 1, 1] + op.attrs["ksize"] = [1, 1, 1, 1] + # Swapping strides only works in linear format (ofm) + op.ofm.force_linear_format = True + + # Convert IFM to correct 4D shape + perm = op.inputs[1] + ifm_shape = op.ifm.shape + + # IFM rank 2 case + if len(ifm_shape) == 2: + # IFM shape: WxC -> 1xWxCx1 + op.ifm_shapes[0] = Shape4D([1, ifm_shape[0], ifm_shape[1], 1]) + + # IFM rank 3 cases + elif len(ifm_shape) == 3: + # Check if HxWxC -> WxHxC + if perm.values[0] == 1 and perm.values[1] == 0: + # IFM shape: HxWxC -> 1xHxWxC + op.ifm_shapes[0] = Shape4D([1, ifm_shape[0], ifm_shape[1], ifm_shape[2]]) + + # Check if 1xWxC -> 1xCxW + elif ifm_shape[0] == 1 and perm.values[1] == 2 and perm.values[2] == 1: + # IFM shape: 1xWxC -> 1xWxCx1 + op.ifm_shapes[0] = Shape4D([1, ifm_shape[1], ifm_shape[2], 1]) + + # Check if Hx1xC -> Cx1xH + elif ifm_shape[1] == 1 and perm.values[0] == 2 and perm.values[2] == 0: + # IFM shape: Hx1xC -> 1xHxCx1 + op.ifm_shapes[0] = Shape4D([1, ifm_shape[0], ifm_shape[2], 1]) + + # IFM rank 4 cases + elif len(ifm_shape) == 4: + # Check if 1xHxWxC -> 1xWxHxC + if perm.values[1] == 2 and perm.values[2] == 1: + # IFM shape is correct + pass + + # Check if 1x1xWxC -> 1x1xCxW + elif ifm_shape[1] == 1 and perm.values[2] == 3 and perm.values[3] == 2: + # IFM shape: 1x1xWxC -> 1xWxCx1 + op.ifm_shapes[0] = Shape4D([1, ifm_shape[2], ifm_shape[3], 1]) + + # Check if 1xHx1xC -> 1xCx1xH + elif ifm_shape[2] == 1 and perm.values[1] == 3 and perm.values[3] == 1: + # IFM shape: 1xHx1xC -> 1xHxCx1 + op.ifm_shapes[0] = Shape4D([1, ifm_shape[1], ifm_shape[3], 1]) + + # OFM shape must use IFM shape + op.ofm_shapes[0] = op.ifm_shapes[0] + + DebugDatabase.add_optimised(op, op) + + return op + + def fixup_reshape(op, arch, nng): def _get_explicit_shape(implicit_shape, total_size): # the explicit shape is a copy of the implicit shape but with the special -1 (remaining size) value converted to @@ -2824,6 +2914,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): convert_quantize, replace_pad_by_hw_pad, fixup_dilation_gt2, + fixup_transpose, ] for idx, sg in enumerate(nng.subgraphs): |