aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_model_semantic.py
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-01-31 10:26:26 +0100
committerJohan Alfven <johan.alfven@arm.com>2023-02-06 10:09:58 +0100
commit12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f (patch)
treebbbe6eadb20249fb5974e8753d324b66814d8184 /ethosu/vela/test/test_tflite_model_semantic.py
parent4b7179936d659a6d4abd6b7659c2cc05c5a845fb (diff)
downloadethos-u-vela-12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f.tar.gz
MLBEDSW-7284: MLCE: Fix assert for faulty Split op
- An assert in Vela is triggered when the number of splits does not evenly divide the input.shape[axis] value and the split offsets are calculated wrongly. - The fix is to add the same constraints as in the reference kernel and only run the Split op on the NPU when the criterias are fulfilled. - Modified test to reflect the new constraints - Updated SUPPORTED_OPS.md Change-Id: I4103ff4a3fdf9a813f5fcb7f51081b859e611100 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 2e0936d0..e26a327f 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -191,6 +191,40 @@ def test_constraint_beta_value_range():
assert semantic_checker.is_operator_semantic_valid(op)
+def test_constraint_split_axis():
+ # Axis value must be in the range [-<ifm_dimensions>, <ifm_dimensions>)"
+ attrs = {"num_splits": 2}
+ axis = create_const_tensor("axis", [1], DataType.int8, [3])
+ ifm = Tensor([1, 1, 4], DataType.int8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ ofm = Tensor([1, 1, 4], DataType.int8, "ofm")
+ ofm.quantization = testutil.default_quant_params()
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ # Check invalid axis value
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Check valid axis value
+ axis = create_const_tensor("axis", [1], DataType.int8, [-1])
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_split_num_splits():
+ # Check that split number is valid"
+ attrs = {"num_splits": 2}
+ axis = create_const_tensor("axis", [1], DataType.int8, [-1])
+ ifm = Tensor([1, 1, 3], DataType.int8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ ofm = Tensor([1, 1, 3], DataType.int8, "ofm")
+ ofm.quantization = testutil.default_quant_params()
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ # Check invalid split number 2
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Check valid split number 3
+ attrs = {"num_splits": 3}
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
def test_constraint_splitv_inferred():
# SplitV requires a maximum of one inferred shape (-1)
qp = testutil.default_quant_params()