diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-01-11 12:32:02 +0000 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-02-08 15:23:29 +0000 |
commit | a4fb8c72f15146c95df16c25e75f03344e9814fd (patch) | |
tree | ce6d9cf39951a0c85d2773d436cc5010ecf78a8f | |
parent | 09ecc5c8acb758e8def33155feb746a34dd7b560 (diff) | |
download | mlia-a4fb8c72f15146c95df16c25e75f03344e9814fd.tar.gz |
MLIA-591 Create interface for target profiles
New class 'TargetProfile' is used to load and verify target profiles.
Change-Id: I76373a923e2e5f55c4e95860635afe9fc5627a5d
25 files changed, 251 insertions, 227 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py index 437c457..fd5fc13 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -10,10 +10,10 @@ from typing import Any from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext +from mlia.target.config import get_target from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor from mlia.target.tosa.advisor import configure_and_get_tosa_advisor -from mlia.utils.filesystem import get_target logger = logging.getLogger(__name__) diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py index 1974a1d..eb04192 100644 --- a/src/mlia/cli/command_validators.py +++ b/src/mlia/cli/command_validators.py @@ -8,8 +8,8 @@ import logging import sys from mlia.cli.config import get_default_backends +from mlia.target.config import get_target from mlia.target.registry import supported_backends -from mlia.utils.filesystem import get_target logger = logging.getLogger(__name__) diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index 1b92958..dac8c82 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -13,7 +13,7 @@ from mlia.cli.config import DEFAULT_PRUNING_TARGET from mlia.cli.config import get_available_backends from mlia.cli.config import is_corstone_backend from mlia.core.typing import OutputFormat -from mlia.utils.filesystem import get_builtin_supported_profile_names +from mlia.target.config import get_builtin_supported_profile_names def add_check_category_options(parser: argparse.ArgumentParser) -> None: diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index f257784..ec3fb4c 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -1,21 +1,110 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""IP configuration module.""" +"""Target configuration module.""" from __future__ import annotations +from abc import ABC +from abc import abstractmethod from dataclasses import dataclass +from pathlib import Path +from typing import Any +from typing import cast +from typing import TypeVar + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib # type: ignore from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory +from mlia.utils.filesystem import get_mlia_target_profiles_dir + + +def get_profile_file(target_profile: str | Path) -> Path: + """Get the target profile toml file.""" + if not target_profile: + raise Exception("Target profile is not provided.") + + profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml") + if not profile_file.is_file(): + profile_file = Path(target_profile) + + if not profile_file.exists(): + raise Exception(f"File not found: {profile_file}.") + return profile_file + + +def load_profile(path: str | Path) -> dict[str, Any]: + """Get settings for the provided target profile.""" + with open(path, "rb") as file: + profile = tomllib.load(file) + + return cast(dict, profile) + + +def get_builtin_supported_profile_names() -> list[str]: + """Return list of default profiles in the target profiles directory.""" + return sorted( + [ + item.stem + for item in get_mlia_target_profiles_dir().iterdir() + if item.is_file() and item.suffix == ".toml" + ] + ) + +def get_target(target_profile: str | Path) -> str: + """Return target for the provided target_profile.""" + profile_file = get_profile_file(target_profile) + profile = load_profile(profile_file) + return cast(str, profile["target"]) -class IPConfiguration: # pylint: disable=too-few-public-methods - """Base class for IP configuration.""" + +T = TypeVar("T", bound="TargetProfile") + + +class TargetProfile(ABC): + """Base class for target profiles.""" def __init__(self, target: str) -> None: - """Init IP configuration instance.""" + """Init TargetProfile instance with the target name.""" self.target = target + @classmethod + def load(cls: type[T], path: str | Path) -> T: + """Load and verify a target profile from file and return new instance.""" + profile = load_profile(path) + + try: + new_instance = cls(**profile) + except KeyError as ex: + raise KeyError(f"Missing key in file {path}.") from ex + + new_instance.verify() + + return new_instance + + @classmethod + def load_profile(cls: type[T], target_profile: str) -> T: + """Load a target profile by name.""" + profile_file = get_profile_file(target_profile) + return cls.load(profile_file) + + def save(self, path: str | Path) -> None: + """Save this target profile to a file.""" + raise NotImplementedError("Saving target profiles is currently not supported.") + + @abstractmethod + def verify(self) -> None: + """ + Check that all attributes contain valid values etc. + + Raises a ValueError, if an issue is detected. + """ + if not self.target: + raise ValueError(f"Invalid target name: {self.target}") + @dataclass class TargetInfo: diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py index 52af592..518c9f1 100644 --- a/src/mlia/target/cortex_a/advisor.py +++ b/src/mlia/target/cortex_a/advisor.py @@ -58,7 +58,9 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor): target_profile = self.get_target_profile(context) return [ - CortexAAdvisorStartedEvent(model, CortexAConfiguration(target_profile)), + CortexAAdvisorStartedEvent( + model, CortexAConfiguration.load_profile(target_profile) + ), ] diff --git a/src/mlia/target/cortex_a/config.py b/src/mlia/target/cortex_a/config.py index b2b51ea..fd39e0a 100644 --- a/src/mlia/target/cortex_a/config.py +++ b/src/mlia/target/cortex_a/config.py @@ -1,20 +1,23 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cortex-A configuration.""" from __future__ import annotations -from mlia.target.config import IPConfiguration -from mlia.utils.filesystem import get_profile +from typing import Any +from mlia.target.config import TargetProfile -class CortexAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods + +class CortexAConfiguration(TargetProfile): """Cortex-A configuration.""" - def __init__(self, target_profile: str) -> None: + def __init__(self, **kwargs: Any) -> None: """Init Cortex-A target configuration.""" - target_data = get_profile(target_profile) - - target = target_data["target"] - if target != "cortex-a": - raise Exception(f"Wrong target {target} for Cortex-A configuration") + target = kwargs["target"] super().__init__(target) + + def verify(self) -> None: + """Check the parameters.""" + super().verify() + if self.target != "cortex-a": + raise ValueError(f"Wrong target {self.target} for Cortex-A configuration.") diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py index 937e91c..225fd87 100644 --- a/src/mlia/target/ethos_u/advisor.py +++ b/src/mlia/target/ethos_u/advisor.py @@ -19,7 +19,6 @@ from mlia.nn.tensorflow.utils import is_tflite_model from mlia.target.ethos_u.advice_generation import EthosUAdviceProducer from mlia.target.ethos_u.advice_generation import EthosUStaticAdviceProducer from mlia.target.ethos_u.config import EthosUConfiguration -from mlia.target.ethos_u.config import get_target from mlia.target.ethos_u.data_analysis import EthosUDataAnalyzer from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance @@ -40,7 +39,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): def get_collectors(self, context: Context) -> list[DataCollector]: """Return list of the data collectors.""" model = self.get_model(context) - device = self._get_device(context) + device = self._get_device_cfg(context) backends = self._get_backends(context) collectors: list[DataCollector] = [] @@ -88,17 +87,16 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): def get_events(self, context: Context) -> list[Event]: """Return list of the startup events.""" model = self.get_model(context) - device = self._get_device(context) + device = self._get_device_cfg(context) return [ EthosUAdvisorStartedEvent(device=device, model=model), ] - def _get_device(self, context: Context) -> EthosUConfiguration: - """Get device.""" + def _get_device_cfg(self, context: Context) -> EthosUConfiguration: + """Get device configuration.""" target_profile = self.get_target_profile(context) - - return get_target(target_profile) + return EthosUConfiguration.load_profile(target_profile) def _get_optimization_settings(self, context: Context) -> list[list[dict]]: """Get optimization settings.""" diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py index 8d8f481..eb5691d 100644 --- a/src/mlia/target/ethos_u/config.py +++ b/src/mlia/target/ethos_u/config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U configuration.""" from __future__ import annotations @@ -8,36 +8,49 @@ from typing import Any from mlia.backend.vela.compiler import resolve_compiler_config from mlia.backend.vela.compiler import VelaCompilerOptions -from mlia.target.config import IPConfiguration -from mlia.utils.filesystem import get_profile +from mlia.target.config import TargetProfile from mlia.utils.filesystem import get_vela_config logger = logging.getLogger(__name__) -class EthosUConfiguration(IPConfiguration): +class EthosUConfiguration(TargetProfile): """Ethos-U configuration.""" - def __init__(self, target_profile: str) -> None: + def __init__(self, **kwargs: Any) -> None: """Init Ethos-U target configuration.""" - target_data = get_profile(target_profile) - _check_target_data_complete(target_data) - - target = target_data["target"] + target = kwargs["target"] super().__init__(target) - mac = target_data["mac"] - _check_device_options_valid(target, mac) + mac = kwargs["mac"] self.mac = mac self.compiler_options = VelaCompilerOptions( - system_config=target_data["system_config"], - memory_mode=target_data["memory_mode"], + system_config=kwargs["system_config"], + memory_mode=kwargs["memory_mode"], config_files=str(get_vela_config()), accelerator_config=f"{self.target}-{mac}", # type: ignore ) + def verify(self) -> None: + """Check the parameters.""" + super().verify() + + target_mac_ranges = { + "ethos-u55": [32, 64, 128, 256], + "ethos-u65": [256, 512], + } + + if self.target not in target_mac_ranges: + raise ValueError(f"Unsupported target: {self.target}") + + target_mac_range = target_mac_ranges[self.target] + if self.mac not in target_mac_range: + raise ValueError( + f"Mac value for selected device should be in {target_mac_range}." + ) + @property def resolved_compiler_config(self) -> dict[str, Any]: """Resolve compiler configuration.""" @@ -54,37 +67,3 @@ class EthosUConfiguration(IPConfiguration): def __repr__(self) -> str: """Return string representation.""" return f"<Ethos-U configuration target={self.target}>" - - -def get_target(target_profile: str) -> EthosUConfiguration: - """Get target instance based on provided params.""" - if not target_profile: - raise Exception("No target profile given") - - return EthosUConfiguration(target_profile) - - -def _check_target_data_complete(target_data: dict[str, Any]) -> None: - """Check if profile contains all needed data.""" - mandatory_keys = {"target", "mac", "system_config", "memory_mode"} - missing_keys = sorted(mandatory_keys - target_data.keys()) - - if missing_keys: - raise Exception(f"Mandatory fields missing from target profile: {missing_keys}") - - -def _check_device_options_valid(target: str, mac: int) -> None: - """Check if mac is valid for selected device.""" - target_mac_ranges = { - "ethos-u55": [32, 64, 128, 256], - "ethos-u65": [256, 512], - } - - if target not in target_mac_ranges: - raise Exception(f"Unsupported target: {target}") - - target_mac_range = target_mac_ranges[target] - if mac not in target_mac_range: - raise Exception( - f"Mac value for selected device should be in {target_mac_range}" - ) diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py index e8aad53..5588d0f 100644 --- a/src/mlia/target/tosa/advisor.py +++ b/src/mlia/target/tosa/advisor.py @@ -66,7 +66,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor): return [ TOSAAdvisorStartedEvent( model, - TOSAConfiguration(target_profile), + TOSAConfiguration.load_profile(target_profile), MetadataDisplay( TOSAMetadata("tosa-checker"), MLIAMetadata("mlia"), diff --git a/src/mlia/target/tosa/config.py b/src/mlia/target/tosa/config.py index 22805b7..826e719 100644 --- a/src/mlia/target/tosa/config.py +++ b/src/mlia/target/tosa/config.py @@ -1,19 +1,21 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA target configuration.""" -from mlia.target.config import IPConfiguration -from mlia.utils.filesystem import get_profile +from typing import Any +from mlia.target.config import TargetProfile -class TOSAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods + +class TOSAConfiguration(TargetProfile): """TOSA configuration.""" - def __init__(self, target_profile: str) -> None: + def __init__(self, **kwargs: Any) -> None: """Init configuration.""" - target_data = get_profile(target_profile) - target = target_data["target"] - - if target != "tosa": - raise Exception(f"Wrong target {target} for TOSA configuration") - + target = kwargs["target"] super().__init__(target) + + def verify(self) -> None: + """Check the parameters.""" + super().verify() + if self.target != "tosa": + raise ValueError(f"Wrong target {self.target} for TOSA configuration.") diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py index 4734a84..f92629b 100644 --- a/src/mlia/utils/filesystem.py +++ b/src/mlia/utils/filesystem.py @@ -11,16 +11,9 @@ from contextlib import contextmanager from pathlib import Path from tempfile import mkstemp from tempfile import TemporaryDirectory -from typing import Any -from typing import cast from typing import Generator from typing import Iterable -try: - import tomllib -except ModuleNotFoundError: - import tomli as tomllib # type: ignore - def get_mlia_resources() -> Path: """Get the path to the resources directory.""" @@ -39,50 +32,6 @@ def get_mlia_target_profiles_dir() -> Path: return get_mlia_resources() / "target_profiles" -def get_profile_toml_file(target_profile: str | Path) -> str | Path: - """Get the target profile toml file.""" - if not target_profile: - raise Exception("Target profile is not provided") - - profile_toml_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml") - if not profile_toml_file.is_file(): - profile_toml_file = Path(target_profile) - - if not profile_toml_file.exists(): - raise Exception(f"File not found: {profile_toml_file}.") - return profile_toml_file - - -def get_profile(target_profile: str | Path) -> dict[str, Any]: - """Get settings for the provided target profile.""" - if not target_profile: - raise Exception("Target profile is not provided") - - toml_file = get_profile_toml_file(target_profile) - - with open(toml_file, "rb") as file: - profile = tomllib.load(file) - - return cast(dict, profile) - - -def get_builtin_supported_profile_names() -> list[str]: - """Return list of default profiles in the target profiles directory.""" - return sorted( - [ - item.stem - for item in get_mlia_target_profiles_dir().iterdir() - if item.is_file() and item.suffix == ".toml" - ] - ) - - -def get_target(target_profile: str | Path) -> str: - """Return target for the provided target_profile.""" - profile = get_profile(target_profile) - return cast(str, profile["target"]) - - @contextmanager def temp_file(suffix: str | None = None) -> Generator[Path, None, None]: """Create temp file and remove it after.""" diff --git a/tests/conftest.py b/tests/conftest.py index b698a73..67549e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -157,7 +157,7 @@ def fixture_test_models_path( save_tflite_model(tflite_model, tflite_model_path) tflite_vela_model = tmp_path / "test_model_vela.tflite" - device = EthosUConfiguration("ethos-u55-256") + device = EthosUConfiguration.load_profile("ethos-u55-256") optimize_model(tflite_model_path, device.compiler_options, tflite_vela_model) tf.saved_model.save(keras_model, str(tmp_path / "tf_model_test_model")) diff --git a/tests/test_backend_vela_compat.py b/tests/test_backend_vela_compat.py index a2e7f90..4653d7d 100644 --- a/tests/test_backend_vela_compat.py +++ b/tests/test_backend_vela_compat.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module vela/compat.""" from pathlib import Path @@ -55,7 +55,7 @@ from mlia.utils.filesystem import working_directory ) def test_operators(test_models_path: Path, model: str, expected_ops: Operators) -> None: """Test operators function.""" - device = EthosUConfiguration("ethos-u55-256") + device = EthosUConfiguration.load_profile("ethos-u55-256") operators = supported_operators(test_models_path / model, device.compiler_options) for expected, actual in zip(expected_ops.ops, operators.ops): diff --git a/tests/test_backend_vela_compiler.py b/tests/test_backend_vela_compiler.py index 20121d6..2d937ea 100644 --- a/tests/test_backend_vela_compiler.py +++ b/tests/test_backend_vela_compiler.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module vela/compiler.""" from pathlib import Path @@ -146,7 +146,9 @@ def test_vela_compiler_with_parameters(test_resources_path: Path) -> None: def test_compile_model(test_tflite_model: Path) -> None: """Test model optimization.""" - compiler = VelaCompiler(EthosUConfiguration("ethos-u55-256").compiler_options) + compiler = VelaCompiler( + EthosUConfiguration.load_profile("ethos-u55-256").compiler_options + ) optimized_model = compiler.compile_model(test_tflite_model) assert isinstance(optimized_model, OptimizedModel) @@ -156,7 +158,7 @@ def test_optimize_model(tmp_path: Path, test_tflite_model: Path) -> None: """Test model optimization and saving into file.""" tmp_file = tmp_path / "temp.tflite" - device = EthosUConfiguration("ethos-u55-256") + device = EthosUConfiguration.load_profile("ethos-u55-256") optimize_model(test_tflite_model, device.compiler_options, tmp_file.absolute()) assert tmp_file.is_file() diff --git a/tests/test_backend_vela_performance.py b/tests/test_backend_vela_performance.py index 34c11ab..569de61 100644 --- a/tests/test_backend_vela_performance.py +++ b/tests/test_backend_vela_performance.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module vela/performance.""" from pathlib import Path @@ -14,7 +14,7 @@ from mlia.target.ethos_u.config import EthosUConfiguration def test_estimate_performance(test_tflite_model: Path) -> None: """Test getting performance estimations.""" - device = EthosUConfiguration("ethos-u55-256") + device = EthosUConfiguration.load_profile("ethos-u55-256") perf_metrics = estimate_performance(test_tflite_model, device.compiler_options) assert isinstance(perf_metrics, PerformanceMetrics) @@ -24,7 +24,7 @@ def test_estimate_performance_already_optimized( tmp_path: Path, test_tflite_model: Path ) -> None: """Test that performance estimation should fail for already optimized model.""" - device = EthosUConfiguration("ethos-u55-256") + device = EthosUConfiguration.load_profile("ethos-u55-256") optimized_model_path = tmp_path / "optimized_model.tflite" @@ -41,7 +41,7 @@ def test_read_invalid_model(test_tflite_invalid_model: Path) -> None: with pytest.raises( Exception, match=f"Unable to read model {test_tflite_invalid_model}" ): - device = EthosUConfiguration("ethos-u55-256") + device = EthosUConfiguration.load_profile("ethos-u55-256") estimate_performance(test_tflite_invalid_model, device.compiler_options) @@ -58,7 +58,7 @@ def test_compile_invalid_model( with pytest.raises( Exception, match="Model could not be optimized with Vela compiler" ): - device = EthosUConfiguration("ethos-u55-256") + device = EthosUConfiguration.load_profile("ethos-u55-256") optimize_model(test_tflite_model, device.compiler_options, model_path) assert not model_path.exists() diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index b65d90e..61cc5a6 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -79,7 +79,7 @@ def test_opt_valid_optimization_target( def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None: """Mock performance estimation.""" metrics = PerformanceMetrics( - EthosUConfiguration("ethos-u55-256"), + EthosUConfiguration.load_profile("ethos-u55-256"), NPUCycles(1, 2, 3, 4, 5, 6), MemoryUsage(1, 2, 3, 4, 5), ) diff --git a/tests/test_target_config.py b/tests/test_target_config.py index 48f0a58..26f524e 100644 --- a/tests/test_target_config.py +++ b/tests/test_target_config.py @@ -9,15 +9,68 @@ from mlia.backend.config import BackendConfiguration from mlia.backend.config import BackendType from mlia.backend.config import System from mlia.core.common import AdviceCategory -from mlia.target.config import IPConfiguration +from mlia.target.config import get_builtin_supported_profile_names +from mlia.target.config import get_profile_file +from mlia.target.config import load_profile from mlia.target.config import TargetInfo +from mlia.target.config import TargetProfile from mlia.utils.registry import Registry -def test_ip_config() -> None: - """Test the class 'IPConfiguration'.""" - cfg = IPConfiguration("AnyTarget") - assert cfg.target == "AnyTarget" +def test_get_builtin_supported_profile_names() -> None: + """Test profile names getter.""" + assert get_builtin_supported_profile_names() == [ + "cortex-a", + "ethos-u55-128", + "ethos-u55-256", + "ethos-u65-256", + "ethos-u65-512", + "tosa", + ] + + +def test_get_profile_file() -> None: + """Test function 'get_profile_file'.""" + profile_file = get_profile_file("cortex-a") + assert profile_file.is_file() + assert profile_file == get_profile_file(profile_file) + + with pytest.raises(Exception): + get_profile_file("UNKNOWN") + with pytest.raises(Exception): + get_profile_file("") + + +def test_load_profile() -> None: + """Test getting profile data.""" + profile_file = get_profile_file("ethos-u55-256") + assert load_profile(profile_file) == { + "target": "ethos-u55", + "mac": 256, + "memory_mode": "Shared_Sram", + "system_config": "Ethos_U55_High_End_Embedded", + } + + with pytest.raises(Exception, match=r"No such file or directory: 'unknown'"): + load_profile("unknown") + + +def test_target_profile() -> None: + """Test the class 'TargetProfile'.""" + + class MyTargetProfile(TargetProfile): + """Test class deriving from TargetProfile.""" + + def verify(self) -> None: + super().verify() + assert self.target + + profile = MyTargetProfile("AnyTarget") + assert profile.target == "AnyTarget" + + profile = MyTargetProfile("") + with pytest.raises(ValueError): + profile.verify() @pytest.mark.parametrize( diff --git a/tests/test_target_cortex_a_reporters.py b/tests/test_target_cortex_a_reporters.py index 4b39aa1..c32ef7b 100644 --- a/tests/test_target_cortex_a_reporters.py +++ b/tests/test_target_cortex_a_reporters.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Cortex-A reporters.""" from typing import Any @@ -18,7 +18,7 @@ from mlia.target.cortex_a.reporters import report_device def test_report_device() -> None: """Test function report_device().""" - report = report_device(CortexAConfiguration("cortex-a")) + report = report_device(CortexAConfiguration.load_profile("cortex-a")) assert report.to_plain_text() diff --git a/tests/test_target_ethos_u_config.py b/tests/test_target_ethos_u_config.py index 08a20ff..7f13b26 100644 --- a/tests/test_target_ethos_u_config.py +++ b/tests/test_target_ethos_u_config.py @@ -5,14 +5,11 @@ from __future__ import annotations from contextlib import ExitStack as does_not_raise from typing import Any -from unittest.mock import MagicMock import pytest from mlia.backend.vela.compiler import VelaCompilerOptions from mlia.target.ethos_u.config import EthosUConfiguration -from mlia.target.ethos_u.config import get_target -from mlia.utils.filesystem import get_vela_config def test_compiler_options_default_init() -> None: @@ -33,48 +30,28 @@ def test_compiler_options_default_init() -> None: def test_ethosu_target() -> None: """Test Ethos-U target configuration init.""" - default_config = EthosUConfiguration("ethos-u55-256") + default_config = EthosUConfiguration.load_profile("ethos-u55-256") assert default_config.target == "ethos-u55" assert default_config.mac == 256 assert default_config.compiler_options is not None -def test_get_target() -> None: - """Test function get_target.""" - with pytest.raises(Exception, match="No target profile given"): - get_target(None) # type: ignore - - with pytest.raises(Exception, match=r"File not found:*"): - get_target("unknown") - - u65_device = get_target("ethos-u65-512") - - assert isinstance(u65_device, EthosUConfiguration) - assert u65_device.target == "ethos-u65" - assert u65_device.mac == 512 - assert u65_device.compiler_options.accelerator_config == "ethos-u65-512" - assert u65_device.compiler_options.memory_mode == "Dedicated_Sram" - assert u65_device.compiler_options.config_files == str(get_vela_config()) - - @pytest.mark.parametrize( "profile_data, expected_error", [ [ {}, pytest.raises( - Exception, - match="Mandatory fields missing from target profile: " - r"\['mac', 'memory_mode', 'system_config', 'target'\]", + KeyError, + match=r"'target'", ), ], [ {"target": "ethos-u65", "mac": 512}, pytest.raises( - Exception, - match="Mandatory fields missing from target profile: " - r"\['memory_mode', 'system_config'\]", + KeyError, + match=r"'system_config'", ), ], [ @@ -114,12 +91,9 @@ def test_get_target() -> None: ], ) def test_ethosu_configuration( - monkeypatch: pytest.MonkeyPatch, profile_data: dict[str, Any], expected_error: Any + profile_data: dict[str, Any], expected_error: Any ) -> None: """Test creating Ethos-U configuration.""" - monkeypatch.setattr( - "mlia.target.ethos_u.config.get_profile", MagicMock(return_value=profile_data) - ) - with expected_error: - EthosUConfiguration("target") + cfg = EthosUConfiguration(**profile_data) + cfg.verify() diff --git a/tests/test_target_ethos_u_data_analysis.py b/tests/test_target_ethos_u_data_analysis.py index bac27ad..e919f5d 100644 --- a/tests/test_target_ethos_u_data_analysis.py +++ b/tests/test_target_ethos_u_data_analysis.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Ethos-U data analysis module.""" from __future__ import annotations @@ -84,7 +84,7 @@ def test_perf_metrics_diff() -> None: [ OptimizationPerformanceMetrics( PerformanceMetrics( - EthosUConfiguration("ethos-u55-256"), + EthosUConfiguration.load_profile("ethos-u55-256"), NPUCycles(1, 2, 3, 4, 5, 6), # memory metrics are in kilobytes MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore @@ -95,7 +95,7 @@ def test_perf_metrics_diff() -> None: OptimizationSettings("pruning", 0.5, None), ], PerformanceMetrics( - EthosUConfiguration("ethos-u55-256"), + EthosUConfiguration.load_profile("ethos-u55-256"), NPUCycles(1, 2, 3, 4, 5, 6), # memory metrics are in kilobytes MemoryUsage( @@ -127,7 +127,7 @@ def test_perf_metrics_diff() -> None: [ OptimizationPerformanceMetrics( PerformanceMetrics( - EthosUConfiguration("ethos-u55-256"), + EthosUConfiguration.load_profile("ethos-u55-256"), NPUCycles(1, 2, 3, 4, 5, 6), # memory metrics are in kilobytes MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore diff --git a/tests/test_target_ethos_u_data_collection.py b/tests/test_target_ethos_u_data_collection.py index 2cf7482..829d2a7 100644 --- a/tests/test_target_ethos_u_data_collection.py +++ b/tests/test_target_ethos_u_data_collection.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the data collection module for Ethos-U.""" from pathlib import Path @@ -50,7 +50,7 @@ def test_operator_compatibility_collector( sample_context: Context, test_tflite_model: Path ) -> None: """Test operator compatibility data collector.""" - device = EthosUConfiguration("ethos-u55-256") + device = EthosUConfiguration.load_profile("ethos-u55-256") collector = EthosUOperatorCompatibility(test_tflite_model, device) collector.set_context(sample_context) @@ -63,7 +63,7 @@ def test_performance_collector( monkeypatch: pytest.MonkeyPatch, sample_context: Context, test_tflite_model: Path ) -> None: """Test performance data collector.""" - device = EthosUConfiguration("ethos-u55-256") + device = EthosUConfiguration.load_profile("ethos-u55-256") mock_performance_estimation(monkeypatch, device) @@ -81,7 +81,7 @@ def test_optimization_performance_collector( test_tflite_model: Path, ) -> None: """Test optimization performance data collector.""" - device = EthosUConfiguration("ethos-u55-256") + device = EthosUConfiguration.load_profile("ethos-u55-256") mock_performance_estimation(monkeypatch, device) collector = EthosUOptimizationPerformance( diff --git a/tests/test_target_ethos_u_reporters.py b/tests/test_target_ethos_u_reporters.py index bc764a0..0c5764e 100644 --- a/tests/test_target_ethos_u_reporters.py +++ b/tests/test_target_ethos_u_reporters.py @@ -118,7 +118,7 @@ def test_report_operators( "device, expected_plain_text, expected_json_dict", [ [ - EthosUConfiguration("ethos-u55-256"), + EthosUConfiguration.load_profile("ethos-u55-256"), """Device information: Target ethos-u55 MAC 256 diff --git a/tests/test_target_tosa_reporters.py b/tests/test_target_tosa_reporters.py index 59da270..43d2a56 100644 --- a/tests/test_target_tosa_reporters.py +++ b/tests/test_target_tosa_reporters.py @@ -18,7 +18,7 @@ from mlia.target.tosa.reporters import tosa_formatters def test_tosa_report_device() -> None: """Test function report_device().""" - report = report_device(TOSAConfiguration("tosa")) + report = report_device(TOSAConfiguration.load_profile("tosa")) assert report.to_plain_text() diff --git a/tests/test_utils_filesystem.py b/tests/test_utils_filesystem.py index 954f9e3..d0a6e6f 100644 --- a/tests/test_utils_filesystem.py +++ b/tests/test_utils_filesystem.py @@ -9,10 +9,8 @@ import pytest from mlia.utils.filesystem import all_files_exist from mlia.utils.filesystem import all_paths_valid from mlia.utils.filesystem import copy_all -from mlia.utils.filesystem import get_builtin_supported_profile_names from mlia.utils.filesystem import get_mlia_resources from mlia.utils.filesystem import get_mlia_target_profiles_dir -from mlia.utils.filesystem import get_profile from mlia.utils.filesystem import get_vela_config from mlia.utils.filesystem import sha256 from mlia.utils.filesystem import temp_directory @@ -36,31 +34,6 @@ def test_get_mlia_target_profiles() -> None: assert get_mlia_target_profiles_dir().is_dir() -def test_get_builtin_supported_profile_names() -> None: - """Test profile names getter.""" - assert get_builtin_supported_profile_names() == [ - "cortex-a", - "ethos-u55-128", - "ethos-u55-256", - "ethos-u65-256", - "ethos-u65-512", - "tosa", - ] - - -def test_get_profile() -> None: - """Test getting profile data.""" - assert get_profile("ethos-u55-256") == { - "target": "ethos-u55", - "mac": 256, - "memory_mode": "Shared_Sram", - "system_config": "Ethos_U55_High_End_Embedded", - } - - with pytest.raises(Exception, match=r"File not found:*"): - get_profile("unknown") - - @pytest.mark.parametrize("raise_exception", [True, False]) def test_temp_file(raise_exception: bool) -> None: """Test temp_file context manager.""" diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py index 546bd7e..d09d0ab 100644 --- a/tests_e2e/test_e2e.py +++ b/tests_e2e/test_e2e.py @@ -24,7 +24,7 @@ from mlia.cli.config import get_available_backends from mlia.cli.main import get_commands from mlia.cli.main import get_possible_command_names from mlia.cli.main import init_parser -from mlia.utils.filesystem import get_builtin_supported_profile_names +from mlia.target.config import get_builtin_supported_profile_names from mlia.utils.types import is_list_of |