aboutsummaryrefslogtreecommitdiff
path: root/ethosu
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-01-31 10:26:26 +0100
committerJohan Alfven <johan.alfven@arm.com>2023-02-06 10:09:58 +0100
commit12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f (patch)
treebbbe6eadb20249fb5974e8753d324b66814d8184 /ethosu
parent4b7179936d659a6d4abd6b7659c2cc05c5a845fb (diff)
downloadethos-u-vela-12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f.tar.gz
MLBEDSW-7284: MLCE: Fix assert for faulty Split op
- An assert in Vela is triggered when the number of splits does not evenly divide the input.shape[axis] value and the split offsets are calculated wrongly. - The fix is to add the same constraints as in the reference kernel and only run the Split op on the NPU when the criterias are fulfilled. - Modified test to reflect the new constraints - Updated SUPPORTED_OPS.md Change-Id: I4103ff4a3fdf9a813f5fcb7f51081b859e611100 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py34
-rw-r--r--ethosu/vela/tflite_model_semantic.py29
2 files changed, 62 insertions, 1 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 2e0936d0..e26a327f 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -191,6 +191,40 @@ def test_constraint_beta_value_range():
assert semantic_checker.is_operator_semantic_valid(op)
+def test_constraint_split_axis():
+ # Axis value must be in the range [-<ifm_dimensions>, <ifm_dimensions>)"
+ attrs = {"num_splits": 2}
+ axis = create_const_tensor("axis", [1], DataType.int8, [3])
+ ifm = Tensor([1, 1, 4], DataType.int8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ ofm = Tensor([1, 1, 4], DataType.int8, "ofm")
+ ofm.quantization = testutil.default_quant_params()
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ # Check invalid axis value
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Check valid axis value
+ axis = create_const_tensor("axis", [1], DataType.int8, [-1])
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_split_num_splits():
+ # Check that split number is valid"
+ attrs = {"num_splits": 2}
+ axis = create_const_tensor("axis", [1], DataType.int8, [-1])
+ ifm = Tensor([1, 1, 3], DataType.int8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ ofm = Tensor([1, 1, 3], DataType.int8, "ofm")
+ ofm.quantization = testutil.default_quant_params()
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ # Check invalid split number 2
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Check valid split number 3
+ attrs = {"num_splits": 3}
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
def test_constraint_splitv_inferred():
# SplitV requires a maximum of one inferred shape (-1)
qp = testutil.default_quant_params()
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index dc3b8185..2851ab16 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -155,6 +155,10 @@ class TFLiteSemantic:
self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types)
self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_beta_value_range)
+ # Split specific checks:
+ self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_axis)
+ self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_num_splits)
+
# SplitV specific checks:
self.specific_constraints[Op.SplitV].append(TFLiteSemantic.constraint_splitv_inferred)
@@ -396,6 +400,29 @@ class TFLiteSemantic:
return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
@staticmethod
+ def constraint_split_axis(op):
+ "Axis value must be in the range [-RANK(IFM) to +RANK(IFM))"
+ axis_tens = op.inputs[0]
+ input_tens = op.inputs[1]
+ dims = len(input_tens.shape)
+ axis = int(axis_tens.values)
+ axis += dims if axis < 0 else 0
+ valid = 0 <= axis < dims
+ return valid, f"Op has ifm_dimensions={dims} and axis value is: {axis}"
+
+ @staticmethod
+ def constraint_split_num_splits(op):
+ "Axis must be divisible by number of splits"
+ num_splits = op.attrs.get("num_splits")
+ axis_tens = op.inputs[0]
+ input_tens = op.inputs[1]
+ dims = len(input_tens.shape)
+ axis = int(axis_tens.values)
+ axis += dims if axis < 0 else 0
+ valid = input_tens.shape[axis] % num_splits == 0
+ return valid, f"Op has ifm shape={input_tens.shape} axis={axis} num_splits={num_splits}"
+
+ @staticmethod
def constraint_splitv_inferred(op):
"Only one size is allowed to be inferred"
sizes = op.inputs[1].values