From 0bb7ad1e8c1a17ae15905152a60c2160bf998fde Mon Sep 17 00:00:00 2001 From: James Peet Date: Tue, 15 Feb 2022 15:07:54 +0000 Subject: MLBEDSW-5554: Constraints for single-axis mean operations on NPU - Combine two MEAN operator checks for single axis averages into one - Only apply that check if the single axis is the height dimension (previously checks were also applied to width averages) - Rephrase some MEAN operator constraint descriptions Signed-off-by: James Peet Change-Id: Ie0577f2b99aba1f3d6a4c39f8934eafe3813b736 --- ethosu/vela/tflite_supported_operators.py | 65 +++++++++++++------------------ 1 file changed, 27 insertions(+), 38 deletions(-) diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 193a23ff..4d826770 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -209,8 +209,7 @@ class TFLiteSupportedOperators: self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_avgpool) self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product) self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_int8) - self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_depthwise_conv_height_single_axis) - self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_avgpool_height_single_axis) + self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_single_axis) # Reshape specific checks: self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant) @@ -637,7 +636,7 @@ class TFLiteSupportedOperators: @classmethod @docstring_format_args([mean_kernel_product_avgpool]) def constraint_mean_height_width_product_avgpool(cls, op): - """Product of height and width can be at most {}""" + """Product of height and width must be no greater than {}""" shape = op.inputs[0].shape hi = 0 if len(shape) < 4 else 1 h, w = shape[hi : hi + 2] @@ -647,8 +646,9 @@ class TFLiteSupportedOperators: @classmethod @docstring_format_args([mean_kernel_product]) def constraint_mean_height_width_product(cls, op): - """Product of height and width can be at most {} when IFM and OFM have different scale or zero point, - or keep_dims is True""" + """Product of height and width must be no greater than {} when: + IFM and OFM have different scale or zero point; or + 'keep_dims' is True""" ifmq, ofmq = op.ifm.quantization, op.ofm.quantization keep_dims = op.attrs.get("keep_dims") # doesn't apply, size is checked by constraint_mean_height_width_product_avgpool @@ -663,10 +663,11 @@ class TFLiteSupportedOperators: @classmethod @docstring_format_args([mean_kernel_product_int8]) def constraint_mean_height_width_product_int8(cls, op): - """Product of IFM height and width can be at most {} when the following are true: - IFM dimensions are 4, - Axis indices are 1 and 2, - keep_dims is set to True and + """Product of IFM height and width must be no greater than {} when: + The IFM shape has 4 dimensions; and + The axis indices specify reduction across 2 dimensions; and + The axis indices correspond to the width and height dimensions of the IFM; and + 'keep_dims' is True; and IFM datatype is int8""" shape = op.ifm.shape axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values) @@ -679,51 +680,39 @@ class TFLiteSupportedOperators: or axis not in ([1, 2], [2, 1]) ): return True, "" - hi = 0 if len(shape) < 4 else 1 - h, w = shape[hi : hi + 2] + h = shape[-3] + w = shape[-2] max_prod = cls.mean_kernel_product_int8 return h * w <= max_prod, f"Product of height and width is {h * w}" @classmethod - @docstring_format_args([dilated_height_range[1]]) - def constraint_depthwise_conv_height_single_axis(cls, op): - """Height can be at most {} for single axis when axis is 1.""" + @docstring_format_args([filter_height_range[1], dilated_height_range[1]]) + def constraint_mean_height_single_axis(cls, op): + """For single axis averages across the height dimension: + IFM height must be no greater than {} if the IFM and OFM scale and zero point match; otherwise + IFM height must be no greater than {} if the IFM and OFM scale or zero point do not match""" inp, axis = op.inputs if axis.shape == [] or axis.shape[0] == 1: # single axis axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0]) else: - # Multiple axes, constraint does not apply + # Multiple axes return True, "" - # Height and width axes have different index depending on dimensions shape = inp.shape - h = shape[0] if len(shape) < 4 else shape[1] - - # If quantization is the same across IFM and OFM op will become avgpool and this constraint does not apply. - ifm, ofm = op.get_ifm_ofm() - if check_quantized_tens_scaling_equal(ifm, ofm): + if len(shape) < 3: + # No height dimension present in IFM return True, "" - - return h <= 64 or axis != 1, f"Height is {h} and axis is {axis}." - - @classmethod - @docstring_format_args([filter_height_range[1]]) - def constraint_avgpool_height_single_axis(cls, op): - """Avgpool height can be at most {} for single axis when axis is 1.""" - inp, axis = op.inputs - if axis.shape == [] or axis.shape[0] == 1: # single axis - axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0]) - else: - # Multiple axes, constraint does not apply + if axis != len(shape) - 3: + # Not averaging across the height dimension return True, "" - # Height and width axes have different index depending on dimensions - shape = inp.shape - h = shape[0] if len(shape) < 4 else shape[1] + h = shape[axis] ifm, ofm = op.get_ifm_ofm() - scaling_equal = check_quantized_tens_scaling_equal(ifm, ofm) - return h <= 256 or axis != 1 or not scaling_equal, f"Height is {h} and axis is {axis}" + if check_quantized_tens_scaling_equal(ifm, ofm): + return h <= cls.filter_height_range[1], f"Height is {h}, IFM and OFM quantizations match" + else: + return h <= cls.dilated_height_range[1], f"Height is {h}, IFM and OFM quantizations do not match" @staticmethod def constraint_reshape_shape_constant(op): -- cgit v1.2.1