diff options
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 34 | ||||
-rw-r--r-- | ethosu/vela/test/test_tflite_supported_operators.py | 46 | ||||
-rw-r--r-- | ethosu/vela/test/testutil.py | 62 |
3 files changed, 141 insertions, 1 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index fd23d042..d4c92553 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -576,3 +576,37 @@ def test_matching_in_out_quant(): dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0) op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False) assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_lstm_semantics(): + # Test valid configurations + op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8) + assert semantic_checker.is_operator_semantic_valid(op) + assert semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.int16)) + # Test invalid datatype + assert not semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.uint8)) + # Test invalid shape + ifm_shape = op.ifm.shape + ofm_shape = op.ofm.shape + op.ifm.shape = [12, 24] + assert not semantic_checker.is_operator_semantic_valid(op) + op.ifm.shape = ifm_shape + op.ofm.shape = [12, 20] + assert not semantic_checker.is_operator_semantic_valid(op) + op.ofm.shape = ofm_shape + # Test invalid number of intermediates + intermediate = op.intermediates.pop() + assert not semantic_checker.is_operator_semantic_valid(op) + op.intermediates.append(intermediate) + op.intermediates.append(intermediate) + assert not semantic_checker.is_operator_semantic_valid(op) + op.intermediates.pop() + # Test invalid number of inputs + input = op.inputs.pop() + assert not semantic_checker.is_operator_semantic_valid(op) + op.inputs.append(input) + op.inputs.append(input) + assert not semantic_checker.is_operator_semantic_valid(op) + op.inputs.pop() + # Test restored valid configuration + assert semantic_checker.is_operator_semantic_valid(op) diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index 2713adf9..04f10e9a 100644 --- a/ethosu/vela/test/test_tflite_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -623,3 +623,49 @@ def test_mean_hw_product_avgpool(): assert support.is_operator_supported(op) op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True}) assert not support.is_operator_supported(op) + + +def test_lstm_support(): + # Test valid configuration + op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8) + assert support.is_operator_supported(op) + # Test CIFG not supported + input_to_input_weights, recurrent_to_input_weights = op.inputs[1], op.inputs[5] + op.inputs[1] = None + assert not support.is_operator_supported(op) + op.inputs[1] = input_to_input_weights + op.inputs[5] = None + assert not support.is_operator_supported(op) + op.inputs[5] = recurrent_to_input_weights + # Test Peephole not supported + op.inputs[9] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[9] = None + op.inputs[10] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[10] = None + op.inputs[11] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[11] = None + # Test Projection not supported + op.inputs[16] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[16] = None + op.inputs[17] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[17] = None + # Test Normalisation not supported + op.inputs[20] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[20] = None + op.inputs[21] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[21] = None + op.inputs[22] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[22] = None + op.inputs[23] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[23] = None + # Test restored valid configuration + assert support.is_operator_supported(op) diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py index 88fc8747..e08bde24 100644 --- a/ethosu/vela/test/testutil.py +++ b/ethosu/vela/test/testutil.py @@ -103,7 +103,10 @@ def create_op_with_quant_tensors( def create_op(op_type, inputs, output, attrs=None, set_ifm_ofm_shapes=True): op = Operation(op_type, output.name + "_op") for input in inputs: - op.add_input_tensor(input) + if input: # Add regular tensor input + op.add_input_tensor(input) + else: # Add optional (None) inputs for operators with sparse input positioning + op.inputs.append(input) op.set_output_tensor(output) if attrs is not None: op.attrs = attrs @@ -112,6 +115,63 @@ def create_op(op_type, inputs, output, attrs=None, set_ifm_ofm_shapes=True): return op +def create_lstm_op(batches, times, features, outputs, datatype): + input_shape = [batches, times, features] + output_shape = [batches, times, outputs] + weight_shape = [features, outputs] + state_shape = [batches, outputs] + bias_shape = [outputs] + ifm = Tensor(input_shape, datatype, "in") + ifm.quantization = default_quant_params() + ofm = Tensor(output_shape, datatype, "out") + ofm.quantization = default_quant_params() + bias_dtype = DataType.int64 if datatype == DataType.int16 else DataType.int32 + bias = create_const_tensor("bias", bias_shape, bias_dtype, [0] * outputs) + weight_q = default_quant_params() + weight = create_const_tensor("weight", weight_shape, DataType.int8, np.ones(weight_shape), quantization=weight_q) + output_state = Tensor(state_shape, datatype, "output_state") + output_state.quantization = default_quant_params() + output_state.is_variable = True + cell_state = Tensor(state_shape, DataType.int16, "cell_state") + cell_state.quantization = default_quant_params() + cell_state.is_variable = True + intermediate = Tensor([], DataType.float32, "intermediate") + hidden_scale_intermediate = Tensor([], datatype, "effective_hidden_scale_intermediate") + hidden_scale_intermediate.quantization = default_quant_params() + peephole = None + projection = None + normalisation = None + inputs = [ + ifm, + weight, + weight, + weight, + weight, + weight, + weight, + weight, + weight, + peephole, + peephole, + peephole, + bias, + bias, + bias, + bias, + projection, + projection, + output_state, + cell_state, + normalisation, + normalisation, + normalisation, + normalisation, + ] + op = create_op(Op.UnidirectionalSequenceLstm, inputs, ofm) + op.intermediates = [intermediate, intermediate, intermediate, intermediate, hidden_scale_intermediate] + return op + + def create_subgraph(op_list): # Creates subgraph using the given list of operations sg = Subgraph() |