diff options
author | Johan Alfven <johan.alfven@arm.com> | 2023-01-31 10:26:26 +0100 |
---|---|---|
committer | Johan Alfven <johan.alfven@arm.com> | 2023-02-06 10:09:58 +0100 |
commit | 12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f (patch) | |
tree | bbbe6eadb20249fb5974e8753d324b66814d8184 | |
parent | 4b7179936d659a6d4abd6b7659c2cc05c5a845fb (diff) | |
download | ethos-u-vela-12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f.tar.gz |
MLBEDSW-7284: MLCE: Fix assert for faulty Split op
- An assert in Vela is triggered when the number of splits does
not evenly divide the input.shape[axis] value and the split offsets
are calculated wrongly.
- The fix is to add the same constraints as in the reference kernel
and only run the Split op on the NPU when the criterias are fulfilled.
- Modified test to reflect the new constraints
- Updated SUPPORTED_OPS.md
Change-Id: I4103ff4a3fdf9a813f5fcb7f51081b859e611100
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
-rw-r--r-- | SUPPORTED_OPS.md | 11 | ||||
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 34 | ||||
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 29 |
3 files changed, 71 insertions, 3 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md index cee25e7b..43db4c5b 100644 --- a/SUPPORTED_OPS.md +++ b/SUPPORTED_OPS.md @@ -1,7 +1,7 @@ # Supported Ops This file was automatically generated by Vela using the `--supported-ops-report` parameter. -Vela version: `3.6.1.dev1+g30e5320.d20221207` +Vela version: `3.6.1.dev17+g859efbe.d20230203` This file complies with [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md) @@ -47,7 +47,7 @@ Please check the supported operator list for your chosen runtime for further inf | SHAPE | [Generic](#tflite-generic-constraints) | | SLICE | [Generic](#tflite-generic-constraints) | | SOFTMAX | [Generic](#tflite-generic-constraints), [Specific](#tflite-softmax-constraints) | -| SPLIT | [Generic](#tflite-generic-constraints) | +| SPLIT | [Generic](#tflite-generic-constraints), [Specific](#tflite-split-constraints) | | SPLIT_V | [Generic](#tflite-generic-constraints), [Specific](#tflite-split_v-constraints) | | SQUEEZE | [Generic](#tflite-generic-constraints), [Specific](#tflite-squeeze-constraints) | | STRIDED_SLICE | [Generic](#tflite-generic-constraints), [Specific](#tflite-strided_slice-constraints) | @@ -286,6 +286,13 @@ This is a list of constraints that the SOFTMAX operator must satisfy in order to - IFM and OFM data types must match - Beta value needs to be positive +### TFLite SPLIT Constraints + +This is a list of constraints that the SPLIT operator must satisfy in order to be scheduled on the NPU. + +- Axis value must be in the range [-RANK(IFM) to +RANK(IFM)) +- Axis must be divisible by number of splits + ### TFLite SPLIT_V Constraints This is a list of constraints that the SPLIT_V operator must satisfy in order to be scheduled on the NPU. diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 2e0936d0..e26a327f 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -191,6 +191,40 @@ def test_constraint_beta_value_range(): assert semantic_checker.is_operator_semantic_valid(op) +def test_constraint_split_axis(): + # Axis value must be in the range [-<ifm_dimensions>, <ifm_dimensions>)" + attrs = {"num_splits": 2} + axis = create_const_tensor("axis", [1], DataType.int8, [3]) + ifm = Tensor([1, 1, 4], DataType.int8, "ifm") + ifm.quantization = testutil.default_quant_params() + ofm = Tensor([1, 1, 4], DataType.int8, "ofm") + ofm.quantization = testutil.default_quant_params() + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + # Check invalid axis value + assert not semantic_checker.is_operator_semantic_valid(op) + # Check valid axis value + axis = create_const_tensor("axis", [1], DataType.int8, [-1]) + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + assert semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_split_num_splits(): + # Check that split number is valid" + attrs = {"num_splits": 2} + axis = create_const_tensor("axis", [1], DataType.int8, [-1]) + ifm = Tensor([1, 1, 3], DataType.int8, "ifm") + ifm.quantization = testutil.default_quant_params() + ofm = Tensor([1, 1, 3], DataType.int8, "ofm") + ofm.quantization = testutil.default_quant_params() + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + # Check invalid split number 2 + assert not semantic_checker.is_operator_semantic_valid(op) + # Check valid split number 3 + attrs = {"num_splits": 3} + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + assert semantic_checker.is_operator_semantic_valid(op) + + def test_constraint_splitv_inferred(): # SplitV requires a maximum of one inferred shape (-1) qp = testutil.default_quant_params() diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index dc3b8185..2851ab16 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -155,6 +155,10 @@ class TFLiteSemantic: self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types) self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_beta_value_range) + # Split specific checks: + self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_axis) + self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_num_splits) + # SplitV specific checks: self.specific_constraints[Op.SplitV].append(TFLiteSemantic.constraint_splitv_inferred) @@ -396,6 +400,29 @@ class TFLiteSemantic: return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}" @staticmethod + def constraint_split_axis(op): + "Axis value must be in the range [-RANK(IFM) to +RANK(IFM))" + axis_tens = op.inputs[0] + input_tens = op.inputs[1] + dims = len(input_tens.shape) + axis = int(axis_tens.values) + axis += dims if axis < 0 else 0 + valid = 0 <= axis < dims + return valid, f"Op has ifm_dimensions={dims} and axis value is: {axis}" + + @staticmethod + def constraint_split_num_splits(op): + "Axis must be divisible by number of splits" + num_splits = op.attrs.get("num_splits") + axis_tens = op.inputs[0] + input_tens = op.inputs[1] + dims = len(input_tens.shape) + axis = int(axis_tens.values) + axis += dims if axis < 0 else 0 + valid = input_tens.shape[axis] % num_splits == 0 + return valid, f"Op has ifm shape={input_tens.shape} axis={axis} num_splits={num_splits}" + + @staticmethod def constraint_splitv_inferred(op): "Only one size is allowed to be inferred" sizes = op.inputs[1].values |