aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index fd23d042..d4c92553 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -576,3 +576,37 @@ def test_matching_in_out_quant():
dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0)
op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False)
assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_lstm_semantics():
+ # Test valid configurations
+ op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ assert semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.int16))
+ # Test invalid datatype
+ assert not semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.uint8))
+ # Test invalid shape
+ ifm_shape = op.ifm.shape
+ ofm_shape = op.ofm.shape
+ op.ifm.shape = [12, 24]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.ifm.shape = ifm_shape
+ op.ofm.shape = [12, 20]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.ofm.shape = ofm_shape
+ # Test invalid number of intermediates
+ intermediate = op.intermediates.pop()
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.intermediates.append(intermediate)
+ op.intermediates.append(intermediate)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.intermediates.pop()
+ # Test invalid number of inputs
+ input = op.inputs.pop()
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.inputs.append(input)
+ op.inputs.append(input)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.inputs.pop()
+ # Test restored valid configuration
+ assert semantic_checker.is_operator_semantic_valid(op)