diff options
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 2e0936d0..e26a327f 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -191,6 +191,40 @@ def test_constraint_beta_value_range(): assert semantic_checker.is_operator_semantic_valid(op) +def test_constraint_split_axis(): + # Axis value must be in the range [-<ifm_dimensions>, <ifm_dimensions>)" + attrs = {"num_splits": 2} + axis = create_const_tensor("axis", [1], DataType.int8, [3]) + ifm = Tensor([1, 1, 4], DataType.int8, "ifm") + ifm.quantization = testutil.default_quant_params() + ofm = Tensor([1, 1, 4], DataType.int8, "ofm") + ofm.quantization = testutil.default_quant_params() + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + # Check invalid axis value + assert not semantic_checker.is_operator_semantic_valid(op) + # Check valid axis value + axis = create_const_tensor("axis", [1], DataType.int8, [-1]) + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + assert semantic_checker.is_operator_semantic_valid(op) + + +def test_constraint_split_num_splits(): + # Check that split number is valid" + attrs = {"num_splits": 2} + axis = create_const_tensor("axis", [1], DataType.int8, [-1]) + ifm = Tensor([1, 1, 3], DataType.int8, "ifm") + ifm.quantization = testutil.default_quant_params() + ofm = Tensor([1, 1, 3], DataType.int8, "ofm") + ofm.quantization = testutil.default_quant_params() + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + # Check invalid split number 2 + assert not semantic_checker.is_operator_semantic_valid(op) + # Check valid split number 3 + attrs = {"num_splits": 3} + op = testutil.create_op(Op.Split, [axis, ifm], ofm, attrs) + assert semantic_checker.is_operator_semantic_valid(op) + + def test_constraint_splitv_inferred(): # SplitV requires a maximum of one inferred shape (-1) qp = testutil.default_quant_params() |