aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 2e0936d0..e26a327f 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -191,6 +191,40 @@ def test_constraint_beta_value_range():
assert semantic_checker.is_operator_semantic_valid(op)
+def test_constraint_split_axis():
+ # Axis value must be in the range [-<ifm_dimensions>, <ifm_dimensions>)"
+ attrs = {"num_splits": 2}
+ axis = create_const_tensor("axis", [1], DataType.int8, [3])
+ ifm = Tensor([1, 1, 4], DataType.int8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ ofm = Tensor([1, 1, 4], DataType.int8, "ofm")
+ ofm.quantization = testutil.default_quant_params()
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ # Check invalid axis value
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Check valid axis value
+ axis = create_const_tensor("axis", [1], DataType.int8, [-1])
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_split_num_splits():
+ # Check that split number is valid"
+ attrs = {"num_splits": 2}
+ axis = create_const_tensor("axis", [1], DataType.int8, [-1])
+ ifm = Tensor([1, 1, 3], DataType.int8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ ofm = Tensor([1, 1, 3], DataType.int8, "ofm")
+ ofm.quantization = testutil.default_quant_params()
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ # Check invalid split number 2
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Check valid split number 3
+ attrs = {"num_splits": 3}
+ op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs)
+ assert semantic_checker.is_operator_semantic_valid(op)
+
+
def test_constraint_splitv_inferred():
# SplitV requires a maximum of one inferred shape (-1)
qp = testutil.default_quant_params()