diff options
Diffstat (limited to 'ethosu/vela/test/test_tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_supported_operators.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index 2713adf9..04f10e9a 100644 --- a/ethosu/vela/test/test_tflite_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -623,3 +623,49 @@ def test_mean_hw_product_avgpool(): assert support.is_operator_supported(op) op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True}) assert not support.is_operator_supported(op) + + +def test_lstm_support(): + # Test valid configuration + op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8) + assert support.is_operator_supported(op) + # Test CIFG not supported + input_to_input_weights, recurrent_to_input_weights = op.inputs[1], op.inputs[5] + op.inputs[1] = None + assert not support.is_operator_supported(op) + op.inputs[1] = input_to_input_weights + op.inputs[5] = None + assert not support.is_operator_supported(op) + op.inputs[5] = recurrent_to_input_weights + # Test Peephole not supported + op.inputs[9] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[9] = None + op.inputs[10] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[10] = None + op.inputs[11] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[11] = None + # Test Projection not supported + op.inputs[16] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[16] = None + op.inputs[17] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[17] = None + # Test Normalisation not supported + op.inputs[20] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[20] = None + op.inputs[21] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[21] = None + op.inputs[22] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[22] = None + op.inputs[23] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[23] = None + # Test restored valid configuration + assert support.is_operator_supported(op) |