diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 42 |
1 files changed, 12 insertions, 30 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 92a7f3c3..597e0a2c 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -843,13 +843,20 @@ class TFLiteSupportedOperators: @classmethod @docstring_format_args([mean_kernel_product_int8, mean_kernel_product_uint8, mean_kernel_product_int16]) def constraint_mean_height_width_product(cls, op): - """Product of height and width must be no greater than: + """Product of reduced axes must be no greater than: - {} for signed 8-bit inputs - {} for unsigned 8-bit inputs - {} for signed 16-bit inputs""" shape = op.inputs[0].shape - hi = 0 if len(shape) < 4 else 1 - h, w = shape[hi : hi + 2] + if op.inputs[1].shape == []: + axis = [int(op.inputs[1].values)] + else: + axis = list(op.inputs[1].values) + + # compute the product of the shape of all reduced axes + axis_shapes = [shape[ax] for ax in axis] + prod = np.prod(axis_shapes) + if op.ifm.dtype == DataType.int16: max_prod = cls.mean_kernel_product_int16 datatype = "int16" @@ -859,43 +866,18 @@ class TFLiteSupportedOperators: else: max_prod = cls.mean_kernel_product_int8 datatype = "int8" - return h * w <= max_prod, f"Datatype is {datatype}, product of height and width is {h * w}" + return prod <= max_prod, f"Datatype is {datatype}, product of axes is {prod}" @classmethod @docstring_format_args([mean_width_size]) def constraint_mean_width(cls, op): - """Width must be no greater than {}""" + """If Width axis is reduced its shape must be no greater than {}.""" shape = op.inputs[0].shape hi = 0 if len(shape) < 4 else 1 h, w = shape[hi : hi + 2] max_width = cls.mean_width_size return w <= max_width, f"Width is {w}" - @classmethod - @docstring_format_args([dilated_height_range[1]]) - def constraint_mean_height_single_axis(cls, op): - """For single axis averages across the height dimension: - IFM height must be no greater than {}""" - inp, axis = op.inputs - if axis.shape == [] or axis.shape[0] == 1: # single axis - axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0]) - else: - # Multiple axes - return True, "" - - shape = inp.shape - if len(shape) < 3: - # No height dimension present in IFM - return True, "" - if axis != len(shape) - 3: - # Not averaging across the height dimension - return True, "" - - h = shape[axis] - ifm, ofm = op.get_ifm_ofm() - - return h <= cls.dilated_height_range[1], f"Height is {h}" - @staticmethod def constraint_reshape_shape_constant(op): "Shape must be constant" |