aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_tflite_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 2713adf9..04f10e9a 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -623,3 +623,49 @@ def test_mean_hw_product_avgpool():
assert support.is_operator_supported(op)
op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
assert not support.is_operator_supported(op)
+
+
+def test_lstm_support():
+ # Test valid configuration
+ op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8)
+ assert support.is_operator_supported(op)
+ # Test CIFG not supported
+ input_to_input_weights, recurrent_to_input_weights = op.inputs[1], op.inputs[5]
+ op.inputs[1] = None
+ assert not support.is_operator_supported(op)
+ op.inputs[1] = input_to_input_weights
+ op.inputs[5] = None
+ assert not support.is_operator_supported(op)
+ op.inputs[5] = recurrent_to_input_weights
+ # Test Peephole not supported
+ op.inputs[9] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[9] = None
+ op.inputs[10] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[10] = None
+ op.inputs[11] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[11] = None
+ # Test Projection not supported
+ op.inputs[16] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[16] = None
+ op.inputs[17] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[17] = None
+ # Test Normalisation not supported
+ op.inputs[20] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[20] = None
+ op.inputs[21] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[21] = None
+ op.inputs[22] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[22] = None
+ op.inputs[23] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[23] = None
+ # Test restored valid configuration
+ assert support.is_operator_supported(op)