aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py36
1 files changed, 30 insertions, 6 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index f965d2ba..92a7f3c3 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -191,7 +191,10 @@ class TFLiteSupportedOperators:
filter_range = (1, 8)
filter_height_range = (1, 256)
filter_product_range = (1, 256 * 256)
- mean_kernel_product = 64 * 64
+ mean_width_size = 64 * 64
+ mean_kernel_product_int8 = 2 ** (24)
+ mean_kernel_product_uint8 = 2 ** (23)
+ mean_kernel_product_int16 = 2 ** (16)
def __init__(self):
# Setup the generic constraints. Note: the order matters
@@ -311,7 +314,7 @@ class TFLiteSupportedOperators:
# Mean specific checks:
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
- self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_single_axis)
+ self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_width)
# Reshape specific checks:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
@@ -838,14 +841,35 @@ class TFLiteSupportedOperators:
return valid, f"Op has ifm_shape={ifm_shape} and ifm2_shape={ifm2_shape}"
@classmethod
- @docstring_format_args([mean_kernel_product])
+ @docstring_format_args([mean_kernel_product_int8, mean_kernel_product_uint8, mean_kernel_product_int16])
def constraint_mean_height_width_product(cls, op):
- """Product of height and width must be no greater than {}"""
+ """Product of height and width must be no greater than:
+ - {} for signed 8-bit inputs
+ - {} for unsigned 8-bit inputs
+ - {} for signed 16-bit inputs"""
shape = op.inputs[0].shape
hi = 0 if len(shape) < 4 else 1
h, w = shape[hi : hi + 2]
- max_prod = cls.mean_kernel_product
- return h * w <= max_prod, f"Product of height and width is {h * w}"
+ if op.ifm.dtype == DataType.int16:
+ max_prod = cls.mean_kernel_product_int16
+ datatype = "int16"
+ elif op.ifm.dtype == DataType.uint8:
+ max_prod = cls.mean_kernel_product_uint8
+ datatype = "uint8"
+ else:
+ max_prod = cls.mean_kernel_product_int8
+ datatype = "int8"
+ return h * w <= max_prod, f"Datatype is {datatype}, product of height and width is {h * w}"
+
+ @classmethod
+ @docstring_format_args([mean_width_size])
+ def constraint_mean_width(cls, op):
+ """Width must be no greater than {}"""
+ shape = op.inputs[0].shape
+ hi = 0 if len(shape) < 4 else 1
+ h, w = shape[hi : hi + 2]
+ max_width = cls.mean_width_size
+ return w <= max_width, f"Width is {w}"
@classmethod
@docstring_format_args([dilated_height_range[1]])