aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py141
1 files changed, 140 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 6b454e3d..27513d3d 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -464,6 +464,143 @@ def convert_resize_to_upscale_and_average_pool(op):
return op
+def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
+ def _compute_interpolation_values(index, input_size, output_size):
+ scale = input_size / output_size
+ scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
+ lower_bound = max(np.floor(scaled_value), 0)
+
+ return scaled_value, lower_bound
+
+ def _compute_kernels(input_height, input_width, output_height, output_width):
+ kernels = []
+ for y in (1, 2):
+ for x in (1, 2):
+ sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
+ sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
+
+ # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
+ # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
+ # top-to-bottom - same as the depthwise convolution strides across each tile
+ kernel = np.zeros((2, 2))
+ kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
+ kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
+ kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
+ kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
+ kernel *= 16
+ kernels.append(kernel)
+
+ return kernels
+
+ def _build_convolutions(op, kernels):
+ dw_op_attrs = {
+ "padding": Padding.TILE,
+ "stride_h": 1,
+ "stride_w": 1,
+ "strides": (1, 1, 1, 1),
+ "depth_multiplier": 1,
+ "channel_multiplier": 1,
+ "dilation_h_factor": 1,
+ "dilation_w_factor": 1,
+ "dilation": (1, 1, 1, 1),
+ }
+ ifm = op.ifm
+ ofm = op.ofm
+ ofm.ops = []
+ elem_size = 2 if ofm.dtype == DataType.int16 else 1
+
+ n, h, w, c = ifm.shape
+ _, _, ow, _ = ofm.shape
+
+ intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
+ intermediate_tens.quantization = op.outputs[0].quantization.clone()
+ avgpool_op = op
+ avgpool_op.name = "rb_init_avgpool"
+ avgpool_op.type = Op.AvgPool
+ avgpool_op.attrs["padding"] = Padding.VALID
+ avgpool_op.attrs["stride_w"] = 1
+ avgpool_op.attrs["stride_h"] = 1
+ avgpool_op.attrs["filter_width"] = 1
+ avgpool_op.attrs["filter_height"] = 1
+ avgpool_op.attrs["strides"] = [1, 1, 1, 1]
+ avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
+
+ avgpool_op.add_input_tensor(ifm)
+ avgpool_op.set_output_tensor(intermediate_tens)
+ avgpool_op.set_ifm_ofm_shapes()
+
+ dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
+ dw_conv._original_type = Op.ResizeBilinear
+ dw_conv.write_shape = Shape4D(n, h, w, c)
+ dw_conv.write_offset = Shape4D(0, 0, 0, 0)
+
+ # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
+ # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
+ # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
+ # values to be incorrectly rounded
+ ofm.quantization.next_after = True
+ dw_conv.rounding_mode = NpuRoundingMode.NATURAL
+
+ # Double height and width stride to write the output of each of the four depthwise convolutions below
+ # interleaved with each other when combined with OFM tile base offsets.
+ dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
+
+ # Choose tile padding direction - pad by 1 with edge values in two direction.
+ # For example, TL (top left) will pad top and left in H/W-plane in all channels.
+ directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
+ for i in (0, 1):
+ for j in (0, 1):
+ index = i * 2 + j
+ dw_conv.name = f"depthwise_conv_{index}"
+ dw_op_attrs["explicit_padding"] = directions[index]
+ dw_conv.attrs.update(dw_op_attrs)
+
+ # This will offset the start of the write by modifying the Tile 0 base address
+ dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
+
+ ofm.ops.append(dw_conv)
+ dw_conv.outputs = [ofm]
+
+ kernel = kernels[index]
+ shape = [2, 2, 1, c]
+ kernel = np.dstack([kernel] * c)
+
+ quant = QuantizationParameters()
+ quant.zero_point = 0
+ quant.scale_f32 = 1.0 / 16
+
+ dw_conv.inputs = []
+ dw_conv.add_input_tensor(intermediate_tens)
+ dw_conv.add_input_tensor(
+ create_const_tensor(
+ "weights",
+ shape,
+ intermediate_tens.dtype,
+ np.array(kernel).reshape(shape),
+ value_dtype=np.int8,
+ quantization=quant,
+ ),
+ )
+
+ # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
+ # need to append the bias tensor as resize ops only have 2 inputs
+ assert len(dw_conv.inputs) == 2
+ dw_conv.inputs.append(None)
+ fixup_bias_tensors(dw_conv, None, None)
+
+ dw_conv.set_ifm_ofm_shapes()
+ dw_conv = dw_conv.clone(f"_{index}")
+ return op
+
+ _, input_height, input_width, _ = op.ifm.shape
+ _, output_height, output_width, _ = op.ofm.shape
+
+ kernels = _compute_kernels(input_height, input_width, output_height, output_width)
+ op = _build_convolutions(op, kernels)
+
+ return op
+
+
def fixup_resize(op, arch, nng):
if op.type.is_resize_op() and op.run_on_npu:
if op.ifm_shapes[0] == op.ofm_shapes[0]:
@@ -472,6 +609,8 @@ def fixup_resize(op, arch, nng):
op.type = Op.Identity
elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
convert_resize_1x1_to_add(op)
+ elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
+ convert_resizebilinear_to_depthwise_convolutions(op)
else:
convert_resize_to_upscale_and_average_pool(op)