diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2023-04-11 22:35:04 +0200 |
---|---|---|
committer | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2023-04-17 14:16:44 +0200 |
commit | 0ac0804e76e098695ee2b8a9e24e2f0a1efc324f (patch) | |
tree | 9ccb766221987a415244079ed6c596a47d693b20 /ethosu/vela/test/test_tflite_model_semantic.py | |
parent | c1ad80b3a581dd39b39a112d6c2026f6560207a4 (diff) | |
download | ethos-u-vela-0ac0804e76e098695ee2b8a9e24e2f0a1efc324f.tar.gz |
MLBEDSW-7196 Add LSTM support
Added int8 and int16 UNIDIRECTIONAL_SEQUENCE_LSTM support.
The implementation does not include support for:
* CIFG
* Peephole
* Projection
* Normalisation
This change also:
* Removed unused Op.BlockLSTM operation type.
* Removed the only one consumer limitation on putting the SplitSliceRead
on the tensor consumer(s), if all consumers fullfills the requirements
* Added Op.VariableTensorWrite as a Operation.memory_function to make
sure writes to variable tensors:
* Always use linear mode
* Are not moved to fast scratch
* Are not fused with other elementwise operation tensor ranges
Change-Id: Ief831738924ac3d1f2ba6d41f10bd6dc969911f3
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index fd23d042..d4c92553 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -576,3 +576,37 @@ def test_matching_in_out_quant(): dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0) op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False) assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_lstm_semantics(): + # Test valid configurations + op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8) + assert semantic_checker.is_operator_semantic_valid(op) + assert semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.int16)) + # Test invalid datatype + assert not semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.uint8)) + # Test invalid shape + ifm_shape = op.ifm.shape + ofm_shape = op.ofm.shape + op.ifm.shape = [12, 24] + assert not semantic_checker.is_operator_semantic_valid(op) + op.ifm.shape = ifm_shape + op.ofm.shape = [12, 20] + assert not semantic_checker.is_operator_semantic_valid(op) + op.ofm.shape = ofm_shape + # Test invalid number of intermediates + intermediate = op.intermediates.pop() + assert not semantic_checker.is_operator_semantic_valid(op) + op.intermediates.append(intermediate) + op.intermediates.append(intermediate) + assert not semantic_checker.is_operator_semantic_valid(op) + op.intermediates.pop() + # Test invalid number of inputs + input = op.inputs.pop() + assert not semantic_checker.is_operator_semantic_valid(op) + op.inputs.append(input) + op.inputs.append(input) + assert not semantic_checker.is_operator_semantic_valid(op) + op.inputs.pop() + # Test restored valid configuration + assert semantic_checker.is_operator_semantic_valid(op) |