diff options
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index fd23d042..d4c92553 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -576,3 +576,37 @@ def test_matching_in_out_quant(): dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0) op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False) assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_lstm_semantics(): + # Test valid configurations + op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8) + assert semantic_checker.is_operator_semantic_valid(op) + assert semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.int16)) + # Test invalid datatype + assert not semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.uint8)) + # Test invalid shape + ifm_shape = op.ifm.shape + ofm_shape = op.ofm.shape + op.ifm.shape = [12, 24] + assert not semantic_checker.is_operator_semantic_valid(op) + op.ifm.shape = ifm_shape + op.ofm.shape = [12, 20] + assert not semantic_checker.is_operator_semantic_valid(op) + op.ofm.shape = ofm_shape + # Test invalid number of intermediates + intermediate = op.intermediates.pop() + assert not semantic_checker.is_operator_semantic_valid(op) + op.intermediates.append(intermediate) + op.intermediates.append(intermediate) + assert not semantic_checker.is_operator_semantic_valid(op) + op.intermediates.pop() + # Test invalid number of inputs + input = op.inputs.pop() + assert not semantic_checker.is_operator_semantic_valid(op) + op.inputs.append(input) + op.inputs.append(input) + assert not semantic_checker.is_operator_semantic_valid(op) + op.inputs.pop() + # Test restored valid configuration + assert semantic_checker.is_operator_semantic_valid(op) |