diff options
Diffstat (limited to 'ethosu/vela/test/test_tflite_reader.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_reader.py | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py index 1ba07423..d63c0007 100644 --- a/ethosu/vela/test/test_tflite_reader.py +++ b/ethosu/vela/test/test_tflite_reader.py @@ -15,6 +15,9 @@ # limitations under the License. # Description: # Contains unit tests for tflite_reader +from unittest.mock import MagicMock +from unittest.mock import patch + import pytest from ethosu.vela.tflite_reader import TFLiteSubgraph @@ -35,3 +38,43 @@ class TestTFLiteSubgraph: def test_len1_array_to_scalar(self, test_input, expected): output = TFLiteSubgraph.len1_array_to_scalar(test_input) assert output == expected + + parse_op_testdata = [ + # op_type, opt_serializer, inputs, output, expected + ("FullyConnected", None, [0, 1, 2], 3, 3), # FC + ("FullyConnected", None, [0, 1, -1], 3, 3), # FC disabled Bias + ("FullyConnected", None, [0, 1], 3, 3), # FC no Bias + ("Conv2D", None, [2, 1, 3], 0, 3), # Conv2D + ("Conv2DBackprop", None, [0, 1, 2, 3], 4, 4), # TransposeConv + ("Conv2DBackprop", None, [0, 1, 2], 4, 4), # TransposeConv no Bias + pytest.param("Conv2D", None, [0, -1, 1], 3, 3, marks=pytest.mark.xfail), # Conv2D no Weights + ] + + @pytest.mark.parametrize("op_type, opt_serializer, inputs, output, expected", parse_op_testdata) + def test_parse_operator(self, op_type, opt_serializer, inputs, output, expected): + with patch.object(TFLiteSubgraph, "__init__", lambda self, graph, subraph: None): + # Mock a TFLiteSubGraph + sg = TFLiteSubgraph(None, None) + sg.graph = MagicMock() + sg.graph.operator_codes = [(op_type, opt_serializer)] + + # Mock a couple of tensors + sg.tensors = [MagicMock() for _ in range(5)] + for i, tens in enumerate(sg.tensors): + tens.name = "tensor_{}".format(i) + tens.ops = [] + + # Mock op data + op_data = MagicMock() + op_data.OpcodeIndex.return_value = 0 + op_data.InputsAsNumpy.return_value = inputs + op_data.OutputsAsNumpy.return_value = [output] + + sg.parse_operator(0, op_data) + + # Verify the created Operation + created_op = sg.tensors[output].ops[0] + assert created_op.type == op_type + assert len(created_op.inputs) == expected + assert created_op.outputs[0].name == "tensor_{}".format(output) + assert inputs[-1] != -1 or not created_op.inputs[-1] |