aboutsummaryrefslogtreecommitdiff
path: root/tests/aiet/test_check_model.py
blob: 4eafe59e7c06848073058ba9991fa6682fdf5044 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=redefined-outer-name,no-self-use
"""Module for testing check_model.py script."""
from pathlib import Path
from typing import Any

import pytest
from ethosu.vela.tflite.Model import Model
from ethosu.vela.tflite.OperatorCode import OperatorCode

from aiet.cli.common import InvalidTFLiteFileError
from aiet.cli.common import ModelOptimisedException
from aiet.resources.tools.vela.check_model import check_custom_codes_for_ethosu
from aiet.resources.tools.vela.check_model import check_model
from aiet.resources.tools.vela.check_model import get_custom_codes_from_operators
from aiet.resources.tools.vela.check_model import get_model_from_file
from aiet.resources.tools.vela.check_model import get_operators_from_model
from aiet.resources.tools.vela.check_model import is_vela_optimised


@pytest.fixture(scope="session")
def optimised_tflite_model(
    optimised_input_model_file: Path,
) -> Model:
    """Return Model instance read from a Vela-optimised TFLite file."""
    return get_model_from_file(optimised_input_model_file)


@pytest.fixture(scope="session")
def non_optimised_tflite_model(
    non_optimised_input_model_file: Path,
) -> Model:
    """Return Model instance read from a Vela-optimised TFLite file."""
    return get_model_from_file(non_optimised_input_model_file)


class TestIsVelaOptimised:
    """Test class for is_vela_optimised() function."""

    def test_return_true_when_input_is_optimised(
        self,
        optimised_tflite_model: Model,
    ) -> None:
        """Verify True returned when input is optimised model."""
        output = is_vela_optimised(optimised_tflite_model)

        assert output is True

    def test_return_false_when_input_is_not_optimised(
        self,
        non_optimised_tflite_model: Model,
    ) -> None:
        """Verify False returned when input is non-optimised model."""
        output = is_vela_optimised(non_optimised_tflite_model)

        assert output is False


def test_get_operator_list_returns_correct_instances(
    optimised_tflite_model: Model,
) -> None:
    """Verify list of OperatorCode instances returned by get_operator_list()."""
    operator_list = get_operators_from_model(optimised_tflite_model)

    assert all(isinstance(operator, OperatorCode) for operator in operator_list)


class TestGetCustomCodesFromOperators:
    """Test the get_custom_codes_from_operators() function."""

    def test_returns_empty_list_when_input_operators_have_no_custom_codes(
        self, monkeypatch: Any
    ) -> None:
        """Verify function returns empty list when operators have no custom codes."""
        # Mock OperatorCode.CustomCode() function to return None
        monkeypatch.setattr(
            "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", lambda _: None
        )

        operators = [OperatorCode()] * 3

        custom_codes = get_custom_codes_from_operators(operators)

        assert custom_codes == []

    def test_returns_custom_codes_when_input_operators_have_custom_codes(
        self, monkeypatch: Any
    ) -> None:
        """Verify list of bytes objects returned representing the CustomCodes."""
        # Mock OperatorCode.CustomCode() function to return a byte string
        monkeypatch.setattr(
            "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode",
            lambda _: b"custom-code",
        )

        operators = [OperatorCode()] * 3

        custom_codes = get_custom_codes_from_operators(operators)

        assert custom_codes == [b"custom-code", b"custom-code", b"custom-code"]


@pytest.mark.parametrize(
    "custom_codes, expected_output",
    [
        ([b"ethos-u", b"something else"], True),
        ([b"custom-code-1", b"custom-code-2"], False),
    ],
)
def test_check_list_for_ethosu(custom_codes: list, expected_output: bool) -> None:
    """Verify function detects 'ethos-u' bytes in the input list."""
    output = check_custom_codes_for_ethosu(custom_codes)
    assert output is expected_output


class TestGetModelFromFile:
    """Test the get_model_from_file() function."""

    def test_error_raised_when_input_is_invalid_model_file(
        self,
        invalid_input_model_file: Path,
    ) -> None:
        """Verify error thrown when an invalid model file is given."""
        with pytest.raises(InvalidTFLiteFileError):
            get_model_from_file(invalid_input_model_file)

    def test_model_instance_returned_when_input_is_valid_model_file(
        self,
        optimised_input_model_file: Path,
    ) -> None:
        """Verify file is read successfully and returns model instance."""
        tflite_model = get_model_from_file(optimised_input_model_file)

        assert isinstance(tflite_model, Model)


class TestCheckModel:
    """Test the check_model() function."""

    def test_check_model_with_non_optimised_input(
        self,
        non_optimised_input_model_file: Path,
    ) -> None:
        """Verify no error occurs for a valid input file."""
        check_model(non_optimised_input_model_file)

    def test_check_model_with_optimised_input(
        self,
        optimised_input_model_file: Path,
    ) -> None:
        """Verify that the right exception is raised with already optimised input."""
        with pytest.raises(ModelOptimisedException):
            check_model(optimised_input_model_file)

    def test_check_model_with_invalid_input(
        self,
        invalid_input_model_file: Path,
    ) -> None:
        """Verify that an exception is raised with invalid input."""
        with pytest.raises(Exception):
            check_model(invalid_input_model_file)