diff options
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 444c04ad..56dce14f 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -696,14 +696,36 @@ class TFLiteSemantic: @staticmethod def constraint_mean_axis(op): - "Axis indices must correspond to height and width axes" - dims = len(op.inputs[0].shape) - axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values) - if dims == 2 or dims == 3: - valid = axis in (0, 1, [0], [1], [0, 1], [1, 0]) - elif dims == 4: - valid = axis in (1, 2, [1], [2], [1, 2], [2, 1]) - return valid, f"Axis is {axis}" + """Requirements for axis parameter: + When IFM tensor is 2D: + - Reduction in both axes is supported. + When IFM tensor is 3D or 4D: + - Reduction in Batch axis is only supported if batch size is 1. + - Reduction in both Height and Width axes is supported. + - Reduction in Depth axis is only supported if depth is 1.""" + input_shape = op.inputs[0].shape + dims = len(input_shape) + if op.inputs[1].shape == []: + axis = [int(op.inputs[1].values)] + else: + axis = list(op.inputs[1].values) + valid = True + + for ax in axis: + if ax < 0 or ax >= dims: + return False, "Axis parameter is out of bounds. axis: {axis}, dims: {dims}. " + elif dims == 3: + # depth is only supported if size is 1 + if ax == 2 and input_shape[ax] != 1: + valid = False + break + else: # 4D + # batch and depth are only supported if sizes are 1 + if ax in [0, 3] and input_shape[ax] != 1: + valid = False + break + + return valid, f"Shape is {input_shape}, Axis is {axis}." @staticmethod def constraint_matching_in_out_quant(op): |