aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py172
1 files changed, 125 insertions, 47 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index d2899c4c..ed8fa1e3 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -279,10 +279,9 @@ def fixup_conv2d_backprop(op, arch, nng):
# Convert the op to an elementwise add
-def convert_resizebilinear_1x1_to_add(op):
- op.type = Op.Add
+def convert_resize_1x1_to_add(op):
+ op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
op.name = op.name + "_add"
- op.attrs["resizebilinear"] = True
# Create an input tensor filled with zeros
shape = op.ofm_shapes[0].as_list()
tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
@@ -301,12 +300,103 @@ def convert_resizebilinear_1x1_to_add(op):
return op
-# Convert ResizeBilinear to a number of 2x2 nearest neighbor upscaling and one avgpool op with kernel size dependent
-# on the upscaling factor. Avgpool kernel limit of 8x8 when padding is applied limits upscaling to 8x8.
-def convert_resizebilinear_to_upscale_and_average_pool(op):
+# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
+# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
+# to select the appropriate nearest neighbor value
+def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
+ ifm = op.ifm
+ ofm = op.ofm
+ output_depth = ofm.shape[-1]
+ dw_op_attrs = {
+ "padding": Padding.VALID,
+ "stride_h": 1,
+ "stride_w": 1,
+ "strides": (1, 1, 1, 1),
+ "depth_multiplier": 1,
+ "channel_multiplier": 1,
+ "dilation_h_factor": 1,
+ "dilation_w_factor": 1,
+ "dilation": (1, 1, 1, 1),
+ }
+
+ # change resizebilinear to depthwise
+ op.type = Op.DepthwiseConv2DBias
+ op.attrs.update(dw_op_attrs)
+ op.set_input_tensor(ifm, 0) # ifm tensor index
+ op.activation = None
+
+ # add input resample to resize by x2
+ op.ifm_resampling_mode = resampling_mode.NEAREST
+
+ # don't care about the rounding mode as it is nearest neighbor
+
+ # setup weight tensor
+ weight_quant = QuantizationParameters()
+ weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
+ weight_quant.zero_point = 0
+ weight_quant.quant_dim = 0
+ ofm_dtype = ofm.dtype
+ if ofm_dtype == DataType.uint8:
+ weight_value_dtype = np.uint8
+ weight_quant.quant_min = 0
+ weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
+ else:
+ if ofm_dtype == DataType.int8:
+ weight_value_dtype = np.int8
+ else:
+ assert ofm_dtype == DataType.int16
+ weight_value_dtype = np.int16
+
+ weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
+ weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
+
+ weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
+
+ # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
+ # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
+ # below-and-right (i.e. next) to it (D).
+ # 0---1---2
+ # | A | B |
+ # 1---*---+
+ # | C | D |
+ # 2---+---+
+ weight_values = [0] * (upscale_factor * upscale_factor)
+ centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
+ weight_values[centre_coeff] = 1
+
+ # add weight tensor, this will discard the size tensor of the resize op
+ op.set_input_tensor(
+ create_const_tensor(
+ "weights",
+ weight_shape,
+ ofm.dtype,
+ np.array(weight_values).reshape(weight_shape),
+ value_dtype=weight_value_dtype,
+ quantization=weight_quant,
+ ),
+ 1, # inputs tensor weight index
+ )
+
+ # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
+ # need to append the bias tensor as resize ops only have 2 inputs
+ assert len(op.inputs) == 2
+ op.inputs.append(None)
+ fixup_bias_tensors(op, None, None)
+
+ # finally update the shape incase we've change the tensor shapes or connections
+ op.set_ifm_ofm_shapes()
+
+ return op
+
+
+# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
+# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
+# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
+def convert_resize_to_upscale_and_average_pool(op):
pre_op = op
outputs = op.outputs
dtype = op.ifm.dtype
+
op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
op.ifm_resampling_mode = resampling_mode.NEAREST
@@ -321,14 +411,14 @@ def convert_resizebilinear_to_upscale_and_average_pool(op):
# between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
n = int(np.log2(upscale_factor))
- # Perform 2x2 upscaling n-1 times
+ # Perform x2 upscaling n-1 times
scaled_op = pre_op
for count in range(n - 1):
if count > 0:
scaled_op = op.clone(f"_{count}")
scaled_op.inputs[0] = pre_op.outputs[0]
- # Nearest neighbor 2x2 upscaling
+ # Nearest neighbor x2 upscaling
upscaled_shape = upscaled_shape * 2
shape = op.ofm_shapes[0].as_list()
shape[1:3] = upscaled_shape
@@ -339,17 +429,30 @@ def convert_resizebilinear_to_upscale_and_average_pool(op):
scaled_op.set_ifm_ofm_shapes()
- # Last 2x2 upscaling also applies avgpool with kernel size dependent on the upscaling factor and adds
- # padding to the right and bottom.
+ # Last x2 upscaling
if n > 1:
scaled_op = op.clone(f"_{n-1}")
scaled_op.inputs[0] = pre_op.outputs[0]
- if op.attrs["align_corners"]:
- scaled_op.attrs["padding"] = Padding.VALID
- else:
- scaled_op.attrs["padding"] = Padding.EXPLICIT
- scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
- scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
+
+ if scaled_op.original_type == Op.ResizeBilinear:
+ if scaled_op.attrs["align_corners"]:
+ # no padding
+ scaled_op.attrs["padding"] = Padding.VALID
+ else:
+ # padding to the right and bottom (limits average pool to 8x8 kernel)
+ scaled_op.attrs["padding"] = Padding.EXPLICIT
+ scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
+
+ # kernal size dependent on the upscaling factor
+ scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
+ else: # Op.ResizeNearestNeighbor
+ if scaled_op.attrs["align_corners"]:
+ # use depthwise conv to select the correct value
+ scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
+ else:
+ # keep 1x1 kernel and average pool
+ pass
+
scaled_op.outputs = outputs
scaled_op.outputs[0].ops = [scaled_op]
scaled_op.set_ifm_ofm_shapes()
@@ -357,16 +460,16 @@ def convert_resizebilinear_to_upscale_and_average_pool(op):
return op
-def fixup_resizebilinear(op, arch, nng):
- if op.type == Op.ResizeBilinear and op.run_on_npu:
+def fixup_resize(op, arch, nng):
+ if op.type.is_resize_op() and op.run_on_npu:
if op.ifm_shapes[0] == op.ofm_shapes[0]:
- # Bypass nop resizebilinear
+ # Bypass the resize op which is essentially a NOP
op.inputs = op.inputs[:1]
op.type = Op.Identity
elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
- convert_resizebilinear_1x1_to_add(op)
+ convert_resize_1x1_to_add(op)
else:
- convert_resizebilinear_to_upscale_and_average_pool(op)
+ convert_resize_to_upscale_and_average_pool(op)
return op
@@ -1130,31 +1233,6 @@ def convert_pad(op: Operation, arch, nng):
return avgpool_op
-def add_attrs_to_resizebilinear(op, arch, nng):
- if op.type == Op.ResizeBilinear and op.run_on_npu:
- input_shape = op.ifm_shapes[0]
- upscaled_height = input_shape.height * 2
- upscaled_width = input_shape.width * 2
- out_shape = op.ofm_shapes[0]
- if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
- # this means the output is supposed to be a x2 upscale,
- # so we need to do SAME padding
- op.attrs["padding"] = Padding.SAME
- elif (
- op.attrs["align_corners"]
- and out_shape.height == (upscaled_height - 1)
- and out_shape.width == (upscaled_width - 1)
- ):
- # here we can just run the avg pool without padding and
- # produce a (M * 2 - 1, N * 2 - 1) sized output
- op.attrs["padding"] = Padding.VALID
- else:
- return op
- op.ifm_resampling_mode = resampling_mode.NEAREST
- op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
- return op
-
-
def fixup_bias_tensors(op, arch, nng):
if op.type.needs_bias() and op.bias is None:
# Op has no bias, add bias tensor filled with zeros
@@ -1577,7 +1655,7 @@ def tflite_optimise_graph(nng, arch):
fixup_conv2d_backprop,
fixup_relus_with_differing_ifm_ofm_scaling,
reorder_depthwise_weights,
- fixup_resizebilinear,
+ fixup_resize,
fixup_bias_tensors,
fixup_asymmetric_weights,
convert_mul_max_to_abs_or_lrelu,