diff options
author | Alexander Hansson <Alexander.Hansson@arm.com> | 2023-06-30 15:41:13 +0000 |
---|---|---|
committer | Alexander Hansson <alexander.hansson@arm.com> | 2023-07-11 11:02:10 +0100 |
commit | da8741a14c3774d3161f59019d3003a2ee944400 (patch) | |
tree | cf4aa63fd073d0a3b0f99b4ea6716b4a8f1be10f /ethosu/vela/tflite_supported_operators.py | |
parent | 1d5e859973ff18f3e4285f0ca04251ca246a182c (diff) | |
download | ethos-u-vela-da8741a14c3774d3161f59019d3003a2ee944400.tar.gz |
MLBEDSW-7653: Extend Mean support for depth axis
If any of H,W axes have shape 1, the IFM can be reshaped to support
reduction over the depth axis.
Signed-off-by: Alexander Hansson <Alexander.Hansson@arm.com>
Change-Id: I432ff1c399b7cee4ca5f0a8f4461e9c0a936d804
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 33 |
1 files changed, 27 insertions, 6 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 597e0a2c..7d544004 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -191,7 +191,7 @@ class TFLiteSupportedOperators: filter_range = (1, 8) filter_height_range = (1, 256) filter_product_range = (1, 256 * 256) - mean_width_size = 64 * 64 + mean_reduced_axis_max_size = 64 * 64 mean_kernel_product_int8 = 2 ** (24) mean_kernel_product_uint8 = 2 ** (23) mean_kernel_product_int16 = 2 ** (16) @@ -315,6 +315,7 @@ class TFLiteSupportedOperators: # Mean specific checks: self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product) self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_width) + self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_depth) # Reshape specific checks: self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant) @@ -844,9 +845,9 @@ class TFLiteSupportedOperators: @docstring_format_args([mean_kernel_product_int8, mean_kernel_product_uint8, mean_kernel_product_int16]) def constraint_mean_height_width_product(cls, op): """Product of reduced axes must be no greater than: - - {} for signed 8-bit inputs - - {} for unsigned 8-bit inputs - - {} for signed 16-bit inputs""" + - {} for signed 8-bit inputs. + - {} for unsigned 8-bit inputs. + - {} for signed 16-bit inputs.""" shape = op.inputs[0].shape if op.inputs[1].shape == []: axis = [int(op.inputs[1].values)] @@ -869,15 +870,35 @@ class TFLiteSupportedOperators: return prod <= max_prod, f"Datatype is {datatype}, product of axes is {prod}" @classmethod - @docstring_format_args([mean_width_size]) + @docstring_format_args([mean_reduced_axis_max_size]) def constraint_mean_width(cls, op): """If Width axis is reduced its shape must be no greater than {}.""" shape = op.inputs[0].shape hi = 0 if len(shape) < 4 else 1 h, w = shape[hi : hi + 2] - max_width = cls.mean_width_size + max_width = cls.mean_reduced_axis_max_size return w <= max_width, f"Width is {w}" + @classmethod + @docstring_format_args([mean_reduced_axis_max_size]) + def constraint_mean_depth(cls, op): + """If Depth axis is reduced its shape must be no greater than {}.""" + max_depth = cls.mean_reduced_axis_max_size + shape = op.inputs[0].shape + + if op.inputs[1].shape == []: + axis = [int(op.inputs[1].values)] + else: + axis = list(op.inputs[1].values) + + depth_idx = len(shape) - 1 + + supported = True + if depth_idx in axis and shape[-1] > max_depth: + supported = False + + return supported, f"Depth is {shape[-1]}, shape is {shape}, axis is {axis}" + @staticmethod def constraint_reshape_shape_constant(op): "Shape must be constant" |