aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2023-04-11 22:35:04 +0200
committerFredrik Svedberg <fredrik.svedberg@arm.com>2023-04-17 14:16:44 +0200
commit0ac0804e76e098695ee2b8a9e24e2f0a1efc324f (patch)
tree9ccb766221987a415244079ed6c596a47d693b20 /ethosu/vela/test
parentc1ad80b3a581dd39b39a112d6c2026f6560207a4 (diff)
downloadethos-u-vela-0ac0804e76e098695ee2b8a9e24e2f0a1efc324f.tar.gz
MLBEDSW-7196 Add LSTM support
Added int8 and int16 UNIDIRECTIONAL_SEQUENCE_LSTM support. The implementation does not include support for: * CIFG * Peephole * Projection * Normalisation This change also: * Removed unused Op.BlockLSTM operation type. * Removed the only one consumer limitation on putting the SplitSliceRead on the tensor consumer(s), if all consumers fullfills the requirements * Added Op.VariableTensorWrite as a Operation.memory_function to make sure writes to variable tensors: * Always use linear mode * Are not moved to fast scratch * Are not fused with other elementwise operation tensor ranges Change-Id: Ief831738924ac3d1f2ba6d41f10bd6dc969911f3 Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py34
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py46
-rw-r--r--ethosu/vela/test/testutil.py62
3 files changed, 141 insertions, 1 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index fd23d042..d4c92553 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -576,3 +576,37 @@ def test_matching_in_out_quant():
dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0)
op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False)
assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_lstm_semantics():
+ # Test valid configurations
+ op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ assert semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.int16))
+ # Test invalid datatype
+ assert not semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.uint8))
+ # Test invalid shape
+ ifm_shape = op.ifm.shape
+ ofm_shape = op.ofm.shape
+ op.ifm.shape = [12, 24]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.ifm.shape = ifm_shape
+ op.ofm.shape = [12, 20]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.ofm.shape = ofm_shape
+ # Test invalid number of intermediates
+ intermediate = op.intermediates.pop()
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.intermediates.append(intermediate)
+ op.intermediates.append(intermediate)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.intermediates.pop()
+ # Test invalid number of inputs
+ input = op.inputs.pop()
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.inputs.append(input)
+ op.inputs.append(input)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.inputs.pop()
+ # Test restored valid configuration
+ assert semantic_checker.is_operator_semantic_valid(op)
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 2713adf9..04f10e9a 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -623,3 +623,49 @@ def test_mean_hw_product_avgpool():
assert support.is_operator_supported(op)
op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
assert not support.is_operator_supported(op)
+
+
+def test_lstm_support():
+ # Test valid configuration
+ op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8)
+ assert support.is_operator_supported(op)
+ # Test CIFG not supported
+ input_to_input_weights, recurrent_to_input_weights = op.inputs[1], op.inputs[5]
+ op.inputs[1] = None
+ assert not support.is_operator_supported(op)
+ op.inputs[1] = input_to_input_weights
+ op.inputs[5] = None
+ assert not support.is_operator_supported(op)
+ op.inputs[5] = recurrent_to_input_weights
+ # Test Peephole not supported
+ op.inputs[9] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[9] = None
+ op.inputs[10] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[10] = None
+ op.inputs[11] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[11] = None
+ # Test Projection not supported
+ op.inputs[16] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[16] = None
+ op.inputs[17] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[17] = None
+ # Test Normalisation not supported
+ op.inputs[20] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[20] = None
+ op.inputs[21] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[21] = None
+ op.inputs[22] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[22] = None
+ op.inputs[23] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[23] = None
+ # Test restored valid configuration
+ assert support.is_operator_supported(op)
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index 88fc8747..e08bde24 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -103,7 +103,10 @@ def create_op_with_quant_tensors(
def create_op(op_type, inputs, output, attrs=None, set_ifm_ofm_shapes=True):
op = Operation(op_type, output.name + "_op")
for input in inputs:
- op.add_input_tensor(input)
+ if input: # Add regular tensor input
+ op.add_input_tensor(input)
+ else: # Add optional (None) inputs for operators with sparse input positioning
+ op.inputs.append(input)
op.set_output_tensor(output)
if attrs is not None:
op.attrs = attrs
@@ -112,6 +115,63 @@ def create_op(op_type, inputs, output, attrs=None, set_ifm_ofm_shapes=True):
return op
+def create_lstm_op(batches, times, features, outputs, datatype):
+ input_shape = [batches, times, features]
+ output_shape = [batches, times, outputs]
+ weight_shape = [features, outputs]
+ state_shape = [batches, outputs]
+ bias_shape = [outputs]
+ ifm = Tensor(input_shape, datatype, "in")
+ ifm.quantization = default_quant_params()
+ ofm = Tensor(output_shape, datatype, "out")
+ ofm.quantization = default_quant_params()
+ bias_dtype = DataType.int64 if datatype == DataType.int16 else DataType.int32
+ bias = create_const_tensor("bias", bias_shape, bias_dtype, [0] * outputs)
+ weight_q = default_quant_params()
+ weight = create_const_tensor("weight", weight_shape, DataType.int8, np.ones(weight_shape), quantization=weight_q)
+ output_state = Tensor(state_shape, datatype, "output_state")
+ output_state.quantization = default_quant_params()
+ output_state.is_variable = True
+ cell_state = Tensor(state_shape, DataType.int16, "cell_state")
+ cell_state.quantization = default_quant_params()
+ cell_state.is_variable = True
+ intermediate = Tensor([], DataType.float32, "intermediate")
+ hidden_scale_intermediate = Tensor([], datatype, "effective_hidden_scale_intermediate")
+ hidden_scale_intermediate.quantization = default_quant_params()
+ peephole = None
+ projection = None
+ normalisation = None
+ inputs = [
+ ifm,
+ weight,
+ weight,
+ weight,
+ weight,
+ weight,
+ weight,
+ weight,
+ weight,
+ peephole,
+ peephole,
+ peephole,
+ bias,
+ bias,
+ bias,
+ bias,
+ projection,
+ projection,
+ output_state,
+ cell_state,
+ normalisation,
+ normalisation,
+ normalisation,
+ normalisation,
+ ]
+ op = create_op(Op.UnidirectionalSequenceLstm, inputs, ofm)
+ op.intermediates = [intermediate, intermediate, intermediate, intermediate, hidden_scale_intermediate]
+ return op
+
+
def create_subgraph(op_list):
# Creates subgraph using the given list of operations
sg = Subgraph()