aboutsummaryrefslogtreecommitdiff
path: root/tests/test_target_registry.py
blob: ca1ad8234e625ff5a9ed3861419aa997de905935 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the target registry module."""
from __future__ import annotations

import pytest

from mlia.core.common import AdviceCategory
from mlia.target.config import get_builtin_profile_path
from mlia.target.registry import all_supported_backends
from mlia.target.registry import default_backends
from mlia.target.registry import is_supported
from mlia.target.registry import profile
from mlia.target.registry import registry
from mlia.target.registry import supported_advice
from mlia.target.registry import supported_backends
from mlia.target.registry import supported_targets


@pytest.mark.parametrize(
    "expected_target", ("cortex-a", "ethos-u55", "ethos-u65", "tosa")
)
def test_target_registry(expected_target: str) -> None:
    """Test the target registry."""
    assert expected_target in registry.items, (
        f"Expected target '{expected_target}' not contained in registered "
        f"targets '{registry.items.keys()}'."
    )


@pytest.mark.parametrize(
    ("target_name", "expected_advices"),
    (
        ("cortex-a", [AdviceCategory.COMPATIBILITY]),
        (
            "ethos-u55",
            [
                AdviceCategory.COMPATIBILITY,
                AdviceCategory.OPTIMIZATION,
                AdviceCategory.PERFORMANCE,
            ],
        ),
        (
            "ethos-u65",
            [
                AdviceCategory.COMPATIBILITY,
                AdviceCategory.OPTIMIZATION,
                AdviceCategory.PERFORMANCE,
            ],
        ),
        ("tosa", [AdviceCategory.COMPATIBILITY]),
    ),
)
def test_supported_advice(
    target_name: str, expected_advices: list[AdviceCategory]
) -> None:
    """Test function supported_advice()."""
    supported = supported_advice(target_name)
    assert all(advice in expected_advices for advice in supported)
    assert all(advice in supported for advice in expected_advices)


@pytest.mark.parametrize(
    ("backend", "target", "expected_result"),
    (
        ("armnn-tflite-delegate", None, True),
        ("armnn-tflite-delegate", "cortex-a", True),
        ("armnn-tflite-delegate", "tosa", False),
        ("corstone-310", None, True),
        ("corstone-310", "ethos-u55", True),
        ("corstone-310", "ethos-u65", True),
        ("corstone-310", "cortex-a", False),
    ),
)
def test_is_supported(backend: str, target: str | None, expected_result: bool) -> None:
    """Test function is_supported()."""
    assert is_supported(backend, target) == expected_result


@pytest.mark.parametrize(
    ("target_name", "expected_backends"),
    (
        ("cortex-a", ["armnn-tflite-delegate"]),
        ("ethos-u55", ["corstone-300", "corstone-310", "vela"]),
        ("ethos-u65", ["corstone-300", "corstone-310", "vela"]),
        ("tosa", ["tosa-checker"]),
    ),
)
def test_supported_backends(target_name: str, expected_backends: list[str]) -> None:
    """Test function supported_backends()."""
    assert sorted(expected_backends) == sorted(supported_backends(target_name))


@pytest.mark.parametrize(
    ("advice", "expected_targets"),
    (
        (AdviceCategory.COMPATIBILITY, ["cortex-a", "ethos-u55", "ethos-u65", "tosa"]),
        (AdviceCategory.OPTIMIZATION, ["ethos-u55", "ethos-u65"]),
        (AdviceCategory.PERFORMANCE, ["ethos-u55", "ethos-u65"]),
    ),
)
def test_supported_targets(advice: AdviceCategory, expected_targets: list[str]) -> None:
    """Test function supported_targets()."""
    assert sorted(expected_targets) == sorted(supported_targets(advice))


def test_all_supported_backends() -> None:
    """Test function all_supported_backends."""
    assert all_supported_backends() == {
        "vela",
        "tosa-checker",
        "armnn-tflite-delegate",
        "corstone-310",
        "corstone-300",
    }


@pytest.mark.parametrize(
    ("target", "expected_default_backends", "is_subset_only"),
    [
        ["cortex-a", ["armnn-tflite-delegate"], False],
        ["tosa", ["tosa-checker"], False],
        ["ethos-u55", ["vela"], True],
        ["ethos-u65", ["vela"], True],
    ],
)
def test_default_backends(
    target: str,
    expected_default_backends: list[str],
    is_subset_only: bool,
) -> None:
    """Test function default_backends()."""
    if is_subset_only:
        assert set(expected_default_backends).issubset(default_backends(target))
    else:
        assert default_backends(target) == expected_default_backends


@pytest.mark.parametrize(
    "target_profile", ("cortex-a", "tosa", "ethos-u55-128", "ethos-u65-256")
)
def test_profile(target_profile: str) -> None:
    """Test function profile()."""
    # Test loading by built-in profile name
    cfg = profile(target_profile)
    assert target_profile.startswith(cfg.target)

    # Test loading the file directly
    profile_file = get_builtin_profile_path(target_profile)
    cfg = profile(profile_file)
    assert target_profile.startswith(cfg.target)