# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the backend config module.""" from __future__ import annotations from pathlib import Path import pytest from mlia.backend.config import BackendConfiguration from mlia.backend.config import BackendType from mlia.backend.config import System from mlia.core.common import AdviceCategory from mlia.target.config import copy_profile_file_to_output_dir from mlia.target.config import get_builtin_supported_profile_names from mlia.target.config import get_profile_file from mlia.target.config import load_profile from mlia.target.config import TargetInfo from mlia.target.config import TargetProfile from mlia.utils.registry import Registry def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None: """Test if the profile file is copied into the output directory.""" test_target_profile_name = "ethos-u55-128" test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml") copy_profile_file_to_output_dir(test_target_profile_name, tmp_path) assert Path.is_file(test_file_path) def test_get_builtin_supported_profile_names() -> None: """Test profile names getter.""" assert get_builtin_supported_profile_names() == [ "cortex-a", "ethos-u55-128", "ethos-u55-256", "ethos-u65-256", "ethos-u65-512", "tosa", ] def test_get_profile_file() -> None: """Test function 'get_profile_file'.""" profile_file = get_profile_file("cortex-a") assert profile_file.is_file() assert profile_file == get_profile_file(profile_file) with pytest.raises(Exception): get_profile_file("UNKNOWN") with pytest.raises(Exception): get_profile_file("") def test_load_profile() -> None: """Test getting profile data.""" profile_file = get_profile_file("ethos-u55-256") assert load_profile(profile_file) == { "target": "ethos-u55", "mac": 256, "memory_mode": "Shared_Sram", "system_config": "Ethos_U55_High_End_Embedded", } with pytest.raises(Exception, match=r"No such file or directory: 'unknown'"): load_profile("unknown") def test_target_profile() -> None: """Test the class 'TargetProfile'.""" class MyTargetProfile(TargetProfile): """Test class deriving from TargetProfile.""" def verify(self) -> None: super().verify() assert self.target profile = MyTargetProfile("AnyTarget") assert profile.target == "AnyTarget" profile = MyTargetProfile("") with pytest.raises(ValueError): profile.verify() @pytest.mark.parametrize( ("advice", "check_system", "supported"), ( (None, False, True), (None, True, True), (AdviceCategory.COMPATIBILITY, True, True), (AdviceCategory.OPTIMIZATION, True, False), ), ) def test_target_info( monkeypatch: pytest.MonkeyPatch, advice: AdviceCategory | None, check_system: bool, supported: bool, ) -> None: """Test the class 'TargetInfo'.""" info = TargetInfo(["backend"]) backend_registry = Registry[BackendConfiguration]() backend_registry.register( "backend", BackendConfiguration( [AdviceCategory.COMPATIBILITY], [System.CURRENT], BackendType.BUILTIN, ), ) monkeypatch.setattr("mlia.target.config.backend_registry", backend_registry) assert info.is_supported(advice, check_system) == supported assert bool(info.filter_supported_backends(advice, check_system)) == supported