aboutsummaryrefslogtreecommitdiff
path: root/tests/test_backend_tosa_compat.py
blob: 0b6eaf52ed5992e2ea9a8ba477e23a852afb7ae0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for TOSA compatibility."""
from __future__ import annotations

from pathlib import Path
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock

import pytest

from mlia.backend.errors import BackendUnavailableError
from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info
from mlia.backend.tosa_checker.compat import Operator
from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo


def replace_get_tosa_checker_with_mock(
    monkeypatch: pytest.MonkeyPatch, mock: MagicMock | None
) -> None:
    """Replace TOSA checker with mock."""
    monkeypatch.setattr(
        "mlia.backend.tosa_checker.compat.get_tosa_checker",
        MagicMock(return_value=mock),
    )


def test_compatibility_check_should_fail_if_checker_not_available(
    monkeypatch: pytest.MonkeyPatch, test_tflite_model: str | Path
) -> None:
    """Test that compatibility check should fail if TOSA checker is not available."""
    replace_get_tosa_checker_with_mock(monkeypatch, None)

    with pytest.raises(
        BackendUnavailableError, match="Backend tosa-checker is not available"
    ):
        get_tosa_compatibility_info(test_tflite_model)


@pytest.mark.parametrize(
    "is_tosa_compatible, operators, exception, expected_result",
    [
        [
            True,
            [],
            None,
            TOSACompatibilityInfo(True, [], None, None, None),
        ],
        [
            True,
            [
                SimpleNamespace(
                    location="op_location",
                    name="op_name",
                    is_tosa_compatible=True,
                )
            ],
            None,
            TOSACompatibilityInfo(
                True, [Operator("op_location", "op_name", True)], None, [], []
            ),
        ],
        [
            False,
            [],
            ValueError("error_test"),
            TOSACompatibilityInfo(False, [], ValueError("error_test"), [], []),
        ],
    ],
)
def test_get_tosa_compatibility_info(
    monkeypatch: pytest.MonkeyPatch,
    test_tflite_model: str | Path,
    is_tosa_compatible: bool,
    operators: Any,
    exception: Exception | None,
    expected_result: TOSACompatibilityInfo,
) -> None:
    """Test getting TOSA compatibility information."""
    mock_checker = MagicMock()
    mock_checker.is_tosa_compatible.return_value = is_tosa_compatible
    mock_checker._get_tosa_compatibility_for_ops.return_value = (  # pylint: disable=protected-access
        operators
    )
    if exception:
        mock_checker._get_tosa_compatibility_for_ops.side_effect = (  # pylint: disable=protected-access
            exception
        )
    replace_get_tosa_checker_with_mock(monkeypatch, mock_checker)

    returned_compatibility_info = get_tosa_compatibility_info(test_tflite_model)
    assert repr(returned_compatibility_info.exception) == repr(
        expected_result.exception
    )
    assert (
        returned_compatibility_info.tosa_compatible == expected_result.tosa_compatible
    )
    assert returned_compatibility_info.operators == expected_result.operators