From 4f728c04bcc90742d9d57b0e253be68a7251984f Mon Sep 17 00:00:00 2001 From: Dwight Lidman Date: Thu, 17 Dec 2020 15:14:45 +0100 Subject: MLBEDSW-1499: Add MEAN operator This commit adds support for the MEAN operator, with some caveats. Signed-off-by: Dwight Lidman Change-Id: I165cb26cb5aefd68e70d2cfc68291ccf7b778921 --- ethosu/vela/supported_operators.py | 93 +++++++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) (limited to 'ethosu/vela/supported_operators.py') diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 8b759beb..a82f8124 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -75,6 +75,8 @@ class SupportedOperators: | resizing_ops # FC layers | fc_vector_products + # Mean (converts to depthwise conv) + | set((Op.Mean,)) ) unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op) binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,)) @@ -99,7 +101,7 @@ class SupportedOperators: split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,)) concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,)) memory_only_ops = set((Op.Reshape, Op.QuantizedReshape,)) | concat_ops | split_ops - shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,)) + shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean)) per_axis_quant_ops = convolution_like_ops # per-axis/channel quantization only currently supported for conv ops supported_fused_activations = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.LUT,)) supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | pad_ops | npu_post_ops | memory_only_ops @@ -118,6 +120,8 @@ class SupportedOperators: filter_range = (1, 8) filter_height_range = (1, 256) filter_product_range = (1, 256 * 256) + mean_kernel_product = 64 * 64 + mean_kernel_product_int8 = 16 * 16 # Supported consumers supported_pad_consumers = convolution_ops | depthwise_convolution_ops | pooling_ops @@ -268,6 +272,13 @@ class SupportedOperators: # HardSwish specific checks: self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit) self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_matching_in_out_types) + # Mean specific checks: + self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_input_8bit) + self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_properties) + self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_input_dims) + self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_axis) + self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product) + self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product_int8) def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) @@ -1077,3 +1088,83 @@ class SupportedOperators: if op.attrs.get("keep_num_dims"): valid = len(op.ifm.shape) == len(op.ofm.shape) return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}" + + def constraint_mean_input_dims(op): + "Input tensor must be at least 2D" + dims = len(op.inputs[0].shape) + return 2 <= dims <= 4, f"Input is {dims}D" + + @staticmethod + def constraint_mean_axis(op): + "Axis indices must correspond to height and width axes" + dims = len(op.inputs[0].shape) + axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values) + if dims == 2 or dims == 3: + valid = axis in (0, 1, [0, 1], [1, 0]) + elif dims == 4: + valid = axis in (1, 2, [1, 2], [2, 1]) + return valid, f"Axis is {axis}" + + @classmethod + @docstring_format_args([mean_kernel_product]) + def constraint_mean_height_width_product(cls, op): + "Product of height and width can be at most {}" + shape = op.inputs[0].shape + hi = 0 if len(shape) < 4 else 1 + h, w = shape[hi : hi + 2] + max_prod = cls.mean_kernel_product + return h * w <= max_prod, f"Product of height and width is {h * w}" + + @classmethod + @docstring_format_args([mean_kernel_product_int8]) + def constraint_mean_height_width_product_int8(cls, op): + """Product of IFM height and width can be at most {} when the following are true: + IFM dimensions are 4, + Axis indices are 1 and 2, + keep_dims is set to True and + IFM datatype is int8""" + shape = op.ifm.shape + axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values) + if ( + len(shape) != 4 + or op.ifm.dtype != DataType.int8 + or not op.attrs.get("keep_dims") + or axis not in ([1, 2], [2, 1]) + ): + return True, "" + hi = 0 if len(shape) < 4 else 1 + h, w = shape[hi : hi + 2] + max_prod = cls.mean_kernel_product_int8 + return h * w <= max_prod, f"Product of height and width is {h * w}" + + @staticmethod + def constraint_mean_properties(op): + """Every constraint in either one (or both) of the following sets of constraints must be fulfilled: + Set A: + IFM dimensions are 4, + Axis indices are 1 and 2, + keep_dims is set to True + Set B: + IFM zero point and OFM zero point are the same, + IFM scale and OFM scale are the same""" + seta, setb = True, True + extra = [] + axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values) + if len(op.ifm.shape) != 4: + seta = False + extra.append(f"IFM shape is {op.ifm.shape}") + if not any(np.array_equal(axis, ax) for ax in ([1, 2], [2, 1])): + seta = False + extra.append(f"Axis is {axis}") + if not op.attrs.get("keep_dims"): + seta = False + extra.append("keep_dims is False") + ifmq, ofmq = op.ifm.quantization, op.ofm.quantization + if ifmq.zero_point != ofmq.zero_point: + setb = False + extra.append("IFM zero point does not match OFM zero point") + if ifmq.scale_f32 != ofmq.scale_f32: + setb = False + extra.append("IFM scale does not match OFM scale") + extra = ", ".join(extra) + return seta or setb, f"The following constraints were not fulfilled: {extra}" -- cgit v1.2.1