aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Peet <james.peet@arm.com>2022-02-15 15:07:54 +0000
committerJames Peet <james.peet@arm.com>2022-02-15 15:16:07 +0000
commit0bb7ad1e8c1a17ae15905152a60c2160bf998fde (patch)
tree5b79474b2e7039d408398a1034560e15dbb794aa
parent1b9218e8ea1e2c2a1e01894ba8fe59cfc978cf55 (diff)
downloadethos-u-vela-0bb7ad1e8c1a17ae15905152a60c2160bf998fde.tar.gz
MLBEDSW-5554: Constraints for single-axis mean operations on NPU
- Combine two MEAN operator checks for single axis averages into one - Only apply that check if the single axis is the height dimension (previously checks were also applied to width averages) - Rephrase some MEAN operator constraint descriptions Signed-off-by: James Peet <james.peet@arm.com> Change-Id: Ie0577f2b99aba1f3d6a4c39f8934eafe3813b736
-rw-r--r--ethosu/vela/tflite_supported_operators.py65
1 files changed, 27 insertions, 38 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 193a23f..4d82677 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -209,8 +209,7 @@ class TFLiteSupportedOperators:
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_avgpool)
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_int8)
- self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_depthwise_conv_height_single_axis)
- self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_avgpool_height_single_axis)
+ self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_single_axis)
# Reshape specific checks:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
@@ -637,7 +636,7 @@ class TFLiteSupportedOperators:
@classmethod
@docstring_format_args([mean_kernel_product_avgpool])
def constraint_mean_height_width_product_avgpool(cls, op):
- """Product of height and width can be at most {}"""
+ """Product of height and width must be no greater than {}"""
shape = op.inputs[0].shape
hi = 0 if len(shape) < 4 else 1
h, w = shape[hi : hi + 2]
@@ -647,8 +646,9 @@ class TFLiteSupportedOperators:
@classmethod
@docstring_format_args([mean_kernel_product])
def constraint_mean_height_width_product(cls, op):
- """Product of height and width can be at most {} when IFM and OFM have different scale or zero point,
- or keep_dims is True"""
+ """Product of height and width must be no greater than {} when:
+ IFM and OFM have different scale or zero point; or
+ 'keep_dims' is True"""
ifmq, ofmq = op.ifm.quantization, op.ofm.quantization
keep_dims = op.attrs.get("keep_dims")
# doesn't apply, size is checked by constraint_mean_height_width_product_avgpool
@@ -663,10 +663,11 @@ class TFLiteSupportedOperators:
@classmethod
@docstring_format_args([mean_kernel_product_int8])
def constraint_mean_height_width_product_int8(cls, op):
- """Product of IFM height and width can be at most {} when the following are true:
- IFM dimensions are 4,
- Axis indices are 1 and 2,
- keep_dims is set to True and
+ """Product of IFM height and width must be no greater than {} when:
+ The IFM shape has 4 dimensions; and
+ The axis indices specify reduction across 2 dimensions; and
+ The axis indices correspond to the width and height dimensions of the IFM; and
+ 'keep_dims' is True; and
IFM datatype is int8"""
shape = op.ifm.shape
axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
@@ -679,51 +680,39 @@ class TFLiteSupportedOperators:
or axis not in ([1, 2], [2, 1])
):
return True, ""
- hi = 0 if len(shape) < 4 else 1
- h, w = shape[hi : hi + 2]
+ h = shape[-3]
+ w = shape[-2]
max_prod = cls.mean_kernel_product_int8
return h * w <= max_prod, f"Product of height and width is {h * w}"
@classmethod
- @docstring_format_args([dilated_height_range[1]])
- def constraint_depthwise_conv_height_single_axis(cls, op):
- """Height can be at most {} for single axis when axis is 1."""
+ @docstring_format_args([filter_height_range[1], dilated_height_range[1]])
+ def constraint_mean_height_single_axis(cls, op):
+ """For single axis averages across the height dimension:
+ IFM height must be no greater than {} if the IFM and OFM scale and zero point match; otherwise
+ IFM height must be no greater than {} if the IFM and OFM scale or zero point do not match"""
inp, axis = op.inputs
if axis.shape == [] or axis.shape[0] == 1: # single axis
axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
else:
- # Multiple axes, constraint does not apply
+ # Multiple axes
return True, ""
- # Height and width axes have different index depending on dimensions
shape = inp.shape
- h = shape[0] if len(shape) < 4 else shape[1]
-
- # If quantization is the same across IFM and OFM op will become avgpool and this constraint does not apply.
- ifm, ofm = op.get_ifm_ofm()
- if check_quantized_tens_scaling_equal(ifm, ofm):
+ if len(shape) < 3:
+ # No height dimension present in IFM
return True, ""
-
- return h <= 64 or axis != 1, f"Height is {h} and axis is {axis}."
-
- @classmethod
- @docstring_format_args([filter_height_range[1]])
- def constraint_avgpool_height_single_axis(cls, op):
- """Avgpool height can be at most {} for single axis when axis is 1."""
- inp, axis = op.inputs
- if axis.shape == [] or axis.shape[0] == 1: # single axis
- axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
- else:
- # Multiple axes, constraint does not apply
+ if axis != len(shape) - 3:
+ # Not averaging across the height dimension
return True, ""
- # Height and width axes have different index depending on dimensions
- shape = inp.shape
- h = shape[0] if len(shape) < 4 else shape[1]
+ h = shape[axis]
ifm, ofm = op.get_ifm_ofm()
- scaling_equal = check_quantized_tens_scaling_equal(ifm, ofm)
- return h <= 256 or axis != 1 or not scaling_equal, f"Height is {h} and axis is {axis}"
+ if check_quantized_tens_scaling_equal(ifm, ofm):
+ return h <= cls.filter_height_range[1], f"Height is {h}, IFM and OFM quantizations match"
+ else:
+ return h <= cls.dilated_height_range[1], f"Height is {h}, IFM and OFM quantizations do not match"
@staticmethod
def constraint_reshape_shape_constant(op):