From 6986a079020ab6344c9191aa67af13beeb475593 Mon Sep 17 00:00:00 2001 From: Rickard Bolin Date: Mon, 19 Dec 2022 12:33:40 +0000 Subject: MLBEDSW-6435: Implement support for ArgMax along depth dimension - Add support for ArgMax along depth dimension with a depth limit of 127. - Only supports 8-bit input and 32-bit output Signed-off-by: Rickard Bolin Change-Id: I5f6f0503135bebabbb1ca637f9729587b7c60740 --- ethosu/vela/operation.py | 2 +- ethosu/vela/tflite_graph_optimiser.py | 157 ++++++++++++++++++++++++++++++ ethosu/vela/tflite_mapping.py | 4 +- ethosu/vela/tflite_model_semantic.py | 10 +- ethosu/vela/tflite_supported_operators.py | 39 ++++++-- 5 files changed, 199 insertions(+), 13 deletions(-) (limited to 'ethosu') diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 6be9dc25..67717104 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -134,7 +134,7 @@ class Op(Enum): Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) AddN = OperatorInfo() Any = OperatorInfo() - ArgMax = OperatorInfo() + ArgMax = OperatorInfo(indices=NNG_IFM_INDICES) ArgMin = OperatorInfo() AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES) Atan2 = OperatorInfo(indices=NNG_IFM_IFM2_INDICES) diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index a1cbb3e2..44f5d6ae 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -50,6 +50,7 @@ from .operation import Operation from .operation import Padding from .operation_util import create_add_nop from .operation_util import create_avgpool_nop +from .operation_util import create_depthwise_maxpool from .operation_util import get_pad_values_from_input from .scaling import quantise_scale from .shape4d import Shape4D @@ -460,6 +461,161 @@ def convert_resize_to_upscale_and_average_pool(op): return op +def convert_argmax_to_depthwise_conv_and_max_pool(op, arch, nng): + """ + Convert ArgMax to DWConv2D->MaxPool->DWConv2D, see details below. + + Example: + arr = [4, [00000100, + 6, = 00000110, # <-- This is the largest value, so we're expecting argmax(arr) = 1 + 5] 00000101] + + Use 16-bit precision and shift all values 7 bits to the left: + Shifted_arr = [0000001000000000, + 0000001100000000, + 0000001010000000] + + Add "c - index of channel" to each channel: + Shifted_arr_plus_reverse_idx = [0000001000000010, (+2) + 0000001100000001, (+1) + 0000001010000000] (+0) + + The index is reversed since ArgMax selects the lowest index if maximum value is found at two index. The index will + act as a tie-breaker between channels with equal values and since we want the smallest channel index to be chosen + we reverse the index before the maxpool and then subtract the index from the number of channel after the maxpool to + get the correct index. + + Find the maximum value in the array: + val = max(shifted_arr_plus_reverse_idx) = 0000001100000001 + + Subtract the value from the number of channels: + shifted_arr_plus_idx = (c-1) - val = 2 - 1 = 1 + + Extract the 7 lowest bits using a LUT to cut off the 9 most significant bits: + idx = LUT(val) = 0000000000000001 = 1 + """ + + if op.type == Op.ArgMax: + ifm, ofm = op.inputs[0], op.outputs[0] + identity_quant = QuantizationParameters() + identity_quant.zero_point = 0 + identity_quant.scale_f32 = 1.0 + if ofm.quantization is None: + ofm.quantization = identity_quant + # Add last dimension to ofm shape + ofm.shape += [1] + ofm.ops = [] + + # Create 1x1 Depthwise convolution with 2**7 weights for each channel to convert precision to 16 bit and shift + # all values 7 bits to the left + # Set necessary depthwise attributes + dw_op_attrs = { + "padding": Padding.VALID, + "stride_h": 1, + "stride_w": 1, + "strides": (1, 1, 1, 1), + "depth_multiplier": 1, + "channel_multiplier": 1, + "dilation_h_factor": 1, + "dilation_w_factor": 1, + "dilation": (1, 1, 1, 1), + "explicit_padding": None, + } + op.name = "depthwise_conv_SHL_7" + op.type = Op.DepthwiseConv2DBias + op.attrs.update(dw_op_attrs) + n, h, w, c = ifm.shape + shape = [1, 1, 1, c] + kernel = np.dstack([2**7] * c) + op.inputs = [] + op.add_input_tensor(ifm) + op.add_input_tensor( + create_const_tensor( + "weights", + shape, + DataType.uint8, + np.array(kernel).reshape(shape), + quantization=identity_quant, + ), + ) + # Let the bias for each channel be the "reverse" index of the channel it is in, ie c - channel_idx + reverse_idxs = list(reversed(range(c))) + bias_tensor = create_const_tensor(op.name + "_bias", [c], DataType.int64, reverse_idxs) + op.add_input_tensor(bias_tensor) + + intermediate_tens = Tensor([n, h, w, c], DataType.int16, "int16_and_shifted_7_bits_left") + intermediate_tens.quantization = ifm.quantization + op.set_output_tensor(intermediate_tens) + op.set_ifm_ofm_shapes() + orig_ifm_shape = op.ifm_shapes[0] + DebugDatabase.add_optimised(op, op) + + # To extract 7 least significant bits and swap reverse index back to real index using a LUT activation, we set + # the base value to c-1 and slope to -128. The 16-bit LUT uses a table of 32-bit values where the top 16 bits + # represent the slope and bottom 16 bits the base which are used to interpolate the activation value. + slope = (-128 & 0xFFFF) << 16 # Top 16 bits of 32 bit LUT table value + base = c - 1 # Bottom 16 bits of the LUT table value + lut_tensor = create_const_tensor( + "maxpool_LUT_extract_7_LSB", + [1, 1, 1, 512], + DataType.uint32, + [slope + base] * 512, + TensorPurpose.LUT, + ) + + # Split large feature maps into smaller chunks since the Depthwise Maxpool height dimension can overflow due to + # flattening the ifm to (H*W)xCx1 + max_height = 2**16 // orig_ifm_shape.width + num_full_height_ops = orig_ifm_shape.height // max_height + last_op_height = orig_ifm_shape.height - max_height * num_full_height_ops + op_heights = [max_height] * num_full_height_ops + if last_op_height > 0: + op_heights.append(last_op_height) + + # Create maxpool output tensor which is reshaped to 1x(H*W)x1x1. The product H*W might be larger than the + # maximum allowed height, but that's handled by reading and writing the data in chunks + maxpool_ofm = Tensor([1, orig_ifm_shape.height * orig_ifm_shape.width, 1, 1], DataType.int16, "argmax_maxpool") + maxpool_ofm.quantization = identity_quant + + for op_idx, op_height in enumerate(op_heights): + maxpool_op = create_depthwise_maxpool( + f"dw_maxpool_{op_idx}", intermediate_tens, orig_ifm_shape, identity_quant + ) + maxpool_op.outputs = [maxpool_ofm] + maxpool_ofm.ops.append(maxpool_op) + maxpool_op.ofm_shapes = [Shape4D(maxpool_ofm.shape)] + maxpool_op.set_activation_lut(lut_tensor) + + # Set read and write shapes/offsets to read/write chunks of the IFM/OFM + maxpool_op.read_shapes[0] = Shape4D([1, op_height * orig_ifm_shape.width, orig_ifm_shape.depth, 1]) + maxpool_op.read_offsets[0] = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0]) + maxpool_op.write_shape = Shape4D([1, op_height * orig_ifm_shape.width, 1, 1]) + maxpool_op.write_offset = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0]) + DebugDatabase.add_optimised(op, maxpool_op) + + # Convert output to OFM dtype and reshape back to original OFM shape with 1x1 DWConv + dw_conv = Operation(Op.DepthwiseConv2DBias, f"depthwise_conv_convert_to_32bit_{op_idx}") + dw_conv.attrs.update(dw_op_attrs) + dw_conv.inputs = [maxpool_op.ofm] + dw_conv.add_input_tensor( + create_const_tensor( + "weights", + [1, 1, 1, 1], + DataType.uint8, + np.array([1]).reshape([1, 1, 1, 1]), + quantization=identity_quant, + ), + ) + dw_conv.add_input_tensor(create_const_tensor(dw_conv.name + "_bias", [1], DataType.int64, [0])) + ofm.ops.append(dw_conv) + dw_conv.outputs = [ofm] + dw_conv.ifm_shapes.append(Shape4D([1, orig_ifm_shape.height, orig_ifm_shape.width, 1])) + dw_conv.ofm_shapes.append(Shape4D(ofm.shape)) + DebugDatabase.add_optimised(op, dw_conv) + + return op + + def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True): def _compute_interpolation_values(index, input_size, output_size): scale = input_size / output_size @@ -1976,6 +2132,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): fixup_conv2d_backprop, fixup_relus_with_differing_ifm_ofm_scaling, reorder_depthwise_weights, + convert_argmax_to_depthwise_conv_and_max_pool, fixup_resize, fixup_bias_tensors, fixup_asymmetric_weights, diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py index 8ec01737..98fe287d 100644 --- a/ethosu/vela/tflite_mapping.py +++ b/ethosu/vela/tflite_mapping.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -749,7 +749,7 @@ builtin_operator_map = { BuiltinOperator.ARG_MAX: ( Op.ArgMax, OptionsSerializer("ArgMaxOptions", (("output_type", datatype_deserialize, datatype_serialize),)), - TFLITE_NO_INDICES, + TFLITE_IFM_INDICES, ), BuiltinOperator.MINIMUM: (Op.Minimum, OptionsSerializer("MaximumMinimumOptions"), TFLITE_IFM_IFM2_INDICES), BuiltinOperator.LESS: (Op.Less, OptionsSerializer("LessOptions"), TFLITE_NO_INDICES), diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 9f53a1e6..495d71a6 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -77,7 +77,9 @@ class TFLiteSemantic: ) binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops - shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize)) + shapeless_input_ops = binary_elem_wise_main_ops | set( + (Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize, Op.ArgMax) + ) reshape_ops = set( ( Op.Reshape, @@ -187,6 +189,9 @@ class TFLiteSemantic: self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims) self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis) + # ArgMax specific checks: + self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit) + def is_operator_semantic_valid(self, op): ext_type = optype_to_builtintype(op.type) @@ -226,6 +231,9 @@ class TFLiteSemantic: TFLiteSemantic.constraint_tens_no_dynamic, TFLiteSemantic.constraint_tens_output_scalar, ], + Op.ArgMax: [ + TFLiteSemantic.constraint_tens_quant_none_check, + ], } return generic_constraints_exclude_list diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 26ccfeb6..fd9a9c20 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -81,6 +81,8 @@ class TFLiteSupportedOperators: | fc_vector_products # Mean (converts to depthwise conv) | set((Op.Mean,)) + # ArgMax (converts to depthwise conv and maxpool) + | set((Op.ArgMax,)) ) unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op) binary_elem_wise_min_max_ops = set( @@ -106,15 +108,7 @@ class TFLiteSupportedOperators: elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops pad_ops = set((Op.Pad,)) supported_int32_tensor_ops = ( - set( - ( - Op.ReduceSum, - Op.CLZ, - Op.Shape, - ) - ) - | binary_elem_wise_add_mul_sub - | binary_elem_wise_shift_ops + set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops ) relu_ops = set( @@ -321,6 +315,11 @@ class TFLiteSupportedOperators: # Reshape specific checks: self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant) + # ArgMax specific checks: + self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_input_dimensions) + self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis) + self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth) + def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) if op.type not in TFLiteSupportedOperators.supported_operators: @@ -873,3 +872,25 @@ class TFLiteSupportedOperators: extra = ", ".join(extra) return valid, f"Op has non-const input(s): {extra}" + + @staticmethod + def constraint_argmax_axis(op): + "Operation must be performed along the depth axis" + inp_dims = len(op.inputs[0].shape) + axis = op.inputs[1].values + return ( + axis in (3, -1), + f"Axis is {axis} and number of input dimensions is {inp_dims}", + ) + + @staticmethod + def constraint_argmax_input_dimensions(op): + "Number of input dimensions must be 4" + inp_dims = len(op.inputs[0].shape) + return inp_dims == 4, f"Number of input dimensions is {inp_dims}" + + @staticmethod + def constraint_argmax_depth(op): + "IFM depth must be no greater than 127" + ifm_depth = op.inputs[0].shape[-1] + return ifm_depth <= 127, f"IFM depth is {ifm_depth}" -- cgit v1.2.1