From 67e0d8f24fcb86115e834acd19dc57027b03ea4f Mon Sep 17 00:00:00 2001 From: Jacob Bohlin Date: Thu, 20 Aug 2020 10:53:02 +0200 Subject: MLBEDSW-2663: Handle optional tensors Includes a number of changes: * Handle non-existing optional inputs * Handle disabled optional inputs (-1 indexed) * Added unit tests for parsing operators * Add bias tensor to the different Convolutions + FullyConnected if it's missing. Signed-off-by: Jacob Bohlin Change-Id: Ib88d2b610314b1c886fc0aef4f9da87430ce6ae5 --- ethosu/vela/test/test_tflite_reader.py | 43 ++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) (limited to 'ethosu/vela/test/test_tflite_reader.py') diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py index 1ba07423..d63c0007 100644 --- a/ethosu/vela/test/test_tflite_reader.py +++ b/ethosu/vela/test/test_tflite_reader.py @@ -15,6 +15,9 @@ # limitations under the License. # Description: # Contains unit tests for tflite_reader +from unittest.mock import MagicMock +from unittest.mock import patch + import pytest from ethosu.vela.tflite_reader import TFLiteSubgraph @@ -35,3 +38,43 @@ class TestTFLiteSubgraph: def test_len1_array_to_scalar(self, test_input, expected): output = TFLiteSubgraph.len1_array_to_scalar(test_input) assert output == expected + + parse_op_testdata = [ + # op_type, opt_serializer, inputs, output, expected + ("FullyConnected", None, [0, 1, 2], 3, 3), # FC + ("FullyConnected", None, [0, 1, -1], 3, 3), # FC disabled Bias + ("FullyConnected", None, [0, 1], 3, 3), # FC no Bias + ("Conv2D", None, [2, 1, 3], 0, 3), # Conv2D + ("Conv2DBackprop", None, [0, 1, 2, 3], 4, 4), # TransposeConv + ("Conv2DBackprop", None, [0, 1, 2], 4, 4), # TransposeConv no Bias + pytest.param("Conv2D", None, [0, -1, 1], 3, 3, marks=pytest.mark.xfail), # Conv2D no Weights + ] + + @pytest.mark.parametrize("op_type, opt_serializer, inputs, output, expected", parse_op_testdata) + def test_parse_operator(self, op_type, opt_serializer, inputs, output, expected): + with patch.object(TFLiteSubgraph, "__init__", lambda self, graph, subraph: None): + # Mock a TFLiteSubGraph + sg = TFLiteSubgraph(None, None) + sg.graph = MagicMock() + sg.graph.operator_codes = [(op_type, opt_serializer)] + + # Mock a couple of tensors + sg.tensors = [MagicMock() for _ in range(5)] + for i, tens in enumerate(sg.tensors): + tens.name = "tensor_{}".format(i) + tens.ops = [] + + # Mock op data + op_data = MagicMock() + op_data.OpcodeIndex.return_value = 0 + op_data.InputsAsNumpy.return_value = inputs + op_data.OutputsAsNumpy.return_value = [output] + + sg.parse_operator(0, op_data) + + # Verify the created Operation + created_op = sg.tensors[output].ops[0] + assert created_op.type == op_type + assert len(created_op.inputs) == expected + assert created_op.outputs[0].name == "tensor_{}".format(output) + assert inputs[-1] != -1 or not created_op.inputs[-1] -- cgit v1.2.1