aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 5c7fd517..60bc6fd0 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -209,6 +209,8 @@ class TFLiteSupportedOperators:
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_avgpool)
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_int8)
+ self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_depthwise_conv_height_single_axis)
+ self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_avgpool_height_single_axis)
# Reshape specific checks:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
@@ -686,6 +688,47 @@ class TFLiteSupportedOperators:
max_prod = cls.mean_kernel_product_int8
return h * w <= max_prod, f"Product of height and width is {h * w}"
+ @classmethod
+ @docstring_format_args([dilated_height_range[1]])
+ def constraint_depthwise_conv_height_single_axis(cls, op):
+ """Height can be at most {} for single axis when axis is 1."""
+ inp, axis = op.inputs
+ if axis.shape == [] or axis.shape[0] == 1: # single axis
+ axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
+ else:
+ # Multiple axes, constraint does not apply
+ return True, ""
+
+ # Height and width axes have different index depending on dimensions
+ shape = inp.shape
+ h = shape[0] if len(shape) < 4 else shape[1]
+
+ # If quantization is the same across IFM and OFM op will become avgpool and this constraint does not apply.
+ ifm, ofm = op.get_ifm_ofm()
+ if check_quantized_tens_scaling_equal(ifm, ofm):
+ return True, ""
+
+ return h <= 64 or axis != 1, f"Height is {h} and axis is {axis}."
+
+ @classmethod
+ @docstring_format_args([filter_height_range[1]])
+ def constraint_avgpool_height_single_axis(cls, op):
+ """Avgpool height can be at most {} for single axis when axis is 1."""
+ inp, axis = op.inputs
+ if axis.shape == [] or axis.shape[0] == 1: # single axis
+ axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
+ else:
+ # Multiple axes, constraint does not apply
+ return True, ""
+
+ # Height and width axes have different index depending on dimensions
+ shape = inp.shape
+ h = shape[0] if len(shape) < 4 else shape[1]
+ ifm, ofm = op.get_ifm_ofm()
+ scaling_equal = check_quantized_tens_scaling_equal(ifm, ofm)
+
+ return h <= 256 or axis != 1 or not scaling_equal, f"Height is {h} and axis is {axis}"
+
@staticmethod
def constraint_reshape_shape_constant(op):
"Shape must be constant"