aboutsummaryrefslogtreecommitdiff
path: root/tests/test_target_config.py
blob: 66ebed68f526c429dbc032473af946d098aa468c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the backend config module."""
from __future__ import annotations

import pytest

from mlia.backend.config import BackendConfiguration
from mlia.backend.config import BackendType
from mlia.backend.config import System
from mlia.core.common import AdviceCategory
from mlia.target.config import IPConfiguration
from mlia.target.config import TargetInfo
from mlia.utils.registry import Registry


def test_ip_config() -> None:
    """Test the class 'IPConfiguration'."""
    cfg = IPConfiguration("AnyTarget")
    assert cfg.target == "AnyTarget"


@pytest.mark.parametrize(
    ("advice", "check_system", "supported"),
    (
        (None, False, True),
        (None, True, True),
        (AdviceCategory.OPERATORS, True, True),
        (AdviceCategory.OPTIMIZATION, True, False),
    ),
)
def test_target_info(
    monkeypatch: pytest.MonkeyPatch,
    advice: AdviceCategory | None,
    check_system: bool,
    supported: bool,
) -> None:
    """Test the class 'TargetInfo'."""
    info = TargetInfo(["backend"])

    backend_registry = Registry[BackendConfiguration]()
    backend_registry.register(
        "backend",
        BackendConfiguration(
            [AdviceCategory.OPERATORS],
            [System.CURRENT],
            BackendType.BUILTIN,
        ),
    )
    monkeypatch.setattr("mlia.target.config.backend_registry", backend_registry)

    assert info.is_supported(advice, check_system) == supported
    assert bool(info.filter_supported_backends(advice, check_system)) == supported