From 12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f Mon Sep 17 00:00:00 2001 From: Johan Alfven Date: Tue, 31 Jan 2023 10:26:26 +0100 Subject: MLBEDSW-7284: MLCE: Fix assert for faulty Split op - An assert in Vela is triggered when the number of splits does not evenly divide the input.shape[axis] value and the split offsets are calculated wrongly. - The fix is to add the same constraints as in the reference kernel and only run the Split op on the NPU when the criterias are fulfilled. - Modified test to reflect the new constraints - Updated SUPPORTED_OPS.md Change-Id: I4103ff4a3fdf9a813f5fcb7f51081b859e611100 Signed-off-by: Johan Alfven --- ethosu/vela/test/test_tflite_model_semantic.py | 34 ++++++++++++++++++++++++++ ethosu/vela/tflite_model_semantic.py | 29 +++++++++++++++++++++- 2 files changed, 62 insertions(+), 1 deletion(-) (limited to 'ethosu') diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 2e0936d0..e26a327f 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -191,6 +191,40 @@ def test_constraint_beta_value_range(): assert semantic_checker.is_operator_semantic_valid(op) +def test_constraint_split_axis(): + # Axis value must be in the range [-, )" + attrs = {"num_splits": 2} + axis = create_const_tensor("axis", [1], DataType.int8, [3]) + ifm = Tensor([1, 1, 4], DataType.int8, "ifm") + ifm.quantization = testutil.default_quant_params() + ofm = Tensor([1, 1, 4], DataType.int8, "ofm") + ofm.quantization = testutil.default_quant_params() + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + # Check invalid axis value + assert not semantic_checker.is_operator_semantic_valid(op) + # Check valid axis value + axis = create_const_tensor("axis", [1], DataType.int8, [-1]) + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + assert semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_split_num_splits(): + # Check that split number is valid" + attrs = {"num_splits": 2} + axis = create_const_tensor("axis", [1], DataType.int8, [-1]) + ifm = Tensor([1, 1, 3], DataType.int8, "ifm") + ifm.quantization = testutil.default_quant_params() + ofm = Tensor([1, 1, 3], DataType.int8, "ofm") + ofm.quantization = testutil.default_quant_params() + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + # Check invalid split number 2 + assert not semantic_checker.is_operator_semantic_valid(op) + # Check valid split number 3 + attrs = {"num_splits": 3} + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + assert semantic_checker.is_operator_semantic_valid(op) + + def test_constraint_splitv_inferred(): # SplitV requires a maximum of one inferred shape (-1) qp = testutil.default_quant_params() diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index dc3b8185..2851ab16 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -155,6 +155,10 @@ class TFLiteSemantic: self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types) self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_beta_value_range) + # Split specific checks: + self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_axis) + self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_num_splits) + # SplitV specific checks: self.specific_constraints[Op.SplitV].append(TFLiteSemantic.constraint_splitv_inferred) @@ -395,6 +399,29 @@ class TFLiteSemantic: valid = ifm_shape == ofm_shape return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}" + @staticmethod + def constraint_split_axis(op): + "Axis value must be in the range [-RANK(IFM) to +RANK(IFM))" + axis_tens = op.inputs[0] + input_tens = op.inputs[1] + dims = len(input_tens.shape) + axis = int(axis_tens.values) + axis += dims if axis < 0 else 0 + valid = 0 <= axis < dims + return valid, f"Op has ifm_dimensions={dims} and axis value is: {axis}" + + @staticmethod + def constraint_split_num_splits(op): + "Axis must be divisible by number of splits" + num_splits = op.attrs.get("num_splits") + axis_tens = op.inputs[0] + input_tens = op.inputs[1] + dims = len(input_tens.shape) + axis = int(axis_tens.values) + axis += dims if axis < 0 else 0 + valid = input_tens.shape[axis] % num_splits == 0 + return valid, f"Op has ifm shape={input_tens.shape} axis={axis} num_splits={num_splits}" + @staticmethod def constraint_splitv_inferred(op): "Only one size is allowed to be inferred" -- cgit v1.2.1