diff options
author | Alexander Hansson <Alexander.Hansson@arm.com> | 2023-06-27 12:36:25 +0000 |
---|---|---|
committer | Johan Alfven <johan.alfven@arm.com> | 2023-07-11 09:10:59 +0000 |
commit | 1d5e859973ff18f3e4285f0ca04251ca246a182c (patch) | |
tree | e25de299de17ac46269c003585f32b87cedfd137 /ethosu/vela/tflite_model_semantic.py | |
parent | ca9cc420984eba39b85885bf0d2d7b48bb920da9 (diff) | |
download | ethos-u-vela-1d5e859973ff18f3e4285f0ca04251ca246a182c.tar.gz |
MLBEDSW-7652: Add mean support for batch and channel when shape is 1
- Add support for batch and depth channels when shape is 1
- Refactor reshaping in convert_mean_to_depthwise_conv
Signed-off-by: Alexander Hansson <Alexander.Hansson@arm.com>
Change-Id: If663395934ab58c76ba92b6ebaaf484a389ae699
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 444c04ad..56dce14f 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -696,14 +696,36 @@ class TFLiteSemantic: @staticmethod def constraint_mean_axis(op): - "Axis indices must correspond to height and width axes" - dims = len(op.inputs[0].shape) - axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values) - if dims == 2 or dims == 3: - valid = axis in (0, 1, [0], [1], [0, 1], [1, 0]) - elif dims == 4: - valid = axis in (1, 2, [1], [2], [1, 2], [2, 1]) - return valid, f"Axis is {axis}" + """Requirements for axis parameter: + When IFM tensor is 2D: + - Reduction in both axes is supported. + When IFM tensor is 3D or 4D: + - Reduction in Batch axis is only supported if batch size is 1. + - Reduction in both Height and Width axes is supported. + - Reduction in Depth axis is only supported if depth is 1.""" + input_shape = op.inputs[0].shape + dims = len(input_shape) + if op.inputs[1].shape == []: + axis = [int(op.inputs[1].values)] + else: + axis = list(op.inputs[1].values) + valid = True + + for ax in axis: + if ax < 0 or ax >= dims: + return False, "Axis parameter is out of bounds. axis: {axis}, dims: {dims}. " + elif dims == 3: + # depth is only supported if size is 1 + if ax == 2 and input_shape[ax] != 1: + valid = False + break + else: # 4D + # batch and depth are only supported if sizes are 1 + if ax in [0, 3] and input_shape[ax] != 1: + valid = False + break + + return valid, f"Shape is {input_shape}, Axis is {axis}." @staticmethod def constraint_matching_in_out_quant(op): |