diff options
author | Johan Alfven <johan.alfven@arm.com> | 2023-01-31 10:26:26 +0100 |
---|---|---|
committer | Johan Alfven <johan.alfven@arm.com> | 2023-02-06 10:09:58 +0100 |
commit | 12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f (patch) | |
tree | bbbe6eadb20249fb5974e8753d324b66814d8184 /ethosu/vela/test/test_tflite_model_semantic.py | |
parent | 4b7179936d659a6d4abd6b7659c2cc05c5a845fb (diff) | |
download | ethos-u-vela-12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f.tar.gz |
MLBEDSW-7284: MLCE: Fix assert for faulty Split op
- An assert in Vela is triggered when the number of splits does
not evenly divide the input.shape[axis] value and the split offsets
are calculated wrongly.
- The fix is to add the same constraints as in the reference kernel
and only run the Split op on the NPU when the criterias are fulfilled.
- Modified test to reflect the new constraints
- Updated SUPPORTED_OPS.md
Change-Id: I4103ff4a3fdf9a813f5fcb7f51081b859e611100
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 2e0936d0..e26a327f 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -191,6 +191,40 @@ def test_constraint_beta_value_range(): assert semantic_checker.is_operator_semantic_valid(op) +def test_constraint_split_axis(): + # Axis value must be in the range [-<ifm_dimensions>, <ifm_dimensions>)" + attrs = {"num_splits": 2} + axis = create_const_tensor("axis", [1], DataType.int8, [3]) + ifm = Tensor([1, 1, 4], DataType.int8, "ifm") + ifm.quantization = testutil.default_quant_params() + ofm = Tensor([1, 1, 4], DataType.int8, "ofm") + ofm.quantization = testutil.default_quant_params() + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + # Check invalid axis value + assert not semantic_checker.is_operator_semantic_valid(op) + # Check valid axis value + axis = create_const_tensor("axis", [1], DataType.int8, [-1]) + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + assert semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_split_num_splits(): + # Check that split number is valid" + attrs = {"num_splits": 2} + axis = create_const_tensor("axis", [1], DataType.int8, [-1]) + ifm = Tensor([1, 1, 3], DataType.int8, "ifm") + ifm.quantization = testutil.default_quant_params() + ofm = Tensor([1, 1, 3], DataType.int8, "ofm") + ofm.quantization = testutil.default_quant_params() + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + # Check invalid split number 2 + assert not semantic_checker.is_operator_semantic_valid(op) + # Check valid split number 3 + attrs = {"num_splits": 3} + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + assert semantic_checker.is_operator_semantic_valid(op) + + def test_constraint_splitv_inferred(): # SplitV requires a maximum of one inferred shape (-1) qp = testutil.default_quant_params() |