aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py42
1 files changed, 12 insertions, 30 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 92a7f3c3..597e0a2c 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -843,13 +843,20 @@ class TFLiteSupportedOperators:
@classmethod
@docstring_format_args([mean_kernel_product_int8, mean_kernel_product_uint8, mean_kernel_product_int16])
def constraint_mean_height_width_product(cls, op):
- """Product of height and width must be no greater than:
+ """Product of reduced axes must be no greater than:
- {} for signed 8-bit inputs
- {} for unsigned 8-bit inputs
- {} for signed 16-bit inputs"""
shape = op.inputs[0].shape
- hi = 0 if len(shape) < 4 else 1
- h, w = shape[hi : hi + 2]
+ if op.inputs[1].shape == []:
+ axis = [int(op.inputs[1].values)]
+ else:
+ axis = list(op.inputs[1].values)
+
+ # compute the product of the shape of all reduced axes
+ axis_shapes = [shape[ax] for ax in axis]
+ prod = np.prod(axis_shapes)
+
if op.ifm.dtype == DataType.int16:
max_prod = cls.mean_kernel_product_int16
datatype = "int16"
@@ -859,43 +866,18 @@ class TFLiteSupportedOperators:
else:
max_prod = cls.mean_kernel_product_int8
datatype = "int8"
- return h * w <= max_prod, f"Datatype is {datatype}, product of height and width is {h * w}"
+ return prod <= max_prod, f"Datatype is {datatype}, product of axes is {prod}"
@classmethod
@docstring_format_args([mean_width_size])
def constraint_mean_width(cls, op):
- """Width must be no greater than {}"""
+ """If Width axis is reduced its shape must be no greater than {}."""
shape = op.inputs[0].shape
hi = 0 if len(shape) < 4 else 1
h, w = shape[hi : hi + 2]
max_width = cls.mean_width_size
return w <= max_width, f"Width is {w}"
- @classmethod
- @docstring_format_args([dilated_height_range[1]])
- def constraint_mean_height_single_axis(cls, op):
- """For single axis averages across the height dimension:
- IFM height must be no greater than {}"""
- inp, axis = op.inputs
- if axis.shape == [] or axis.shape[0] == 1: # single axis
- axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
- else:
- # Multiple axes
- return True, ""
-
- shape = inp.shape
- if len(shape) < 3:
- # No height dimension present in IFM
- return True, ""
- if axis != len(shape) - 3:
- # Not averaging across the height dimension
- return True, ""
-
- h = shape[axis]
- ifm, ofm = op.get_ifm_ofm()
-
- return h <= cls.dilated_height_range[1], f"Height is {h}"
-
@staticmethod
def constraint_reshape_shape_constant(op):
"Shape must be constant"