aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py157
1 files changed, 157 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index a1cbb3e2..44f5d6ae 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -50,6 +50,7 @@ from .operation import Operation
from .operation import Padding
from .operation_util import create_add_nop
from .operation_util import create_avgpool_nop
+from .operation_util import create_depthwise_maxpool
from .operation_util import get_pad_values_from_input
from .scaling import quantise_scale
from .shape4d import Shape4D
@@ -460,6 +461,161 @@ def convert_resize_to_upscale_and_average_pool(op):
return op
+def convert_argmax_to_depthwise_conv_and_max_pool(op, arch, nng):
+ """
+ Convert ArgMax to DWConv2D->MaxPool->DWConv2D, see details below.
+
+ Example:
+ arr = [4, [00000100,
+ 6, = 00000110, # <-- This is the largest value, so we're expecting argmax(arr) = 1
+ 5] 00000101]
+
+ Use 16-bit precision and shift all values 7 bits to the left:
+ Shifted_arr = [0000001000000000,
+ 0000001100000000,
+ 0000001010000000]
+
+ Add "c - index of channel" to each channel:
+ Shifted_arr_plus_reverse_idx = [0000001000000010, (+2)
+ 0000001100000001, (+1)
+ 0000001010000000] (+0)
+
+ The index is reversed since ArgMax selects the lowest index if maximum value is found at two index. The index will
+ act as a tie-breaker between channels with equal values and since we want the smallest channel index to be chosen
+ we reverse the index before the maxpool and then subtract the index from the number of channel after the maxpool to
+ get the correct index.
+
+ Find the maximum value in the array:
+ val = max(shifted_arr_plus_reverse_idx) = 0000001100000001
+
+ Subtract the value from the number of channels:
+ shifted_arr_plus_idx = (c-1) - val = 2 - 1 = 1
+
+ Extract the 7 lowest bits using a LUT to cut off the 9 most significant bits:
+ idx = LUT(val) = 0000000000000001 = 1
+ """
+
+ if op.type == Op.ArgMax:
+ ifm, ofm = op.inputs[0], op.outputs[0]
+ identity_quant = QuantizationParameters()
+ identity_quant.zero_point = 0
+ identity_quant.scale_f32 = 1.0
+ if ofm.quantization is None:
+ ofm.quantization = identity_quant
+ # Add last dimension to ofm shape
+ ofm.shape += [1]
+ ofm.ops = []
+
+ # Create 1x1 Depthwise convolution with 2**7 weights for each channel to convert precision to 16 bit and shift
+ # all values 7 bits to the left
+ # Set necessary depthwise attributes
+ dw_op_attrs = {
+ "padding": Padding.VALID,
+ "stride_h": 1,
+ "stride_w": 1,
+ "strides": (1, 1, 1, 1),
+ "depth_multiplier": 1,
+ "channel_multiplier": 1,
+ "dilation_h_factor": 1,
+ "dilation_w_factor": 1,
+ "dilation": (1, 1, 1, 1),
+ "explicit_padding": None,
+ }
+ op.name = "depthwise_conv_SHL_7"
+ op.type = Op.DepthwiseConv2DBias
+ op.attrs.update(dw_op_attrs)
+ n, h, w, c = ifm.shape
+ shape = [1, 1, 1, c]
+ kernel = np.dstack([2**7] * c)
+ op.inputs = []
+ op.add_input_tensor(ifm)
+ op.add_input_tensor(
+ create_const_tensor(
+ "weights",
+ shape,
+ DataType.uint8,
+ np.array(kernel).reshape(shape),
+ quantization=identity_quant,
+ ),
+ )
+ # Let the bias for each channel be the "reverse" index of the channel it is in, ie c - channel_idx
+ reverse_idxs = list(reversed(range(c)))
+ bias_tensor = create_const_tensor(op.name + "_bias", [c], DataType.int64, reverse_idxs)
+ op.add_input_tensor(bias_tensor)
+
+ intermediate_tens = Tensor([n, h, w, c], DataType.int16, "int16_and_shifted_7_bits_left")
+ intermediate_tens.quantization = ifm.quantization
+ op.set_output_tensor(intermediate_tens)
+ op.set_ifm_ofm_shapes()
+ orig_ifm_shape = op.ifm_shapes[0]
+ DebugDatabase.add_optimised(op, op)
+
+ # To extract 7 least significant bits and swap reverse index back to real index using a LUT activation, we set
+ # the base value to c-1 and slope to -128. The 16-bit LUT uses a table of 32-bit values where the top 16 bits
+ # represent the slope and bottom 16 bits the base which are used to interpolate the activation value.
+ slope = (-128 & 0xFFFF) << 16 # Top 16 bits of 32 bit LUT table value
+ base = c - 1 # Bottom 16 bits of the LUT table value
+ lut_tensor = create_const_tensor(
+ "maxpool_LUT_extract_7_LSB",
+ [1, 1, 1, 512],
+ DataType.uint32,
+ [slope + base] * 512,
+ TensorPurpose.LUT,
+ )
+
+ # Split large feature maps into smaller chunks since the Depthwise Maxpool height dimension can overflow due to
+ # flattening the ifm to (H*W)xCx1
+ max_height = 2**16 // orig_ifm_shape.width
+ num_full_height_ops = orig_ifm_shape.height // max_height
+ last_op_height = orig_ifm_shape.height - max_height * num_full_height_ops
+ op_heights = [max_height] * num_full_height_ops
+ if last_op_height > 0:
+ op_heights.append(last_op_height)
+
+ # Create maxpool output tensor which is reshaped to 1x(H*W)x1x1. The product H*W might be larger than the
+ # maximum allowed height, but that's handled by reading and writing the data in chunks
+ maxpool_ofm = Tensor([1, orig_ifm_shape.height * orig_ifm_shape.width, 1, 1], DataType.int16, "argmax_maxpool")
+ maxpool_ofm.quantization = identity_quant
+
+ for op_idx, op_height in enumerate(op_heights):
+ maxpool_op = create_depthwise_maxpool(
+ f"dw_maxpool_{op_idx}", intermediate_tens, orig_ifm_shape, identity_quant
+ )
+ maxpool_op.outputs = [maxpool_ofm]
+ maxpool_ofm.ops.append(maxpool_op)
+ maxpool_op.ofm_shapes = [Shape4D(maxpool_ofm.shape)]
+ maxpool_op.set_activation_lut(lut_tensor)
+
+ # Set read and write shapes/offsets to read/write chunks of the IFM/OFM
+ maxpool_op.read_shapes[0] = Shape4D([1, op_height * orig_ifm_shape.width, orig_ifm_shape.depth, 1])
+ maxpool_op.read_offsets[0] = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
+ maxpool_op.write_shape = Shape4D([1, op_height * orig_ifm_shape.width, 1, 1])
+ maxpool_op.write_offset = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
+ DebugDatabase.add_optimised(op, maxpool_op)
+
+ # Convert output to OFM dtype and reshape back to original OFM shape with 1x1 DWConv
+ dw_conv = Operation(Op.DepthwiseConv2DBias, f"depthwise_conv_convert_to_32bit_{op_idx}")
+ dw_conv.attrs.update(dw_op_attrs)
+ dw_conv.inputs = [maxpool_op.ofm]
+ dw_conv.add_input_tensor(
+ create_const_tensor(
+ "weights",
+ [1, 1, 1, 1],
+ DataType.uint8,
+ np.array([1]).reshape([1, 1, 1, 1]),
+ quantization=identity_quant,
+ ),
+ )
+ dw_conv.add_input_tensor(create_const_tensor(dw_conv.name + "_bias", [1], DataType.int64, [0]))
+ ofm.ops.append(dw_conv)
+ dw_conv.outputs = [ofm]
+ dw_conv.ifm_shapes.append(Shape4D([1, orig_ifm_shape.height, orig_ifm_shape.width, 1]))
+ dw_conv.ofm_shapes.append(Shape4D(ofm.shape))
+ DebugDatabase.add_optimised(op, dw_conv)
+
+ return op
+
+
def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
def _compute_interpolation_values(index, input_size, output_size):
scale = input_size / output_size
@@ -1976,6 +2132,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
fixup_conv2d_backprop,
fixup_relus_with_differing_ifm_ofm_scaling,
reorder_depthwise_weights,
+ convert_argmax_to_depthwise_conv_and_max_pool,
fixup_resize,
fixup_bias_tensors,
fixup_asymmetric_weights,