From 7a661257b6adad0c8f53e32b42ced56a1e7d952f Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Thu, 2 Feb 2023 14:02:05 +0000 Subject: MLIA-769 Expand use of target/backend registries - Use the target/backend registries to avoid hard-coded names. - Cache target profiles to avoid re-loading them Change-Id: I474b7c9ef23894e1d8a3ea06d13a37652054c62e --- src/mlia/api.py | 21 ++------- src/mlia/backend/corstone/__init__.py | 30 ++++++------- src/mlia/backend/manager.py | 13 ++++++ src/mlia/cli/command_validators.py | 20 ++++----- src/mlia/cli/commands.py | 2 +- src/mlia/cli/config.py | 69 ----------------------------- src/mlia/cli/helpers.py | 24 +++++++++- src/mlia/cli/main.py | 15 +++++-- src/mlia/cli/options.py | 43 +++++++++++++----- src/mlia/target/config.py | 71 +++++++++++++++--------------- src/mlia/target/cortex_a/__init__.py | 12 ++++- src/mlia/target/cortex_a/advisor.py | 4 +- src/mlia/target/ethos_u/__init__.py | 19 +++++++- src/mlia/target/ethos_u/advisor.py | 4 +- src/mlia/target/ethos_u/config.py | 22 ++++++++- src/mlia/target/registry.py | 68 +++++++++++++++++++++++++++- src/mlia/target/tosa/__init__.py | 12 ++++- src/mlia/target/tosa/advisor.py | 4 +- tests/conftest.py | 7 ++- tests/test_api.py | 8 +--- tests/test_cli_command_validators.py | 57 +++++++++++------------- tests/test_cli_commands.py | 8 +++- tests/test_cli_config.py | 48 -------------------- tests/test_cli_helpers.py | 11 +++++ tests/test_target_config.py | 63 +++++++++++++++----------- tests/test_target_ethos_u_data_analysis.py | 9 ++-- tests/test_target_ethos_u_reporters.py | 5 ++- tests/test_target_registry.py | 57 ++++++++++++++++++++++++ tests_e2e/test_e2e.py | 2 +- 29 files changed, 432 insertions(+), 296 deletions(-) delete mode 100644 src/mlia/cli/config.py delete mode 100644 tests/test_cli_config.py diff --git a/src/mlia/api.py b/src/mlia/api.py index fd5fc13..7adae48 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -10,10 +10,8 @@ from typing import Any from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext -from mlia.target.config import get_target -from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor -from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor -from mlia.target.tosa.advisor import configure_and_get_tosa_advisor +from mlia.target.registry import profile +from mlia.target.registry import registry as target_registry logger = logging.getLogger(__name__) @@ -84,19 +82,8 @@ def get_advisor( **extra_args: Any, ) -> InferenceAdvisor: """Find appropriate advisor for the target.""" - target_factories = { - "ethos-u55": configure_and_get_ethosu_advisor, - "ethos-u65": configure_and_get_ethosu_advisor, - "tosa": configure_and_get_tosa_advisor, - "cortex-a": configure_and_get_cortexa_advisor, - } - - try: - target = get_target(target_profile) - factory_function = target_factories[target] - except KeyError as err: - raise Exception(f"Unsupported profile {target_profile}") from err - + target = profile(target_profile).target + factory_function = target_registry.items[target].advisor_factory_func return factory_function( context, target_profile, diff --git a/src/mlia/backend/corstone/__init__.py b/src/mlia/backend/corstone/__init__.py index 36f74ee..b59ab65 100644 --- a/src/mlia/backend/corstone/__init__.py +++ b/src/mlia/backend/corstone/__init__.py @@ -7,24 +7,20 @@ from mlia.backend.config import System from mlia.backend.registry import registry from mlia.core.common import AdviceCategory -registry.register( - "Corstone-300", - BackendConfiguration( - supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], - supported_systems=[System.LINUX_AMD64], - backend_type=BackendType.CUSTOM, - ), -) -registry.register( - "Corstone-310", - BackendConfiguration( - supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], - supported_systems=[System.LINUX_AMD64], - backend_type=BackendType.CUSTOM, - ), -) +# List of mutually exclusive Corstone backends ordered by priority +CORSTONE_PRIORITY = ("Corstone-310", "Corstone-300") + +for corstone_name in CORSTONE_PRIORITY: + registry.register( + corstone_name, + BackendConfiguration( + supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], + supported_systems=[System.LINUX_AMD64], + backend_type=BackendType.CUSTOM, + ), + ) def is_corstone_backend(backend_name: str) -> bool: """Check if backend belongs to Corstone.""" - return backend_name in ["Corstone-300", "Corstone-310"] + return backend_name in CORSTONE_PRIORITY diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py index b0fa919..d953b2d 100644 --- a/src/mlia/backend/manager.py +++ b/src/mlia/backend/manager.py @@ -9,11 +9,13 @@ from abc import abstractmethod from pathlib import Path from typing import Callable +from mlia.backend.config import BackendType from mlia.backend.corstone.install import get_corstone_installations from mlia.backend.install import DownloadAndInstall from mlia.backend.install import Installation from mlia.backend.install import InstallationType from mlia.backend.install import InstallFromPath +from mlia.backend.registry import registry as backend_registry from mlia.backend.tosa_checker.install import get_tosa_backend_installation from mlia.core.errors import ConfigurationError from mlia.core.errors import InternalError @@ -279,3 +281,14 @@ def get_installation_manager(noninteractive: bool = False) -> InstallationManage backends.append(get_tosa_backend_installation()) return DefaultInstallationManager(backends, noninteractive=noninteractive) + + +def get_available_backends() -> list[str]: + """Return list of the available backends.""" + manager = get_installation_manager() + available_backends = [ + backend + for backend, cfg in backend_registry.items.items() + if cfg.type == BackendType.BUILTIN or manager.backend_installed(backend) + ] + return available_backends diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py index 23101e0..a0f5433 100644 --- a/src/mlia/cli/command_validators.py +++ b/src/mlia/cli/command_validators.py @@ -7,8 +7,8 @@ import argparse import logging import sys -from mlia.cli.config import get_default_backends_dict -from mlia.target.config import get_target +from mlia.target.registry import default_backends +from mlia.target.registry import get_target from mlia.target.registry import supported_backends logger = logging.getLogger(__name__) @@ -26,22 +26,18 @@ def validate_backend( target = get_target(target_profile) if not backend: - return get_default_backends_dict()[target] + return default_backends(target) - compatible_backends = supported_backends(target) + compatible_backends = list(map(normalize_string, supported_backends(target))) + backends = {normalize_string(b): b for b in backend} - nor_backend = list(map(normalize_string, backend)) - nor_compat_backend = list(map(normalize_string, compatible_backends)) - - incompatible_backends = [ - backend[i] for i, x in enumerate(nor_backend) if x not in nor_compat_backend - ] + incompatible_backends = [b for b in backends if b not in compatible_backends] # Throw an error if any unsupported backends are used if incompatible_backends: raise argparse.ArgumentError( None, - f"{', '.join(incompatible_backends)} backend not supported " - f"with target-profile {target_profile}.", + f"Backend {', '.join(backends[b] for b in incompatible_backends)} " + f"not supported with target-profile {target_profile}.", ) return backend diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py index c17d571..27f5b2b 100644 --- a/src/mlia/cli/commands.py +++ b/src/mlia/cli/commands.py @@ -23,9 +23,9 @@ from pathlib import Path from mlia.api import ExecutionContext from mlia.api import get_advice +from mlia.backend.manager import get_installation_manager from mlia.cli.command_validators import validate_backend from mlia.cli.command_validators import validate_check_target_profile -from mlia.cli.config import get_installation_manager from mlia.cli.options import parse_optimization_parameters from mlia.utils.console import create_section_header diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py deleted file mode 100644 index 433300c..0000000 --- a/src/mlia/cli/config.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Environment configuration functions.""" -from __future__ import annotations - -import logging - -from mlia.backend.manager import get_installation_manager -from mlia.target.registry import all_supported_backends - -logger = logging.getLogger(__name__) - -DEFAULT_PRUNING_TARGET = 0.5 -DEFAULT_CLUSTERING_TARGET = 32 - - -def get_available_backends() -> list[str]: - """Return list of the available backends.""" - available_backends = ["Vela", "ArmNNTFLiteDelegate"] - - # Add backends using backend manager - manager = get_installation_manager() - available_backends.extend( - backend - for backend in all_supported_backends() - if manager.backend_installed(backend) - ) - - return available_backends - - -# List of mutually exclusive Corstone backends ordered by priority -_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300") -_NON_ETHOS_U_BACKENDS = ("tosa-checker", "ArmNNTFLiteDelegate") - - -def get_ethos_u_default_backends(backends: list[str]) -> list[str]: - """Get Ethos-U default backends for evaluation.""" - return [x for x in backends if x not in _NON_ETHOS_U_BACKENDS] - - -def get_default_backends() -> list[str]: - """Get default backends for evaluation.""" - backends = get_available_backends() - - # Filter backends to only include one Corstone backend - for corstone in _CORSTONE_EXCLUSIVE_PRIORITY: - if corstone in backends: - backends = [ - backend - for backend in backends - if backend == corstone or backend not in _CORSTONE_EXCLUSIVE_PRIORITY - ] - break - - return backends - - -def get_default_backends_dict() -> dict[str, list[str]]: - """Return default backends for all targets.""" - default_backends = get_default_backends() - ethos_u_defaults = get_ethos_u_default_backends(default_backends) - - return { - "ethos-u55": ethos_u_defaults, - "ethos-u65": ethos_u_defaults, - "tosa": ["tosa-checker"], - "cortex-a": ["ArmNNTFLiteDelegate"], - } diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py index ac64581..576670b 100644 --- a/src/mlia/cli/helpers.py +++ b/src/mlia/cli/helpers.py @@ -1,14 +1,19 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Module for various helper classes.""" +"""Module for various helpers.""" from __future__ import annotations +from pathlib import Path +from shutil import copy from typing import Any +from typing import cast from mlia.cli.options import get_target_profile_opts from mlia.core.helpers import ActionResolver from mlia.nn.tensorflow.optimizations.select import OptimizationSettings from mlia.nn.tensorflow.utils import is_keras_model +from mlia.target.config import get_builtin_profile_path +from mlia.target.config import is_builtin_profile from mlia.utils.types import is_list_of @@ -108,3 +113,20 @@ class CLIActionResolver(ActionResolver): model_path = self.args.get("model") return model_path, device_opts + + +def copy_profile_file_to_output_dir( + target_profile: str | Path, output_dir: str | Path +) -> bool: + """Copy the target profile file to the output directory.""" + profile_file_path = ( + get_builtin_profile_path(cast(str, target_profile)) + if is_builtin_profile(target_profile) + else Path(target_profile) + ) + output_file_path = f"{output_dir}/{profile_file_path.stem}.toml" + try: + copy(profile_file_path, output_file_path) + return True + except OSError as err: + raise RuntimeError("Failed to copy profile file:", err.strerror) from err diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 793e155..b3a9d4c 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -18,6 +18,7 @@ from mlia.cli.commands import check from mlia.cli.commands import optimize from mlia.cli.common import CommandInfo from mlia.cli.helpers import CLIActionResolver +from mlia.cli.helpers import copy_profile_file_to_output_dir from mlia.cli.options import add_backend_install_options from mlia.cli.options import add_backend_options from mlia.cli.options import add_backend_uninstall_options @@ -30,11 +31,11 @@ from mlia.cli.options import add_output_directory from mlia.cli.options import add_output_options from mlia.cli.options import add_target_options from mlia.cli.options import get_output_format +from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext from mlia.core.errors import ConfigurationError from mlia.core.errors import InternalError from mlia.core.logging import setup_logging -from mlia.target.config import copy_profile_file_to_output_dir from mlia.target.registry import table as target_table @@ -59,7 +60,13 @@ def get_commands() -> list[CommandInfo]: [ add_output_directory, add_model_options, - add_target_options, + partial( + add_target_options, + supported_advice=[ + AdviceCategory.COMPATIBILITY, + AdviceCategory.PERFORMANCE, + ], + ), add_backend_options, add_check_category_options, add_output_options, @@ -72,7 +79,9 @@ def get_commands() -> list[CommandInfo]: [ add_output_directory, add_keras_model_options, - partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]), + partial( + add_target_options, supported_advice=[AdviceCategory.OPTIMIZATION] + ), partial( add_backend_options, backends_to_skip=["tosa-checker", "ArmNNTFLiteDelegate"], diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index 421533a..8cd2935 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -7,13 +7,17 @@ import argparse from pathlib import Path from typing import Any from typing import Callable +from typing import Sequence from mlia.backend.corstone import is_corstone_backend -from mlia.cli.config import DEFAULT_CLUSTERING_TARGET -from mlia.cli.config import DEFAULT_PRUNING_TARGET -from mlia.cli.config import get_available_backends +from mlia.backend.manager import get_available_backends +from mlia.core.common import AdviceCategory from mlia.core.typing import OutputFormat -from mlia.target.config import get_builtin_supported_profile_names +from mlia.target.registry import builtin_profile_names +from mlia.target.registry import registry as target_registry + +DEFAULT_PRUNING_TARGET = 0.5 +DEFAULT_CLUSTERING_TARGET = 32 def add_check_category_options(parser: argparse.ArgumentParser) -> None: @@ -31,22 +35,39 @@ def add_check_category_options(parser: argparse.ArgumentParser) -> None: def add_target_options( parser: argparse.ArgumentParser, - profiles_to_skip: list[str] | None = None, + supported_advice: Sequence[AdviceCategory] | None = None, required: bool = True, ) -> None: """Add target specific options.""" - target_profiles = get_builtin_supported_profile_names() - if profiles_to_skip: - target_profiles = [tp for tp in target_profiles if tp not in profiles_to_skip] - - default_target_profile = "ethos-u55-256" + target_profiles = builtin_profile_names() + + if supported_advice: + + def is_advice_supported(profile: str, advice: Sequence[AdviceCategory]) -> bool: + """ + Collect all target profiles that support the advice. + + This means target profiles that... + - have the right target prefix, e.g. "ethos-u55..." to avoid loading + all target profiles + - support any of the required advice + """ + for target, info in target_registry.items.items(): + if profile.startswith(target): + return any(info.is_supported(adv) for adv in advice) + return False + + target_profiles = [ + profile + for profile in target_profiles + if is_advice_supported(profile, supported_advice) + ] target_group = parser.add_argument_group("target options") target_group.add_argument( "-t", "--target-profile", required=required, - default=default_target_profile, help="Built-in target profile or path to the custom target profile. " f"Built-in target profiles are {', '.join(target_profiles)}. " "Target profile that will set the target options " diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index bf603dd..eb7ecff 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -6,9 +6,10 @@ from __future__ import annotations from abc import ABC from abc import abstractmethod from dataclasses import dataclass +from functools import lru_cache from pathlib import Path -from shutil import copy from typing import Any +from typing import Callable from typing import cast from typing import TypeVar @@ -19,23 +20,20 @@ except ModuleNotFoundError: from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory +from mlia.core.advisor import InferenceAdvisor from mlia.utils.filesystem import get_mlia_target_profiles_dir -def get_profile_file(target_profile: str | Path) -> Path: - """Get the target profile toml file.""" - if not target_profile: - raise Exception("Target profile is not provided.") +def get_builtin_profile_path(target_profile: str) -> Path: + """ + Construct the path to the built-in target profile file. - profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml") - if not profile_file.is_file(): - profile_file = Path(target_profile) - - if not profile_file.exists(): - raise Exception(f"File not found: {profile_file}.") - return profile_file + No checks are performed. + """ + return get_mlia_target_profiles_dir() / f"{target_profile}.toml" +@lru_cache def load_profile(path: str | Path) -> dict[str, Any]: """Get settings for the provided target profile.""" with open(path, "rb") as file: @@ -55,24 +53,12 @@ def get_builtin_supported_profile_names() -> list[str]: ) -def get_target(target_profile: str | Path) -> str: - """Return target for the provided target_profile.""" - profile_file = get_profile_file(target_profile) - profile = load_profile(profile_file) - return cast(str, profile["target"]) +BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names() -def copy_profile_file_to_output_dir( - target_profile: str | Path, output_dir: str | Path -) -> bool: - """Copy the target profile file to output directory.""" - profile_file_path = get_profile_file(target_profile) - output_file_path = f"{output_dir}/{profile_file_path.stem}.toml" - try: - copy(profile_file_path, output_file_path) - return True - except OSError as err: - raise RuntimeError("Failed to copy profile file:", err.strerror) from err +def is_builtin_profile(profile_name: str | Path) -> bool: + """Check if the given profile name belongs to a built-in profile.""" + return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES T = TypeVar("T", bound="TargetProfile") @@ -88,21 +74,29 @@ class TargetProfile(ABC): @classmethod def load(cls: type[T], path: str | Path) -> T: """Load and verify a target profile from file and return new instance.""" - profile = load_profile(path) + profile_data = load_profile(path) try: - new_instance = cls(**profile) + new_instance = cls.load_json_data(profile_data) except KeyError as ex: raise KeyError(f"Missing key in file {path}.") from ex - new_instance.verify() + return new_instance + @classmethod + def load_json_data(cls: type[T], profile_data: dict) -> T: + """Load a target profile from the JSON data.""" + new_instance = cls(**profile_data) + new_instance.verify() return new_instance @classmethod - def load_profile(cls: type[T], target_profile: str) -> T: - """Load a target profile by name.""" - profile_file = get_profile_file(target_profile) + def load_profile(cls: type[T], target_profile: str | Path) -> T: + """Load a target profile from built-in target profile name or file path.""" + if is_builtin_profile(target_profile): + profile_file = get_builtin_profile_path(cast(str, target_profile)) + else: + profile_file = Path(target_profile) return cls.load(profile_file) def save(self, path: str | Path) -> None: @@ -125,6 +119,9 @@ class TargetInfo: """Collect information about supported targets.""" supported_backends: list[str] + default_backends: list[str] + advisor_factory_func: Callable[..., InferenceAdvisor] + target_profile_cls: type[TargetProfile] def __str__(self) -> str: """List supported backends.""" @@ -135,7 +132,8 @@ class TargetInfo: ) -> bool: """Check if any of the supported backends support this kind of advice.""" return any( - backend_registry.items[name].is_supported(advice, check_system) + name in backend_registry.items + and backend_registry.items[name].is_supported(advice, check_system) for name in self.supported_backends ) @@ -146,5 +144,6 @@ class TargetInfo: return [ name for name in self.supported_backends - if backend_registry.items[name].is_supported(advice, check_system) + if name in backend_registry.items + and backend_registry.items[name].is_supported(advice, check_system) ] diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py index f686bfc..87f268a 100644 --- a/src/mlia/target/cortex_a/__init__.py +++ b/src/mlia/target/cortex_a/__init__.py @@ -1,7 +1,17 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cortex-A target module.""" +from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor +from mlia.target.cortex_a.config import CortexAConfiguration from mlia.target.registry import registry from mlia.target.registry import TargetInfo -registry.register("cortex-a", TargetInfo(["ArmNNTFLiteDelegate"])) +registry.register( + "cortex-a", + TargetInfo( + supported_backends=["ArmNNTFLiteDelegate"], + default_backends=["ArmNNTFLiteDelegate"], + advisor_factory_func=configure_and_get_cortexa_advisor, + target_profile_cls=CortexAConfiguration, + ), +) diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py index 518c9f1..a093784 100644 --- a/src/mlia/target/cortex_a/advisor.py +++ b/src/mlia/target/cortex_a/advisor.py @@ -5,6 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Any +from typing import cast from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor @@ -21,6 +22,7 @@ from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer from mlia.target.cortex_a.data_collection import CortexAOperatorCompatibility from mlia.target.cortex_a.events import CortexAAdvisorStartedEvent from mlia.target.cortex_a.handlers import CortexAEventHandler +from mlia.target.registry import profile class CortexAInferenceAdvisor(DefaultInferenceAdvisor): @@ -59,7 +61,7 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor): return [ CortexAAdvisorStartedEvent( - model, CortexAConfiguration.load_profile(target_profile) + model, cast(CortexAConfiguration, profile(target_profile)) ), ] diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py index d53be53..6b6777d 100644 --- a/src/mlia/target/ethos_u/__init__.py +++ b/src/mlia/target/ethos_u/__init__.py @@ -1,8 +1,23 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U target module.""" +from mlia.backend.corstone import CORSTONE_PRIORITY +from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor +from mlia.target.ethos_u.config import EthosUConfiguration +from mlia.target.ethos_u.config import get_default_ethos_u_backends from mlia.target.registry import registry from mlia.target.registry import TargetInfo -registry.register("ethos-u55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) -registry.register("ethos-u65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) +SUPPORTED_BACKENDS_PRIORITY = ["Vela", *CORSTONE_PRIORITY] + + +for ethos_u in ("ethos-u55", "ethos-u65"): + registry.register( + ethos_u, + TargetInfo( + supported_backends=SUPPORTED_BACKENDS_PRIORITY, + default_backends=get_default_ethos_u_backends(SUPPORTED_BACKENDS_PRIORITY), + advisor_factory_func=configure_and_get_ethosu_advisor, + target_profile_cls=EthosUConfiguration, + ), + ) diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py index 225fd87..5f23fdd 100644 --- a/src/mlia/target/ethos_u/advisor.py +++ b/src/mlia/target/ethos_u/advisor.py @@ -5,6 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Any +from typing import cast from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor @@ -25,6 +26,7 @@ from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance from mlia.target.ethos_u.data_collection import EthosUPerformance from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent from mlia.target.ethos_u.handlers import EthosUEventHandler +from mlia.target.registry import profile from mlia.utils.types import is_list_of @@ -96,7 +98,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): def _get_device_cfg(self, context: Context) -> EthosUConfiguration: """Get device configuration.""" target_profile = self.get_target_profile(context) - return EthosUConfiguration.load_profile(target_profile) + return cast(EthosUConfiguration, profile(target_profile)) def _get_optimization_settings(self, context: Context) -> list[list[dict]]: """Get optimization settings.""" diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py index eb5691d..d1a2c7a 100644 --- a/src/mlia/target/ethos_u/config.py +++ b/src/mlia/target/ethos_u/config.py @@ -6,12 +6,13 @@ from __future__ import annotations import logging from typing import Any +from mlia.backend.corstone import is_corstone_backend +from mlia.backend.manager import get_available_backends from mlia.backend.vela.compiler import resolve_compiler_config from mlia.backend.vela.compiler import VelaCompilerOptions from mlia.target.config import TargetProfile from mlia.utils.filesystem import get_vela_config - logger = logging.getLogger(__name__) @@ -67,3 +68,22 @@ class EthosUConfiguration(TargetProfile): def __repr__(self) -> str: """Return string representation.""" return f"" + + +def get_default_ethos_u_backends( + supported_backends_priority_order: list[str], +) -> list[str]: + """Return default backends for Ethos-U targets.""" + available_backends = get_available_backends() + + default_backends = [] + corstone_added = False + for backend in supported_backends_priority_order: + if backend not in available_backends: + continue + if is_corstone_backend(backend): + if corstone_added: + continue # only add one Corstone backend + corstone_added = True + default_backends.append(backend) + return default_backends diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py index 4870fc8..9fccecb 100644 --- a/src/mlia/target/registry.py +++ b/src/mlia/target/registry.py @@ -3,17 +3,78 @@ """Target module.""" from __future__ import annotations +from functools import lru_cache +from pathlib import Path +from typing import cast + from mlia.backend.config import BackendType from mlia.backend.manager import get_installation_manager from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory from mlia.core.reporting import Column from mlia.core.reporting import Table +from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES +from mlia.target.config import get_builtin_profile_path +from mlia.target.config import is_builtin_profile +from mlia.target.config import load_profile from mlia.target.config import TargetInfo +from mlia.target.config import TargetProfile from mlia.utils.registry import Registry + +class TargetRegistry(Registry[TargetInfo]): + """Registry for targets.""" + + def register(self, name: str, item: TargetInfo) -> bool: + """Register an item: returns `False` if already registered.""" + assert all( + backend in backend_registry.items for backend in item.supported_backends + ) + return super().register(name, item) + + # All supported targets are required to be registered here. -registry = Registry[TargetInfo]() +registry = TargetRegistry() + + +def builtin_profile_names() -> list[str]: + """Return a list of built-in profile names (not file paths).""" + return BUILTIN_SUPPORTED_PROFILE_NAMES + + +@lru_cache +def profile(target_profile: str | Path) -> TargetProfile: + """Get the target profile data (built-in or custom file).""" + if not target_profile: + raise ValueError("No valid target profile was provided.") + if is_builtin_profile(target_profile): + profile_file = get_builtin_profile_path(cast(str, target_profile)) + profile_ = create_target_profile(profile_file) + else: + profile_file = Path(target_profile) + if profile_file.is_file(): + profile_ = create_target_profile(profile_file) + else: + raise ValueError( + f"Profile '{target_profile}' is neither a valid built-in " + "target profile name or a valid file path." + ) + + return profile_ + + +def get_target(target_profile: str | Path) -> str: + """Return target for the provided target_profile.""" + return profile(target_profile).target + + +@lru_cache +def create_target_profile(path: Path) -> TargetProfile: + """Create a new instance of a TargetProfile from the file.""" + profile_data = load_profile(path) + target = profile_data["target"] + target_info = registry.items[target] + return target_info.target_profile_cls.load_json_data(profile_data) def supported_advice(target: str) -> list[AdviceCategory]: @@ -29,6 +90,11 @@ def supported_backends(target: str) -> list[str]: return registry.items[target].filter_supported_backends(check_system=False) +def default_backends(target: str) -> list[str]: + """Get a list of default backends for the given target.""" + return registry.items[target].default_backends + + def get_backend_to_supported_targets() -> dict[str, list]: """Get a dict that maps a list of supported targets given backend.""" targets = dict(registry.items) diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py index 06bf1a9..3830ce5 100644 --- a/src/mlia/target/tosa/__init__.py +++ b/src/mlia/target/tosa/__init__.py @@ -3,5 +3,15 @@ """TOSA target module.""" from mlia.target.registry import registry from mlia.target.registry import TargetInfo +from mlia.target.tosa.advisor import configure_and_get_tosa_advisor +from mlia.target.tosa.config import TOSAConfiguration -registry.register("tosa", TargetInfo(["tosa-checker"])) +registry.register( + "tosa", + TargetInfo( + supported_backends=["tosa-checker"], + default_backends=["tosa-checker"], + advisor_factory_func=configure_and_get_tosa_advisor, + target_profile_cls=TOSAConfiguration, + ), +) diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py index 5588d0f..5fb18ed 100644 --- a/src/mlia/target/tosa/advisor.py +++ b/src/mlia/target/tosa/advisor.py @@ -5,6 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Any +from typing import cast from mlia.core.advice_generation import AdviceCategory from mlia.core.advice_generation import AdviceProducer @@ -17,6 +18,7 @@ from mlia.core.data_collection import DataCollector from mlia.core.events import Event from mlia.core.metadata import MLIAMetadata from mlia.core.metadata import ModelMetadata +from mlia.target.registry import profile from mlia.target.tosa.advice_generation import TOSAAdviceProducer from mlia.target.tosa.config import TOSAConfiguration from mlia.target.tosa.data_analysis import TOSADataAnalyzer @@ -66,7 +68,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor): return [ TOSAAdvisorStartedEvent( model, - TOSAConfiguration.load_profile(target_profile), + cast(TOSAConfiguration, profile(target_profile)), MetadataDisplay( TOSAMetadata("tosa-checker"), MLIAMetadata("mlia"), diff --git a/tests/conftest.py b/tests/conftest.py index d797869..55b296f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -83,8 +83,11 @@ def fixture_test_models_path( save_tflite_model(tflite_model, tflite_model_path) tflite_vela_model = tmp_path / "test_model_vela.tflite" - device = EthosUConfiguration.load_profile("ethos-u55-256") - optimize_model(tflite_model_path, device.compiler_options, tflite_vela_model) + + target_profile = EthosUConfiguration.load_profile("ethos-u55-256") + optimize_model( + tflite_model_path, target_profile.compiler_options, tflite_vela_model + ) tf.saved_model.save(keras_model, str(tmp_path / "tf_model_test_model")) diff --git a/tests/test_api.py b/tests/test_api.py index 251d5ac..b40c55b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -19,12 +19,8 @@ from mlia.target.tosa.advisor import TOSAInferenceAdvisor def test_get_advice_no_target_provided(test_keras_model: Path) -> None: """Test getting advice when no target provided.""" - with pytest.raises(Exception, match="Target profile is not provided"): - get_advice( - None, # type:ignore - test_keras_model, - {"compatibility"}, - ) + with pytest.raises(Exception, match="No valid target profile was provided."): + get_advice(None, test_keras_model, {"compatibility"}) # type: ignore def test_get_advice_wrong_category(test_keras_model: Path) -> None: diff --git a/tests/test_cli_command_validators.py b/tests/test_cli_command_validators.py index cd048ee..29813f4 100644 --- a/tests/test_cli_command_validators.py +++ b/tests/test_cli_command_validators.py @@ -5,7 +5,6 @@ from __future__ import annotations import argparse from contextlib import ExitStack -from unittest.mock import MagicMock import pytest @@ -93,72 +92,66 @@ def test_validate_check_target_profile( @pytest.mark.parametrize( - "input_target_profile, input_backends, throws_exception," - "exception_message, output_backends", + ( + "input_target_profile", + "input_backends", + "throws_exception", + "exception_message", + "output_backends", + ), [ [ "tosa", - ["Vela"], - True, - "Vela backend not supported with target-profile tosa.", + ["tosa-checker"], + False, None, + ["tosa-checker"], ], [ "tosa", - ["Corstone-300, Vela"], + ["Corstone-310"], True, - "Corstone-300, Vela backend not supported with target-profile tosa.", + "Backend Corstone-310 not supported with target-profile tosa.", None, ], [ "cortex-a", - ["Corstone-310", "tosa-checker"], - True, - "Corstone-310, tosa-checker backend not supported " - "with target-profile cortex-a.", + ["ArmNNTFLiteDelegate"], + False, None, + ["ArmNNTFLiteDelegate"], ], [ - "ethos-u55-256", - ["tosa-checker", "Corstone-310"], + "cortex-a", + ["tosa-checker"], True, - "tosa-checker backend not supported with target-profile ethos-u55-256.", + "Backend tosa-checker not supported with target-profile cortex-a.", None, ], - ["tosa", None, False, None, ["tosa-checker"]], - ["cortex-a", None, False, None, ["ArmNNTFLiteDelegate"]], - ["tosa", ["tosa-checker"], False, None, ["tosa-checker"]], - ["cortex-a", ["ArmNNTFLiteDelegate"], False, None, ["ArmNNTFLiteDelegate"]], [ "ethos-u55-256", - ["Vela", "Corstone-300"], + ["Vela", "Corstone-310"], False, None, - ["Vela", "Corstone-300"], + ["Vela", "Corstone-310"], ], [ - "ethos-u55-256", - None, - False, + "ethos-u65-256", + ["Vela", "Corstone-310", "tosa-checker"], + True, + "Backend tosa-checker not supported with target-profile ethos-u65-256.", None, - ["Vela", "Corstone-300"], ], ], ) def test_validate_backend( - monkeypatch: pytest.MonkeyPatch, input_target_profile: str, - input_backends: list[str] | None, + input_backends: list[str], throws_exception: bool, exception_message: str, output_backends: list[str] | None, ) -> None: """Test backend validation with target-profiles and backends.""" - monkeypatch.setattr( - "mlia.cli.config.get_available_backends", - MagicMock(return_value=["Vela", "Corstone-300"]), - ) - exit_stack = ExitStack() if throws_exception: exit_stack.enter_context( diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 61cc5a6..f3213c4 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -33,7 +33,13 @@ def test_performance_unknown_target( sample_context: ExecutionContext, test_tflite_model: Path ) -> None: """Test that command should fail if unknown target passed.""" - with pytest.raises(Exception, match=r"File not found:*"): + with pytest.raises( + Exception, + match=( + r"Profile 'unknown' is neither a valid built-in target profile " + r"name or a valid file path." + ), + ): check( sample_context, model=str(test_tflite_model), diff --git a/tests/test_cli_config.py b/tests/test_cli_config.py deleted file mode 100644 index 8494d73..0000000 --- a/tests/test_cli_config.py +++ /dev/null @@ -1,48 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for cli.config module.""" -from __future__ import annotations - -from unittest.mock import MagicMock - -import pytest - -from mlia.cli.config import get_default_backends - - -@pytest.mark.parametrize( - "available_backends, expected_default_backends", - [ - [["Vela"], ["Vela"]], - [["Corstone-300"], ["Corstone-300"]], - [["Corstone-310"], ["Corstone-310"]], - [["Corstone-300", "Corstone-310"], ["Corstone-310"]], - [["Vela", "Corstone-300", "Corstone-310"], ["Vela", "Corstone-310"]], - [ - ["Vela", "Corstone-300", "Corstone-310", "New backend"], - ["Vela", "Corstone-310", "New backend"], - ], - [ - ["Vela", "Corstone-300", "New backend"], - ["Vela", "Corstone-300", "New backend"], - ], - [["ArmNNTFLiteDelegate"], ["ArmNNTFLiteDelegate"]], - [["tosa-checker"], ["tosa-checker"]], - [ - ["ArmNNTFLiteDelegate", "Corstone-300"], - ["ArmNNTFLiteDelegate", "Corstone-300"], - ], - ], -) -def test_get_default_backends( - monkeypatch: pytest.MonkeyPatch, - available_backends: list[str], - expected_default_backends: list[str], -) -> None: - """Test function get_default backends.""" - monkeypatch.setattr( - "mlia.cli.config.get_available_backends", - MagicMock(return_value=available_backends), - ) - - assert get_default_backends() == expected_default_backends diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py index 8f7e4b0..6d19207 100644 --- a/tests/test_cli_helpers.py +++ b/tests/test_cli_helpers.py @@ -3,11 +3,13 @@ """Tests for the helper classes.""" from __future__ import annotations +from pathlib import Path from typing import Any import pytest from mlia.cli.helpers import CLIActionResolver +from mlia.cli.helpers import copy_profile_file_to_output_dir from mlia.nn.tensorflow.optimizations.select import OptimizationSettings @@ -139,3 +141,12 @@ class TestCliActionResolver: """Test checking operator compatibility info.""" resolver = CLIActionResolver(args) assert resolver.check_operator_compatibility() == expected_result + + +def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None: + """Test if the profile file is copied into the output directory.""" + test_target_profile_name = "ethos-u55-128" + test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml") + + copy_profile_file_to_output_dir(test_target_profile_name, tmp_path) + assert Path.is_file(test_file_path) diff --git a/tests/test_target_config.py b/tests/test_target_config.py index c6235a5..368d394 100644 --- a/tests/test_target_config.py +++ b/tests/test_target_config.py @@ -3,35 +3,28 @@ """Tests for the backend config module.""" from __future__ import annotations -from pathlib import Path - import pytest from mlia.backend.config import BackendConfiguration from mlia.backend.config import BackendType from mlia.backend.config import System from mlia.core.common import AdviceCategory -from mlia.target.config import copy_profile_file_to_output_dir +from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES +from mlia.target.config import get_builtin_profile_path from mlia.target.config import get_builtin_supported_profile_names -from mlia.target.config import get_profile_file +from mlia.target.config import is_builtin_profile from mlia.target.config import load_profile from mlia.target.config import TargetInfo from mlia.target.config import TargetProfile +from mlia.target.cortex_a.advisor import CortexAInferenceAdvisor +from mlia.target.cortex_a.config import CortexAConfiguration from mlia.utils.registry import Registry -def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None: - """Test if the profile file is copied into the output directory.""" - test_target_profile_name = "ethos-u55-128" - test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml") - - copy_profile_file_to_output_dir(test_target_profile_name, tmp_path) - assert Path.is_file(test_file_path) - - -def test_get_builtin_supported_profile_names() -> None: - """Test profile names getter.""" - assert get_builtin_supported_profile_names() == [ +def test_builtin_supported_profile_names() -> None: + """Test built-in profile names.""" + assert BUILTIN_SUPPORTED_PROFILE_NAMES == get_builtin_supported_profile_names() + assert BUILTIN_SUPPORTED_PROFILE_NAMES == [ "cortex-a", "ethos-u55-128", "ethos-u55-256", @@ -39,23 +32,24 @@ def test_get_builtin_supported_profile_names() -> None: "ethos-u65-512", "tosa", ] + for profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES: + assert is_builtin_profile(profile_name) + profile_file = get_builtin_profile_path(profile_name) + assert profile_file.is_file() -def test_get_profile_file() -> None: - """Test function 'get_profile_file'.""" - profile_file = get_profile_file("cortex-a") +def test_builtin_profile_files() -> None: + """Test function 'get_bulitin_profile_file'.""" + profile_file = get_builtin_profile_path("cortex-a") assert profile_file.is_file() - assert profile_file == get_profile_file(profile_file) - with pytest.raises(Exception): - get_profile_file("UNKNOWN") - with pytest.raises(Exception): - get_profile_file("") + profile_file = get_builtin_profile_path("UNKNOWN_FILE_THAT_DOES_NOT_EXIST") + assert not profile_file.exists() def test_load_profile() -> None: """Test getting profile data.""" - profile_file = get_profile_file("ethos-u55-256") + profile_file = get_builtin_profile_path("ethos-u55-256") assert load_profile(profile_file) == { "target": "ethos-u55", "mac": 256, @@ -80,6 +74,9 @@ def test_target_profile() -> None: profile = MyTargetProfile("AnyTarget") assert profile.target == "AnyTarget" + profile = MyTargetProfile.load_json_data({"target": "MySuperTarget"}) + assert profile.target == "MySuperTarget" + profile = MyTargetProfile("") with pytest.raises(ValueError): profile.verify() @@ -101,7 +98,12 @@ def test_target_info( supported: bool, ) -> None: """Test the class 'TargetInfo'.""" - info = TargetInfo(["backend"]) + info = TargetInfo( + ["backend"], + ["backend"], + CortexAInferenceAdvisor, + CortexAConfiguration, + ) backend_registry = Registry[BackendConfiguration]() backend_registry.register( @@ -116,3 +118,12 @@ def test_target_info( assert info.is_supported(advice, check_system) == supported assert bool(info.filter_supported_backends(advice, check_system)) == supported + + info = TargetInfo( + ["unknown_backend"], + ["unknown_backend"], + CortexAInferenceAdvisor, + CortexAConfiguration, + ) + assert not info.is_supported(advice, check_system) + assert not info.filter_supported_backends(advice, check_system) diff --git a/tests/test_target_ethos_u_data_analysis.py b/tests/test_target_ethos_u_data_analysis.py index e919f5d..8e63946 100644 --- a/tests/test_target_ethos_u_data_analysis.py +++ b/tests/test_target_ethos_u_data_analysis.py @@ -3,6 +3,8 @@ """Tests for Ethos-U data analysis module.""" from __future__ import annotations +from typing import cast + import pytest from mlia.backend.vela.compat import NpuSupported @@ -23,6 +25,7 @@ from mlia.target.ethos_u.performance import MemoryUsage from mlia.target.ethos_u.performance import NPUCycles from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics from mlia.target.ethos_u.performance import PerformanceMetrics +from mlia.target.registry import profile def test_perf_metrics_diff() -> None: @@ -84,7 +87,7 @@ def test_perf_metrics_diff() -> None: [ OptimizationPerformanceMetrics( PerformanceMetrics( - EthosUConfiguration.load_profile("ethos-u55-256"), + cast(EthosUConfiguration, profile("ethos-u55-256")), NPUCycles(1, 2, 3, 4, 5, 6), # memory metrics are in kilobytes MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore @@ -95,7 +98,7 @@ def test_perf_metrics_diff() -> None: OptimizationSettings("pruning", 0.5, None), ], PerformanceMetrics( - EthosUConfiguration.load_profile("ethos-u55-256"), + cast(EthosUConfiguration, profile("ethos-u55-256")), NPUCycles(1, 2, 3, 4, 5, 6), # memory metrics are in kilobytes MemoryUsage( @@ -127,7 +130,7 @@ def test_perf_metrics_diff() -> None: [ OptimizationPerformanceMetrics( PerformanceMetrics( - EthosUConfiguration.load_profile("ethos-u55-256"), + cast(EthosUConfiguration, profile("ethos-u55-256")), NPUCycles(1, 2, 3, 4, 5, 6), # memory metrics are in kilobytes MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore diff --git a/tests/test_target_ethos_u_reporters.py b/tests/test_target_ethos_u_reporters.py index 0c5764e..9707dff 100644 --- a/tests/test_target_ethos_u_reporters.py +++ b/tests/test_target_ethos_u_reporters.py @@ -3,6 +3,8 @@ """Tests for reports module.""" from __future__ import annotations +from typing import cast + import pytest from mlia.backend.vela.compat import NpuSupported @@ -12,6 +14,7 @@ from mlia.core.reporting import Table from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.reporters import report_device_details from mlia.target.ethos_u.reporters import report_operators +from mlia.target.registry import profile from mlia.utils.console import remove_ascii_codes @@ -118,7 +121,7 @@ def test_report_operators( "device, expected_plain_text, expected_json_dict", [ [ - EthosUConfiguration.load_profile("ethos-u55-256"), + cast(EthosUConfiguration, profile("ethos-u55-256")), """Device information: Target ethos-u55 MAC 256 diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py index 5012148..2cbd97d 100644 --- a/tests/test_target_registry.py +++ b/tests/test_target_registry.py @@ -6,7 +6,11 @@ from __future__ import annotations import pytest from mlia.core.common import AdviceCategory +from mlia.target.config import get_builtin_profile_path from mlia.target.registry import all_supported_backends +from mlia.target.registry import default_backends +from mlia.target.registry import is_supported +from mlia.target.registry import profile from mlia.target.registry import registry from mlia.target.registry import supported_advice from mlia.target.registry import supported_backends @@ -56,6 +60,23 @@ def test_supported_advice( assert all(advice in supported for advice in expected_advices) +@pytest.mark.parametrize( + ("backend", "target", "expected_result"), + ( + ("ArmNNTFLiteDelegate", None, True), + ("ArmNNTFLiteDelegate", "cortex-a", True), + ("ArmNNTFLiteDelegate", "tosa", False), + ("Corstone-310", None, True), + ("Corstone-310", "ethos-u55", True), + ("Corstone-310", "ethos-u65", True), + ("Corstone-310", "cortex-a", False), + ), +) +def test_is_supported(backend: str, target: str | None, expected_result: bool) -> None: + """Test function is_supported().""" + assert is_supported(backend, target) == expected_result + + @pytest.mark.parametrize( ("target_name", "expected_backends"), ( @@ -92,3 +113,39 @@ def test_all_supported_backends() -> None: "Corstone-310", "Corstone-300", } + + +@pytest.mark.parametrize( + ("target", "expected_default_backends", "is_subset_only"), + [ + ["cortex-a", ["ArmNNTFLiteDelegate"], False], + ["tosa", ["tosa-checker"], False], + ["ethos-u55", ["Vela"], True], + ["ethos-u65", ["Vela"], True], + ], +) +def test_default_backends( + target: str, + expected_default_backends: list[str], + is_subset_only: bool, +) -> None: + """Test function default_backends().""" + if is_subset_only: + assert set(expected_default_backends).issubset(default_backends(target)) + else: + assert default_backends(target) == expected_default_backends + + +@pytest.mark.parametrize( + "target_profile", ("cortex-a", "tosa", "ethos-u55-128", "ethos-u65-256") +) +def test_profile(target_profile: str) -> None: + """Test function profile().""" + # Test loading by built-in profile name + cfg = profile(target_profile) + assert target_profile.startswith(cfg.target) + + # Test loading the file directly + profile_file = get_builtin_profile_path(target_profile) + cfg = profile(profile_file) + assert target_profile.startswith(cfg.target) diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py index d09d0ab..37a1833 100644 --- a/tests_e2e/test_e2e.py +++ b/tests_e2e/test_e2e.py @@ -20,7 +20,7 @@ from typing import Sequence import pytest -from mlia.cli.config import get_available_backends +from mlia.backend.manager import get_available_backends from mlia.cli.main import get_commands from mlia.cli.main import get_possible_command_names from mlia.cli.main import init_parser -- cgit v1.2.1