aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--SUPPORTED_OPS.md18
-rw-r--r--ethosu/vela/operation.py2
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py157
-rw-r--r--ethosu/vela/tflite_mapping.py4
-rw-r--r--ethosu/vela/tflite_model_semantic.py10
-rw-r--r--ethosu/vela/tflite_supported_operators.py39
6 files changed, 213 insertions, 17 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 860d1fe6..6f6167d2 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -1,7 +1,7 @@
# Supported Ops
This file was automatically generated by Vela using the `--supported-ops-report` parameter.
-Vela version: `3.7.1.dev2+g19f8967.d20230301`
+Vela version: `3.7.1.dev8+ga182a70.d20230322`
This file complies with
[**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -20,6 +20,7 @@ Please check the supported operator list for your chosen runtime for further inf
| --- | --- |
| ABS | [Generic](#tflite-generic-constraints), [Specific](#tflite-abs-constraints) |
| ADD | [Generic](#tflite-generic-constraints), [Specific](#tflite-add-constraints) |
+| ARG_MAX | [Generic](#tflite-generic-constraints), [Specific](#tflite-arg_max-constraints) |
| AVERAGE_POOL_2D | [Generic](#tflite-generic-constraints), [Specific](#tflite-average_pool_2d-constraints) |
| CONCATENATION | [Generic](#tflite-generic-constraints), [Specific](#tflite-concatenation-constraints) |
| CONV_2D | [Generic](#tflite-generic-constraints), [Specific](#tflite-conv_2d-constraints) |
@@ -64,14 +65,14 @@ This is a list of constraints most NPU operators must satisfy in order to be sch
- Input(s) and Output tensors must not be dynamic - [QUANTIZE]
- Input(s) and Output tensors must have a defined shape
- Output tensors cannot be scalar - [QUANTIZE]
-- Scalar Input tensors are only valid for op type: ADD, EXPAND_DIMS, MAXIMUM, MEAN, MINIMUM, MUL, QUANTIZE, SPLIT, SPLIT_V, SUB
+- Scalar Input tensors are only valid for op type: ADD, ARG_MAX, EXPAND_DIMS, MAXIMUM, MEAN, MINIMUM, MUL, QUANTIZE, SPLIT, SPLIT_V, SUB
- Input(s) and Output tensors must not be greater than 4D
-- Input(s), Output and Weight tensors must have quantization parameters - [SHAPE]
+- Input(s), Output and Weight tensors must have quantization parameters - [ARG_MAX, SHAPE]
- Input(s), Output and Weight tensors with quantization scales must be finite
- Input and Output tensors must have quantization scales that fit within float32 precision
- Constant tensors should not have NoneType-values
- Tensors must be of type: int16, int32, int8, uint8
-- Tensors which are int32 are only valid when op type is: ADD, MUL, SHAPE, SUB
+- Tensors which are int32 are only valid when op type is: ADD, ARG_MAX, MUL, SHAPE, SUB
- Tensor dimensions must be in the range [1, 65535]
- Per-axis quantization is only supported for the following op types: CONV_2D, DEPTHWISE_CONV_2D, TRANSPOSE_CONV
- IFM Tensor batch size must be 1 - [FULLY_CONNECTED, RESHAPE, SHAPE, SLICE, SOFTMAX, SPLIT, SPLIT_V, SQUEEZE, STRIDED_SLICE, UNPACK]
@@ -95,6 +96,15 @@ This is a list of constraints that the ADD operator must satisfy in order to be
- For IFM that are unsigned, OFM must either be the same type or int32
- Broadcasting is only allowed for rank indices with dimension 1, from either IFM1 or IFM2
+### TFLite ARG_MAX Constraints
+
+This is a list of constraints that the ARG_MAX operator must satisfy in order to be scheduled on the NPU.
+
+- IFM must be int8 or uint8
+- Number of input dimensions must be 4
+- Operation must be performed along the depth axis
+- IFM depth must be no greater than 127
+
### TFLite AVERAGE_POOL_2D Constraints
This is a list of constraints that the AVERAGE_POOL_2D operator must satisfy in order to be scheduled on the NPU.
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 6be9dc25..67717104 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -134,7 +134,7 @@ class Op(Enum):
Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
AddN = OperatorInfo()
Any = OperatorInfo()
- ArgMax = OperatorInfo()
+ ArgMax = OperatorInfo(indices=NNG_IFM_INDICES)
ArgMin = OperatorInfo()
AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
Atan2 = OperatorInfo(indices=NNG_IFM_IFM2_INDICES)
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index a1cbb3e2..44f5d6ae 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -50,6 +50,7 @@ from .operation import Operation
from .operation import Padding
from .operation_util import create_add_nop
from .operation_util import create_avgpool_nop
+from .operation_util import create_depthwise_maxpool
from .operation_util import get_pad_values_from_input
from .scaling import quantise_scale
from .shape4d import Shape4D
@@ -460,6 +461,161 @@ def convert_resize_to_upscale_and_average_pool(op):
return op
+def convert_argmax_to_depthwise_conv_and_max_pool(op, arch, nng):
+ """
+ Convert ArgMax to DWConv2D->MaxPool->DWConv2D, see details below.
+
+ Example:
+ arr = [4, [00000100,
+ 6, = 00000110, # <-- This is the largest value, so we're expecting argmax(arr) = 1
+ 5] 00000101]
+
+ Use 16-bit precision and shift all values 7 bits to the left:
+ Shifted_arr = [0000001000000000,
+ 0000001100000000,
+ 0000001010000000]
+
+ Add "c - index of channel" to each channel:
+ Shifted_arr_plus_reverse_idx = [0000001000000010, (+2)
+ 0000001100000001, (+1)
+ 0000001010000000] (+0)
+
+ The index is reversed since ArgMax selects the lowest index if maximum value is found at two index. The index will
+ act as a tie-breaker between channels with equal values and since we want the smallest channel index to be chosen
+ we reverse the index before the maxpool and then subtract the index from the number of channel after the maxpool to
+ get the correct index.
+
+ Find the maximum value in the array:
+ val = max(shifted_arr_plus_reverse_idx) = 0000001100000001
+
+ Subtract the value from the number of channels:
+ shifted_arr_plus_idx = (c-1) - val = 2 - 1 = 1
+
+ Extract the 7 lowest bits using a LUT to cut off the 9 most significant bits:
+ idx = LUT(val) = 0000000000000001 = 1
+ """
+
+ if op.type == Op.ArgMax:
+ ifm, ofm = op.inputs[0], op.outputs[0]
+ identity_quant = QuantizationParameters()
+ identity_quant.zero_point = 0
+ identity_quant.scale_f32 = 1.0
+ if ofm.quantization is None:
+ ofm.quantization = identity_quant
+ # Add last dimension to ofm shape
+ ofm.shape += [1]
+ ofm.ops = []
+
+ # Create 1x1 Depthwise convolution with 2**7 weights for each channel to convert precision to 16 bit and shift
+ # all values 7 bits to the left
+ # Set necessary depthwise attributes
+ dw_op_attrs = {
+ "padding": Padding.VALID,
+ "stride_h": 1,
+ "stride_w": 1,
+ "strides": (1, 1, 1, 1),
+ "depth_multiplier": 1,
+ "channel_multiplier": 1,
+ "dilation_h_factor": 1,
+ "dilation_w_factor": 1,
+ "dilation": (1, 1, 1, 1),
+ "explicit_padding": None,
+ }
+ op.name = "depthwise_conv_SHL_7"
+ op.type = Op.DepthwiseConv2DBias
+ op.attrs.update(dw_op_attrs)
+ n, h, w, c = ifm.shape
+ shape = [1, 1, 1, c]
+ kernel = np.dstack([2**7] * c)
+ op.inputs = []
+ op.add_input_tensor(ifm)
+ op.add_input_tensor(
+ create_const_tensor(
+ "weights",
+ shape,
+ DataType.uint8,
+ np.array(kernel).reshape(shape),
+ quantization=identity_quant,
+ ),
+ )
+ # Let the bias for each channel be the "reverse" index of the channel it is in, ie c - channel_idx
+ reverse_idxs = list(reversed(range(c)))
+ bias_tensor = create_const_tensor(op.name + "_bias", [c], DataType.int64, reverse_idxs)
+ op.add_input_tensor(bias_tensor)
+
+ intermediate_tens = Tensor([n, h, w, c], DataType.int16, "int16_and_shifted_7_bits_left")
+ intermediate_tens.quantization = ifm.quantization
+ op.set_output_tensor(intermediate_tens)
+ op.set_ifm_ofm_shapes()
+ orig_ifm_shape = op.ifm_shapes[0]
+ DebugDatabase.add_optimised(op, op)
+
+ # To extract 7 least significant bits and swap reverse index back to real index using a LUT activation, we set
+ # the base value to c-1 and slope to -128. The 16-bit LUT uses a table of 32-bit values where the top 16 bits
+ # represent the slope and bottom 16 bits the base which are used to interpolate the activation value.
+ slope = (-128 & 0xFFFF) << 16 # Top 16 bits of 32 bit LUT table value
+ base = c - 1 # Bottom 16 bits of the LUT table value
+ lut_tensor = create_const_tensor(
+ "maxpool_LUT_extract_7_LSB",
+ [1, 1, 1, 512],
+ DataType.uint32,
+ [slope + base] * 512,
+ TensorPurpose.LUT,
+ )
+
+ # Split large feature maps into smaller chunks since the Depthwise Maxpool height dimension can overflow due to
+ # flattening the ifm to (H*W)xCx1
+ max_height = 2**16 // orig_ifm_shape.width
+ num_full_height_ops = orig_ifm_shape.height // max_height
+ last_op_height = orig_ifm_shape.height - max_height * num_full_height_ops
+ op_heights = [max_height] * num_full_height_ops
+ if last_op_height > 0:
+ op_heights.append(last_op_height)
+
+ # Create maxpool output tensor which is reshaped to 1x(H*W)x1x1. The product H*W might be larger than the
+ # maximum allowed height, but that's handled by reading and writing the data in chunks
+ maxpool_ofm = Tensor([1, orig_ifm_shape.height * orig_ifm_shape.width, 1, 1], DataType.int16, "argmax_maxpool")
+ maxpool_ofm.quantization = identity_quant
+
+ for op_idx, op_height in enumerate(op_heights):
+ maxpool_op = create_depthwise_maxpool(
+ f"dw_maxpool_{op_idx}", intermediate_tens, orig_ifm_shape, identity_quant
+ )
+ maxpool_op.outputs = [maxpool_ofm]
+ maxpool_ofm.ops.append(maxpool_op)
+ maxpool_op.ofm_shapes = [Shape4D(maxpool_ofm.shape)]
+ maxpool_op.set_activation_lut(lut_tensor)
+
+ # Set read and write shapes/offsets to read/write chunks of the IFM/OFM
+ maxpool_op.read_shapes[0] = Shape4D([1, op_height * orig_ifm_shape.width, orig_ifm_shape.depth, 1])
+ maxpool_op.read_offsets[0] = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
+ maxpool_op.write_shape = Shape4D([1, op_height * orig_ifm_shape.width, 1, 1])
+ maxpool_op.write_offset = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
+ DebugDatabase.add_optimised(op, maxpool_op)
+
+ # Convert output to OFM dtype and reshape back to original OFM shape with 1x1 DWConv
+ dw_conv = Operation(Op.DepthwiseConv2DBias, f"depthwise_conv_convert_to_32bit_{op_idx}")
+ dw_conv.attrs.update(dw_op_attrs)
+ dw_conv.inputs = [maxpool_op.ofm]
+ dw_conv.add_input_tensor(
+ create_const_tensor(
+ "weights",
+ [1, 1, 1, 1],
+ DataType.uint8,
+ np.array([1]).reshape([1, 1, 1, 1]),
+ quantization=identity_quant,
+ ),
+ )
+ dw_conv.add_input_tensor(create_const_tensor(dw_conv.name + "_bias", [1], DataType.int64, [0]))
+ ofm.ops.append(dw_conv)
+ dw_conv.outputs = [ofm]
+ dw_conv.ifm_shapes.append(Shape4D([1, orig_ifm_shape.height, orig_ifm_shape.width, 1]))
+ dw_conv.ofm_shapes.append(Shape4D(ofm.shape))
+ DebugDatabase.add_optimised(op, dw_conv)
+
+ return op
+
+
def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
def _compute_interpolation_values(index, input_size, output_size):
scale = input_size / output_size
@@ -1976,6 +2132,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
fixup_conv2d_backprop,
fixup_relus_with_differing_ifm_ofm_scaling,
reorder_depthwise_weights,
+ convert_argmax_to_depthwise_conv_and_max_pool,
fixup_resize,
fixup_bias_tensors,
fixup_asymmetric_weights,
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index 8ec01737..98fe287d 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -749,7 +749,7 @@ builtin_operator_map = {
BuiltinOperator.ARG_MAX: (
Op.ArgMax,
OptionsSerializer("ArgMaxOptions", (("output_type", datatype_deserialize, datatype_serialize),)),
- TFLITE_NO_INDICES,
+ TFLITE_IFM_INDICES,
),
BuiltinOperator.MINIMUM: (Op.Minimum, OptionsSerializer("MaximumMinimumOptions"), TFLITE_IFM_IFM2_INDICES),
BuiltinOperator.LESS: (Op.Less, OptionsSerializer("LessOptions"), TFLITE_NO_INDICES),
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 9f53a1e6..495d71a6 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -77,7 +77,9 @@ class TFLiteSemantic:
)
binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
- shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize))
+ shapeless_input_ops = binary_elem_wise_main_ops | set(
+ (Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize, Op.ArgMax)
+ )
reshape_ops = set(
(
Op.Reshape,
@@ -187,6 +189,9 @@ class TFLiteSemantic:
self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims)
self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis)
+ # ArgMax specific checks:
+ self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit)
+
def is_operator_semantic_valid(self, op):
ext_type = optype_to_builtintype(op.type)
@@ -226,6 +231,9 @@ class TFLiteSemantic:
TFLiteSemantic.constraint_tens_no_dynamic,
TFLiteSemantic.constraint_tens_output_scalar,
],
+ Op.ArgMax: [
+ TFLiteSemantic.constraint_tens_quant_none_check,
+ ],
}
return generic_constraints_exclude_list
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 26ccfeb6..fd9a9c20 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -81,6 +81,8 @@ class TFLiteSupportedOperators:
| fc_vector_products
# Mean (converts to depthwise conv)
| set((Op.Mean,))
+ # ArgMax (converts to depthwise conv and maxpool)
+ | set((Op.ArgMax,))
)
unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
binary_elem_wise_min_max_ops = set(
@@ -106,15 +108,7 @@ class TFLiteSupportedOperators:
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
pad_ops = set((Op.Pad,))
supported_int32_tensor_ops = (
- set(
- (
- Op.ReduceSum,
- Op.CLZ,
- Op.Shape,
- )
- )
- | binary_elem_wise_add_mul_sub
- | binary_elem_wise_shift_ops
+ set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
)
relu_ops = set(
@@ -321,6 +315,11 @@ class TFLiteSupportedOperators:
# Reshape specific checks:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
+ # ArgMax specific checks:
+ self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_input_dimensions)
+ self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis)
+ self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -873,3 +872,25 @@ class TFLiteSupportedOperators:
extra = ", ".join(extra)
return valid, f"Op has non-const input(s): {extra}"
+
+ @staticmethod
+ def constraint_argmax_axis(op):
+ "Operation must be performed along the depth axis"
+ inp_dims = len(op.inputs[0].shape)
+ axis = op.inputs[1].values
+ return (
+ axis in (3, -1),
+ f"Axis is {axis} and number of input dimensions is {inp_dims}",
+ )
+
+ @staticmethod
+ def constraint_argmax_input_dimensions(op):
+ "Number of input dimensions must be 4"
+ inp_dims = len(op.inputs[0].shape)
+ return inp_dims == 4, f"Number of input dimensions is {inp_dims}"
+
+ @staticmethod
+ def constraint_argmax_depth(op):
+ "IFM depth must be no greater than 127"
+ ifm_depth = op.inputs[0].shape[-1]
+ return ifm_depth <= 127, f"IFM depth is {ifm_depth}"