diff options
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 777e9c70..5bf2c459 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -122,6 +122,7 @@ class SupportedOperators: filter_product_range = (1, 256 * 256) mean_kernel_product = 64 * 64 mean_kernel_product_int8 = 16 * 16 + mean_kernel_product_avgpool = 256 * 256 # Supported consumers supported_pad_consumers = convolution_ops | depthwise_convolution_ops | pooling_ops @@ -272,6 +273,7 @@ class SupportedOperators: self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_input_8bit) self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_input_dims) self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_axis) + self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product_avgpool) self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product) self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product_int8) @@ -1028,6 +1030,7 @@ class SupportedOperators: valid = len(op.ifm.shape) == len(op.ofm.shape) return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}" + @staticmethod def constraint_mean_input_dims(op): "Input tensor must be at least 2D" dims = len(op.inputs[0].shape) @@ -1045,9 +1048,25 @@ class SupportedOperators: return valid, f"Axis is {axis}" @classmethod + @docstring_format_args([mean_kernel_product_avgpool]) + def constraint_mean_height_width_product_avgpool(cls, op): + """Product of height and width can be at most {}""" + shape = op.inputs[0].shape + hi = 0 if len(shape) < 4 else 1 + h, w = shape[hi : hi + 2] + max_prod = cls.mean_kernel_product_avgpool + return h * w <= max_prod, f"Product of height and width is {h * w}" + + @classmethod @docstring_format_args([mean_kernel_product]) def constraint_mean_height_width_product(cls, op): - "Product of height and width can be at most {}" + """Product of height and width can be at most {} when IFM and OFM have different scale or zero point, + or keep_dims is True""" + ifmq, ofmq = op.ifm.quantization, op.ofm.quantization + keep_dims = op.attrs.get("keep_dims") + # doesn't apply, size is checked by constraint_mean_height_width_product_avgpool + if not keep_dims and ifmq.scale_f32 == ofmq.scale_f32 and ifmq.zero_point == ofmq.zero_point: + return True, "" shape = op.inputs[0].shape hi = 0 if len(shape) < 4 else 1 h, w = shape[hi : hi + 2] @@ -1064,6 +1083,8 @@ class SupportedOperators: IFM datatype is int8""" shape = op.ifm.shape axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values) + # doesn't apply, size is checked by constraint_mean_height_width_product_avgpool + # and constraint_mean_height_width_product if ( len(shape) != 4 or op.ifm.dtype != DataType.int8 |