1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
|
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the tflite_graph module."""
import json
from pathlib import Path
from mlia.nn.tensorflow.tflite_graph import Op
from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
from mlia.nn.tensorflow.tflite_graph import TensorInfo
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
from mlia.nn.tensorflow.tflite_graph import TFL_OP
from mlia.nn.tensorflow.tflite_graph import TFL_TYPE
def test_tensor_info() -> None:
"""Test class 'TensorInfo'."""
expected = {
"name": "Test",
"type": TFL_TYPE.INT8.name,
"shape": (1,),
"is_variable": False,
}
info = TensorInfo(**expected)
assert vars(info) == expected
expected = {
"name": "Test2",
"type": TFL_TYPE.FLOAT32.name,
"shape": [2, 3],
"is_variable": True,
}
tensor_dict = {
"name": [ord(c) for c in expected["name"]],
"type": TFL_TYPE[expected["type"]],
"shape": expected["shape"],
"is_variable": expected["is_variable"],
}
info = TensorInfo.from_dict(tensor_dict)
assert vars(info) == expected
json_repr = json.loads(repr(info))
assert vars(info) == json_repr
assert str(info)
def test_op() -> None:
"""Test class 'Op'."""
expected = {
"type": TFL_OP.CONV_2D.name,
"builtin_options": {},
"inputs": [],
"outputs": [],
"custom_type": None,
}
oper = Op(**expected)
assert vars(oper) == expected
expected["builtin_options"] = {"some_random_option": 3.14}
oper = Op(**expected)
assert vars(oper) == expected
activation_func = TFL_ACTIVATION_FUNCTION.RELU
expected["builtin_options"] = {"fused_activation_function": activation_func.value}
oper = Op(**expected)
assert oper.builtin_options
assert oper.builtin_options["fused_activation_function"] == activation_func.name
assert str(oper)
assert repr(oper)
def test_parse_subgraphs(test_tflite_model: Path) -> None:
"""Test function 'parse_subgraphs'."""
model = parse_subgraphs(test_tflite_model)
assert len(model) == 1
assert len(model[0]) == 5
for oper in model[0]:
assert TFL_OP[oper.type] in TFL_OP
assert len(oper.inputs) > 0
assert len(oper.outputs) > 0
|