aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_reader.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_tflite_reader.py')
-rw-r--r--ethosu/vela/test/test_tflite_reader.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py
index 1ba07423..d63c0007 100644
--- a/ethosu/vela/test/test_tflite_reader.py
+++ b/ethosu/vela/test/test_tflite_reader.py
@@ -15,6 +15,9 @@
# limitations under the License.
# Description:
# Contains unit tests for tflite_reader
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
import pytest
from ethosu.vela.tflite_reader import TFLiteSubgraph
@@ -35,3 +38,43 @@ class TestTFLiteSubgraph:
def test_len1_array_to_scalar(self, test_input, expected):
output = TFLiteSubgraph.len1_array_to_scalar(test_input)
assert output == expected
+
+ parse_op_testdata = [
+ # op_type, opt_serializer, inputs, output, expected
+ ("FullyConnected", None, [0, 1, 2], 3, 3), # FC
+ ("FullyConnected", None, [0, 1, -1], 3, 3), # FC disabled Bias
+ ("FullyConnected", None, [0, 1], 3, 3), # FC no Bias
+ ("Conv2D", None, [2, 1, 3], 0, 3), # Conv2D
+ ("Conv2DBackprop", None, [0, 1, 2, 3], 4, 4), # TransposeConv
+ ("Conv2DBackprop", None, [0, 1, 2], 4, 4), # TransposeConv no Bias
+ pytest.param("Conv2D", None, [0, -1, 1], 3, 3, marks=pytest.mark.xfail), # Conv2D no Weights
+ ]
+
+ @pytest.mark.parametrize("op_type, opt_serializer, inputs, output, expected", parse_op_testdata)
+ def test_parse_operator(self, op_type, opt_serializer, inputs, output, expected):
+ with patch.object(TFLiteSubgraph, "__init__", lambda self, graph, subraph: None):
+ # Mock a TFLiteSubGraph
+ sg = TFLiteSubgraph(None, None)
+ sg.graph = MagicMock()
+ sg.graph.operator_codes = [(op_type, opt_serializer)]
+
+ # Mock a couple of tensors
+ sg.tensors = [MagicMock() for _ in range(5)]
+ for i, tens in enumerate(sg.tensors):
+ tens.name = "tensor_{}".format(i)
+ tens.ops = []
+
+ # Mock op data
+ op_data = MagicMock()
+ op_data.OpcodeIndex.return_value = 0
+ op_data.InputsAsNumpy.return_value = inputs
+ op_data.OutputsAsNumpy.return_value = [output]
+
+ sg.parse_operator(0, op_data)
+
+ # Verify the created Operation
+ created_op = sg.tensors[output].ops[0]
+ assert created_op.type == op_type
+ assert len(created_op.inputs) == expected
+ assert created_op.outputs[0].name == "tensor_{}".format(output)
+ assert inputs[-1] != -1 or not created_op.inputs[-1]