aboutsummaryrefslogtreecommitdiff
path: root/tests/test_backend_registry.py
blob: 1f729b6d70e3e3e6f004762f0d3492a93e55a970 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the backend registry module."""
from __future__ import annotations

from functools import partial

import pytest

from mlia.backend.config import BackendType
from mlia.backend.config import System
from mlia.backend.registry import get_supported_backends
from mlia.backend.registry import get_supported_systems
from mlia.backend.registry import registry
from mlia.core.common import AdviceCategory


@pytest.mark.parametrize(
    ("backend", "advices", "systems", "type_"),
    (
        (
            "armnn-tflite-delegate",
            [AdviceCategory.COMPATIBILITY],
            None,
            BackendType.BUILTIN,
        ),
        (
            "corstone-300",
            [
                AdviceCategory.COMPATIBILITY,
                AdviceCategory.PERFORMANCE,
                AdviceCategory.OPTIMIZATION,
            ],
            [System.LINUX_AMD64, System.LINUX_AARCH64],
            BackendType.CUSTOM,
        ),
        (
            "corstone-310",
            [
                AdviceCategory.COMPATIBILITY,
                AdviceCategory.PERFORMANCE,
                AdviceCategory.OPTIMIZATION,
            ],
            [System.LINUX_AMD64, System.LINUX_AARCH64],
            BackendType.CUSTOM,
        ),
        (
            "tosa-checker",
            [AdviceCategory.COMPATIBILITY],
            [System.LINUX_AMD64],
            BackendType.WHEEL,
        ),
        (
            "vela",
            [
                AdviceCategory.COMPATIBILITY,
                AdviceCategory.PERFORMANCE,
                AdviceCategory.OPTIMIZATION,
            ],
            [
                System.LINUX_AMD64,
                System.LINUX_AARCH64,
                System.WINDOWS_AMD64,
                System.WINDOWS_AARCH64,
            ],
            BackendType.BUILTIN,
        ),
    ),
)
def test_backend_registry(
    backend: str,
    advices: list[AdviceCategory],
    systems: list[System] | None,
    type_: BackendType,
) -> None:
    """Test the backend registry."""
    sorted_by_name = partial(sorted, key=lambda x: x.name)

    assert backend in registry.items
    cfg = registry.items[backend]
    assert sorted_by_name(advices) == sorted_by_name(
        cfg.supported_advice
    ), f"Advices differs: {advices} != {cfg.supported_advice}"
    if systems is None:
        assert cfg.supported_systems is None
    else:
        assert cfg.supported_systems is not None
        assert sorted_by_name(systems) == sorted_by_name(
            cfg.supported_systems
        ), f"Supported systems differs: {advices} != {cfg.supported_advice}"
    assert cfg.type == type_


def test_get_supported_backends() -> None:
    """Test function get_supported_backends."""
    assert get_supported_backends() == [
        "armnn-tflite-delegate",
        "corstone-300",
        "corstone-310",
        "tosa-checker",
        "vela",
    ]


def test_get_supported_systems() -> None:
    """Test function get_supported_systems."""
    assert get_supported_systems()