diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 33 |
1 files changed, 27 insertions, 6 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 597e0a2c..7d544004 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -191,7 +191,7 @@ class TFLiteSupportedOperators: filter_range = (1, 8) filter_height_range = (1, 256) filter_product_range = (1, 256 * 256) - mean_width_size = 64 * 64 + mean_reduced_axis_max_size = 64 * 64 mean_kernel_product_int8 = 2 ** (24) mean_kernel_product_uint8 = 2 ** (23) mean_kernel_product_int16 = 2 ** (16) @@ -315,6 +315,7 @@ class TFLiteSupportedOperators: # Mean specific checks: self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product) self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_width) + self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_depth) # Reshape specific checks: self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant) @@ -844,9 +845,9 @@ class TFLiteSupportedOperators: @docstring_format_args([mean_kernel_product_int8, mean_kernel_product_uint8, mean_kernel_product_int16]) def constraint_mean_height_width_product(cls, op): """Product of reduced axes must be no greater than: - - {} for signed 8-bit inputs - - {} for unsigned 8-bit inputs - - {} for signed 16-bit inputs""" + - {} for signed 8-bit inputs. + - {} for unsigned 8-bit inputs. + - {} for signed 16-bit inputs.""" shape = op.inputs[0].shape if op.inputs[1].shape == []: axis = [int(op.inputs[1].values)] @@ -869,15 +870,35 @@ class TFLiteSupportedOperators: return prod <= max_prod, f"Datatype is {datatype}, product of axes is {prod}" @classmethod - @docstring_format_args([mean_width_size]) + @docstring_format_args([mean_reduced_axis_max_size]) def constraint_mean_width(cls, op): """If Width axis is reduced its shape must be no greater than {}.""" shape = op.inputs[0].shape hi = 0 if len(shape) < 4 else 1 h, w = shape[hi : hi + 2] - max_width = cls.mean_width_size + max_width = cls.mean_reduced_axis_max_size return w <= max_width, f"Width is {w}" + @classmethod + @docstring_format_args([mean_reduced_axis_max_size]) + def constraint_mean_depth(cls, op): + """If Depth axis is reduced its shape must be no greater than {}.""" + max_depth = cls.mean_reduced_axis_max_size + shape = op.inputs[0].shape + + if op.inputs[1].shape == []: + axis = [int(op.inputs[1].values)] + else: + axis = list(op.inputs[1].values) + + depth_idx = len(shape) - 1 + + supported = True + if depth_idx in axis and shape[-1] > max_depth: + supported = False + + return supported, f"Depth is {shape[-1]}, shape is {shape}, axis is {axis}" + @staticmethod def constraint_reshape_shape_constant(op): "Shape must be constant" |