diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/mlia/api.py | 21 | ||||
-rw-r--r-- | src/mlia/backend/corstone/__init__.py | 30 | ||||
-rw-r--r-- | src/mlia/backend/manager.py | 13 | ||||
-rw-r--r-- | src/mlia/cli/command_validators.py | 20 | ||||
-rw-r--r-- | src/mlia/cli/commands.py | 2 | ||||
-rw-r--r-- | src/mlia/cli/config.py | 69 | ||||
-rw-r--r-- | src/mlia/cli/helpers.py | 24 | ||||
-rw-r--r-- | src/mlia/cli/main.py | 15 | ||||
-rw-r--r-- | src/mlia/cli/options.py | 43 | ||||
-rw-r--r-- | src/mlia/target/config.py | 71 | ||||
-rw-r--r-- | src/mlia/target/cortex_a/__init__.py | 12 | ||||
-rw-r--r-- | src/mlia/target/cortex_a/advisor.py | 4 | ||||
-rw-r--r-- | src/mlia/target/ethos_u/__init__.py | 19 | ||||
-rw-r--r-- | src/mlia/target/ethos_u/advisor.py | 4 | ||||
-rw-r--r-- | src/mlia/target/ethos_u/config.py | 22 | ||||
-rw-r--r-- | src/mlia/target/registry.py | 68 | ||||
-rw-r--r-- | src/mlia/target/tosa/__init__.py | 12 | ||||
-rw-r--r-- | src/mlia/target/tosa/advisor.py | 4 |
18 files changed, 277 insertions, 176 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py index fd5fc13..7adae48 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -10,10 +10,8 @@ from typing import Any from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext -from mlia.target.config import get_target -from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor -from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor -from mlia.target.tosa.advisor import configure_and_get_tosa_advisor +from mlia.target.registry import profile +from mlia.target.registry import registry as target_registry logger = logging.getLogger(__name__) @@ -84,19 +82,8 @@ def get_advisor( **extra_args: Any, ) -> InferenceAdvisor: """Find appropriate advisor for the target.""" - target_factories = { - "ethos-u55": configure_and_get_ethosu_advisor, - "ethos-u65": configure_and_get_ethosu_advisor, - "tosa": configure_and_get_tosa_advisor, - "cortex-a": configure_and_get_cortexa_advisor, - } - - try: - target = get_target(target_profile) - factory_function = target_factories[target] - except KeyError as err: - raise Exception(f"Unsupported profile {target_profile}") from err - + target = profile(target_profile).target + factory_function = target_registry.items[target].advisor_factory_func return factory_function( context, target_profile, diff --git a/src/mlia/backend/corstone/__init__.py b/src/mlia/backend/corstone/__init__.py index 36f74ee..b59ab65 100644 --- a/src/mlia/backend/corstone/__init__.py +++ b/src/mlia/backend/corstone/__init__.py @@ -7,24 +7,20 @@ from mlia.backend.config import System from mlia.backend.registry import registry from mlia.core.common import AdviceCategory -registry.register( - "Corstone-300", - BackendConfiguration( - supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], - supported_systems=[System.LINUX_AMD64], - backend_type=BackendType.CUSTOM, - ), -) -registry.register( - "Corstone-310", - BackendConfiguration( - supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], - supported_systems=[System.LINUX_AMD64], - backend_type=BackendType.CUSTOM, - ), -) +# List of mutually exclusive Corstone backends ordered by priority +CORSTONE_PRIORITY = ("Corstone-310", "Corstone-300") + +for corstone_name in CORSTONE_PRIORITY: + registry.register( + corstone_name, + BackendConfiguration( + supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], + supported_systems=[System.LINUX_AMD64], + backend_type=BackendType.CUSTOM, + ), + ) def is_corstone_backend(backend_name: str) -> bool: """Check if backend belongs to Corstone.""" - return backend_name in ["Corstone-300", "Corstone-310"] + return backend_name in CORSTONE_PRIORITY diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py index b0fa919..d953b2d 100644 --- a/src/mlia/backend/manager.py +++ b/src/mlia/backend/manager.py @@ -9,11 +9,13 @@ from abc import abstractmethod from pathlib import Path from typing import Callable +from mlia.backend.config import BackendType from mlia.backend.corstone.install import get_corstone_installations from mlia.backend.install import DownloadAndInstall from mlia.backend.install import Installation from mlia.backend.install import InstallationType from mlia.backend.install import InstallFromPath +from mlia.backend.registry import registry as backend_registry from mlia.backend.tosa_checker.install import get_tosa_backend_installation from mlia.core.errors import ConfigurationError from mlia.core.errors import InternalError @@ -279,3 +281,14 @@ def get_installation_manager(noninteractive: bool = False) -> InstallationManage backends.append(get_tosa_backend_installation()) return DefaultInstallationManager(backends, noninteractive=noninteractive) + + +def get_available_backends() -> list[str]: + """Return list of the available backends.""" + manager = get_installation_manager() + available_backends = [ + backend + for backend, cfg in backend_registry.items.items() + if cfg.type == BackendType.BUILTIN or manager.backend_installed(backend) + ] + return available_backends diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py index 23101e0..a0f5433 100644 --- a/src/mlia/cli/command_validators.py +++ b/src/mlia/cli/command_validators.py @@ -7,8 +7,8 @@ import argparse import logging import sys -from mlia.cli.config import get_default_backends_dict -from mlia.target.config import get_target +from mlia.target.registry import default_backends +from mlia.target.registry import get_target from mlia.target.registry import supported_backends logger = logging.getLogger(__name__) @@ -26,22 +26,18 @@ def validate_backend( target = get_target(target_profile) if not backend: - return get_default_backends_dict()[target] + return default_backends(target) - compatible_backends = supported_backends(target) + compatible_backends = list(map(normalize_string, supported_backends(target))) + backends = {normalize_string(b): b for b in backend} - nor_backend = list(map(normalize_string, backend)) - nor_compat_backend = list(map(normalize_string, compatible_backends)) - - incompatible_backends = [ - backend[i] for i, x in enumerate(nor_backend) if x not in nor_compat_backend - ] + incompatible_backends = [b for b in backends if b not in compatible_backends] # Throw an error if any unsupported backends are used if incompatible_backends: raise argparse.ArgumentError( None, - f"{', '.join(incompatible_backends)} backend not supported " - f"with target-profile {target_profile}.", + f"Backend {', '.join(backends[b] for b in incompatible_backends)} " + f"not supported with target-profile {target_profile}.", ) return backend diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py index c17d571..27f5b2b 100644 --- a/src/mlia/cli/commands.py +++ b/src/mlia/cli/commands.py @@ -23,9 +23,9 @@ from pathlib import Path from mlia.api import ExecutionContext from mlia.api import get_advice +from mlia.backend.manager import get_installation_manager from mlia.cli.command_validators import validate_backend from mlia.cli.command_validators import validate_check_target_profile -from mlia.cli.config import get_installation_manager from mlia.cli.options import parse_optimization_parameters from mlia.utils.console import create_section_header diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py deleted file mode 100644 index 433300c..0000000 --- a/src/mlia/cli/config.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Environment configuration functions.""" -from __future__ import annotations - -import logging - -from mlia.backend.manager import get_installation_manager -from mlia.target.registry import all_supported_backends - -logger = logging.getLogger(__name__) - -DEFAULT_PRUNING_TARGET = 0.5 -DEFAULT_CLUSTERING_TARGET = 32 - - -def get_available_backends() -> list[str]: - """Return list of the available backends.""" - available_backends = ["Vela", "ArmNNTFLiteDelegate"] - - # Add backends using backend manager - manager = get_installation_manager() - available_backends.extend( - backend - for backend in all_supported_backends() - if manager.backend_installed(backend) - ) - - return available_backends - - -# List of mutually exclusive Corstone backends ordered by priority -_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300") -_NON_ETHOS_U_BACKENDS = ("tosa-checker", "ArmNNTFLiteDelegate") - - -def get_ethos_u_default_backends(backends: list[str]) -> list[str]: - """Get Ethos-U default backends for evaluation.""" - return [x for x in backends if x not in _NON_ETHOS_U_BACKENDS] - - -def get_default_backends() -> list[str]: - """Get default backends for evaluation.""" - backends = get_available_backends() - - # Filter backends to only include one Corstone backend - for corstone in _CORSTONE_EXCLUSIVE_PRIORITY: - if corstone in backends: - backends = [ - backend - for backend in backends - if backend == corstone or backend not in _CORSTONE_EXCLUSIVE_PRIORITY - ] - break - - return backends - - -def get_default_backends_dict() -> dict[str, list[str]]: - """Return default backends for all targets.""" - default_backends = get_default_backends() - ethos_u_defaults = get_ethos_u_default_backends(default_backends) - - return { - "ethos-u55": ethos_u_defaults, - "ethos-u65": ethos_u_defaults, - "tosa": ["tosa-checker"], - "cortex-a": ["ArmNNTFLiteDelegate"], - } diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py index ac64581..576670b 100644 --- a/src/mlia/cli/helpers.py +++ b/src/mlia/cli/helpers.py @@ -1,14 +1,19 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Module for various helper classes.""" +"""Module for various helpers.""" from __future__ import annotations +from pathlib import Path +from shutil import copy from typing import Any +from typing import cast from mlia.cli.options import get_target_profile_opts from mlia.core.helpers import ActionResolver from mlia.nn.tensorflow.optimizations.select import OptimizationSettings from mlia.nn.tensorflow.utils import is_keras_model +from mlia.target.config import get_builtin_profile_path +from mlia.target.config import is_builtin_profile from mlia.utils.types import is_list_of @@ -108,3 +113,20 @@ class CLIActionResolver(ActionResolver): model_path = self.args.get("model") return model_path, device_opts + + +def copy_profile_file_to_output_dir( + target_profile: str | Path, output_dir: str | Path +) -> bool: + """Copy the target profile file to the output directory.""" + profile_file_path = ( + get_builtin_profile_path(cast(str, target_profile)) + if is_builtin_profile(target_profile) + else Path(target_profile) + ) + output_file_path = f"{output_dir}/{profile_file_path.stem}.toml" + try: + copy(profile_file_path, output_file_path) + return True + except OSError as err: + raise RuntimeError("Failed to copy profile file:", err.strerror) from err diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 793e155..b3a9d4c 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -18,6 +18,7 @@ from mlia.cli.commands import check from mlia.cli.commands import optimize from mlia.cli.common import CommandInfo from mlia.cli.helpers import CLIActionResolver +from mlia.cli.helpers import copy_profile_file_to_output_dir from mlia.cli.options import add_backend_install_options from mlia.cli.options import add_backend_options from mlia.cli.options import add_backend_uninstall_options @@ -30,11 +31,11 @@ from mlia.cli.options import add_output_directory from mlia.cli.options import add_output_options from mlia.cli.options import add_target_options from mlia.cli.options import get_output_format +from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext from mlia.core.errors import ConfigurationError from mlia.core.errors import InternalError from mlia.core.logging import setup_logging -from mlia.target.config import copy_profile_file_to_output_dir from mlia.target.registry import table as target_table @@ -59,7 +60,13 @@ def get_commands() -> list[CommandInfo]: [ add_output_directory, add_model_options, - add_target_options, + partial( + add_target_options, + supported_advice=[ + AdviceCategory.COMPATIBILITY, + AdviceCategory.PERFORMANCE, + ], + ), add_backend_options, add_check_category_options, add_output_options, @@ -72,7 +79,9 @@ def get_commands() -> list[CommandInfo]: [ add_output_directory, add_keras_model_options, - partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]), + partial( + add_target_options, supported_advice=[AdviceCategory.OPTIMIZATION] + ), partial( add_backend_options, backends_to_skip=["tosa-checker", "ArmNNTFLiteDelegate"], diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index 421533a..8cd2935 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -7,13 +7,17 @@ import argparse from pathlib import Path from typing import Any from typing import Callable +from typing import Sequence from mlia.backend.corstone import is_corstone_backend -from mlia.cli.config import DEFAULT_CLUSTERING_TARGET -from mlia.cli.config import DEFAULT_PRUNING_TARGET -from mlia.cli.config import get_available_backends +from mlia.backend.manager import get_available_backends +from mlia.core.common import AdviceCategory from mlia.core.typing import OutputFormat -from mlia.target.config import get_builtin_supported_profile_names +from mlia.target.registry import builtin_profile_names +from mlia.target.registry import registry as target_registry + +DEFAULT_PRUNING_TARGET = 0.5 +DEFAULT_CLUSTERING_TARGET = 32 def add_check_category_options(parser: argparse.ArgumentParser) -> None: @@ -31,22 +35,39 @@ def add_check_category_options(parser: argparse.ArgumentParser) -> None: def add_target_options( parser: argparse.ArgumentParser, - profiles_to_skip: list[str] | None = None, + supported_advice: Sequence[AdviceCategory] | None = None, required: bool = True, ) -> None: """Add target specific options.""" - target_profiles = get_builtin_supported_profile_names() - if profiles_to_skip: - target_profiles = [tp for tp in target_profiles if tp not in profiles_to_skip] - - default_target_profile = "ethos-u55-256" + target_profiles = builtin_profile_names() + + if supported_advice: + + def is_advice_supported(profile: str, advice: Sequence[AdviceCategory]) -> bool: + """ + Collect all target profiles that support the advice. + + This means target profiles that... + - have the right target prefix, e.g. "ethos-u55..." to avoid loading + all target profiles + - support any of the required advice + """ + for target, info in target_registry.items.items(): + if profile.startswith(target): + return any(info.is_supported(adv) for adv in advice) + return False + + target_profiles = [ + profile + for profile in target_profiles + if is_advice_supported(profile, supported_advice) + ] target_group = parser.add_argument_group("target options") target_group.add_argument( "-t", "--target-profile", required=required, - default=default_target_profile, help="Built-in target profile or path to the custom target profile. " f"Built-in target profiles are {', '.join(target_profiles)}. " "Target profile that will set the target options " diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index bf603dd..eb7ecff 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -6,9 +6,10 @@ from __future__ import annotations from abc import ABC from abc import abstractmethod from dataclasses import dataclass +from functools import lru_cache from pathlib import Path -from shutil import copy from typing import Any +from typing import Callable from typing import cast from typing import TypeVar @@ -19,23 +20,20 @@ except ModuleNotFoundError: from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory +from mlia.core.advisor import InferenceAdvisor from mlia.utils.filesystem import get_mlia_target_profiles_dir -def get_profile_file(target_profile: str | Path) -> Path: - """Get the target profile toml file.""" - if not target_profile: - raise Exception("Target profile is not provided.") +def get_builtin_profile_path(target_profile: str) -> Path: + """ + Construct the path to the built-in target profile file. - profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml") - if not profile_file.is_file(): - profile_file = Path(target_profile) - - if not profile_file.exists(): - raise Exception(f"File not found: {profile_file}.") - return profile_file + No checks are performed. + """ + return get_mlia_target_profiles_dir() / f"{target_profile}.toml" +@lru_cache def load_profile(path: str | Path) -> dict[str, Any]: """Get settings for the provided target profile.""" with open(path, "rb") as file: @@ -55,24 +53,12 @@ def get_builtin_supported_profile_names() -> list[str]: ) -def get_target(target_profile: str | Path) -> str: - """Return target for the provided target_profile.""" - profile_file = get_profile_file(target_profile) - profile = load_profile(profile_file) - return cast(str, profile["target"]) +BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names() -def copy_profile_file_to_output_dir( - target_profile: str | Path, output_dir: str | Path -) -> bool: - """Copy the target profile file to output directory.""" - profile_file_path = get_profile_file(target_profile) - output_file_path = f"{output_dir}/{profile_file_path.stem}.toml" - try: - copy(profile_file_path, output_file_path) - return True - except OSError as err: - raise RuntimeError("Failed to copy profile file:", err.strerror) from err +def is_builtin_profile(profile_name: str | Path) -> bool: + """Check if the given profile name belongs to a built-in profile.""" + return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES T = TypeVar("T", bound="TargetProfile") @@ -88,21 +74,29 @@ class TargetProfile(ABC): @classmethod def load(cls: type[T], path: str | Path) -> T: """Load and verify a target profile from file and return new instance.""" - profile = load_profile(path) + profile_data = load_profile(path) try: - new_instance = cls(**profile) + new_instance = cls.load_json_data(profile_data) except KeyError as ex: raise KeyError(f"Missing key in file {path}.") from ex - new_instance.verify() + return new_instance + @classmethod + def load_json_data(cls: type[T], profile_data: dict) -> T: + """Load a target profile from the JSON data.""" + new_instance = cls(**profile_data) + new_instance.verify() return new_instance @classmethod - def load_profile(cls: type[T], target_profile: str) -> T: - """Load a target profile by name.""" - profile_file = get_profile_file(target_profile) + def load_profile(cls: type[T], target_profile: str | Path) -> T: + """Load a target profile from built-in target profile name or file path.""" + if is_builtin_profile(target_profile): + profile_file = get_builtin_profile_path(cast(str, target_profile)) + else: + profile_file = Path(target_profile) return cls.load(profile_file) def save(self, path: str | Path) -> None: @@ -125,6 +119,9 @@ class TargetInfo: """Collect information about supported targets.""" supported_backends: list[str] + default_backends: list[str] + advisor_factory_func: Callable[..., InferenceAdvisor] + target_profile_cls: type[TargetProfile] def __str__(self) -> str: """List supported backends.""" @@ -135,7 +132,8 @@ class TargetInfo: ) -> bool: """Check if any of the supported backends support this kind of advice.""" return any( - backend_registry.items[name].is_supported(advice, check_system) + name in backend_registry.items + and backend_registry.items[name].is_supported(advice, check_system) for name in self.supported_backends ) @@ -146,5 +144,6 @@ class TargetInfo: return [ name for name in self.supported_backends - if backend_registry.items[name].is_supported(advice, check_system) + if name in backend_registry.items + and backend_registry.items[name].is_supported(advice, check_system) ] diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py index f686bfc..87f268a 100644 --- a/src/mlia/target/cortex_a/__init__.py +++ b/src/mlia/target/cortex_a/__init__.py @@ -1,7 +1,17 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cortex-A target module.""" +from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor +from mlia.target.cortex_a.config import CortexAConfiguration from mlia.target.registry import registry from mlia.target.registry import TargetInfo -registry.register("cortex-a", TargetInfo(["ArmNNTFLiteDelegate"])) +registry.register( + "cortex-a", + TargetInfo( + supported_backends=["ArmNNTFLiteDelegate"], + default_backends=["ArmNNTFLiteDelegate"], + advisor_factory_func=configure_and_get_cortexa_advisor, + target_profile_cls=CortexAConfiguration, + ), +) diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py index 518c9f1..a093784 100644 --- a/src/mlia/target/cortex_a/advisor.py +++ b/src/mlia/target/cortex_a/advisor.py @@ -5,6 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Any +from typing import cast from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor @@ -21,6 +22,7 @@ from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer from mlia.target.cortex_a.data_collection import CortexAOperatorCompatibility from mlia.target.cortex_a.events import CortexAAdvisorStartedEvent from mlia.target.cortex_a.handlers import CortexAEventHandler +from mlia.target.registry import profile class CortexAInferenceAdvisor(DefaultInferenceAdvisor): @@ -59,7 +61,7 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor): return [ CortexAAdvisorStartedEvent( - model, CortexAConfiguration.load_profile(target_profile) + model, cast(CortexAConfiguration, profile(target_profile)) ), ] diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py index d53be53..6b6777d 100644 --- a/src/mlia/target/ethos_u/__init__.py +++ b/src/mlia/target/ethos_u/__init__.py @@ -1,8 +1,23 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U target module.""" +from mlia.backend.corstone import CORSTONE_PRIORITY +from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor +from mlia.target.ethos_u.config import EthosUConfiguration +from mlia.target.ethos_u.config import get_default_ethos_u_backends from mlia.target.registry import registry from mlia.target.registry import TargetInfo -registry.register("ethos-u55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) -registry.register("ethos-u65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) +SUPPORTED_BACKENDS_PRIORITY = ["Vela", *CORSTONE_PRIORITY] + + +for ethos_u in ("ethos-u55", "ethos-u65"): + registry.register( + ethos_u, + TargetInfo( + supported_backends=SUPPORTED_BACKENDS_PRIORITY, + default_backends=get_default_ethos_u_backends(SUPPORTED_BACKENDS_PRIORITY), + advisor_factory_func=configure_and_get_ethosu_advisor, + target_profile_cls=EthosUConfiguration, + ), + ) diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py index 225fd87..5f23fdd 100644 --- a/src/mlia/target/ethos_u/advisor.py +++ b/src/mlia/target/ethos_u/advisor.py @@ -5,6 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Any +from typing import cast from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor @@ -25,6 +26,7 @@ from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance from mlia.target.ethos_u.data_collection import EthosUPerformance from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent from mlia.target.ethos_u.handlers import EthosUEventHandler +from mlia.target.registry import profile from mlia.utils.types import is_list_of @@ -96,7 +98,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): def _get_device_cfg(self, context: Context) -> EthosUConfiguration: """Get device configuration.""" target_profile = self.get_target_profile(context) - return EthosUConfiguration.load_profile(target_profile) + return cast(EthosUConfiguration, profile(target_profile)) def _get_optimization_settings(self, context: Context) -> list[list[dict]]: """Get optimization settings.""" diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py index eb5691d..d1a2c7a 100644 --- a/src/mlia/target/ethos_u/config.py +++ b/src/mlia/target/ethos_u/config.py @@ -6,12 +6,13 @@ from __future__ import annotations import logging from typing import Any +from mlia.backend.corstone import is_corstone_backend +from mlia.backend.manager import get_available_backends from mlia.backend.vela.compiler import resolve_compiler_config from mlia.backend.vela.compiler import VelaCompilerOptions from mlia.target.config import TargetProfile from mlia.utils.filesystem import get_vela_config - logger = logging.getLogger(__name__) @@ -67,3 +68,22 @@ class EthosUConfiguration(TargetProfile): def __repr__(self) -> str: """Return string representation.""" return f"<Ethos-U configuration target={self.target}>" + + +def get_default_ethos_u_backends( + supported_backends_priority_order: list[str], +) -> list[str]: + """Return default backends for Ethos-U targets.""" + available_backends = get_available_backends() + + default_backends = [] + corstone_added = False + for backend in supported_backends_priority_order: + if backend not in available_backends: + continue + if is_corstone_backend(backend): + if corstone_added: + continue # only add one Corstone backend + corstone_added = True + default_backends.append(backend) + return default_backends diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py index 4870fc8..9fccecb 100644 --- a/src/mlia/target/registry.py +++ b/src/mlia/target/registry.py @@ -3,17 +3,78 @@ """Target module.""" from __future__ import annotations +from functools import lru_cache +from pathlib import Path +from typing import cast + from mlia.backend.config import BackendType from mlia.backend.manager import get_installation_manager from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory from mlia.core.reporting import Column from mlia.core.reporting import Table +from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES +from mlia.target.config import get_builtin_profile_path +from mlia.target.config import is_builtin_profile +from mlia.target.config import load_profile from mlia.target.config import TargetInfo +from mlia.target.config import TargetProfile from mlia.utils.registry import Registry + +class TargetRegistry(Registry[TargetInfo]): + """Registry for targets.""" + + def register(self, name: str, item: TargetInfo) -> bool: + """Register an item: returns `False` if already registered.""" + assert all( + backend in backend_registry.items for backend in item.supported_backends + ) + return super().register(name, item) + + # All supported targets are required to be registered here. -registry = Registry[TargetInfo]() +registry = TargetRegistry() + + +def builtin_profile_names() -> list[str]: + """Return a list of built-in profile names (not file paths).""" + return BUILTIN_SUPPORTED_PROFILE_NAMES + + +@lru_cache +def profile(target_profile: str | Path) -> TargetProfile: + """Get the target profile data (built-in or custom file).""" + if not target_profile: + raise ValueError("No valid target profile was provided.") + if is_builtin_profile(target_profile): + profile_file = get_builtin_profile_path(cast(str, target_profile)) + profile_ = create_target_profile(profile_file) + else: + profile_file = Path(target_profile) + if profile_file.is_file(): + profile_ = create_target_profile(profile_file) + else: + raise ValueError( + f"Profile '{target_profile}' is neither a valid built-in " + "target profile name or a valid file path." + ) + + return profile_ + + +def get_target(target_profile: str | Path) -> str: + """Return target for the provided target_profile.""" + return profile(target_profile).target + + +@lru_cache +def create_target_profile(path: Path) -> TargetProfile: + """Create a new instance of a TargetProfile from the file.""" + profile_data = load_profile(path) + target = profile_data["target"] + target_info = registry.items[target] + return target_info.target_profile_cls.load_json_data(profile_data) def supported_advice(target: str) -> list[AdviceCategory]: @@ -29,6 +90,11 @@ def supported_backends(target: str) -> list[str]: return registry.items[target].filter_supported_backends(check_system=False) +def default_backends(target: str) -> list[str]: + """Get a list of default backends for the given target.""" + return registry.items[target].default_backends + + def get_backend_to_supported_targets() -> dict[str, list]: """Get a dict that maps a list of supported targets given backend.""" targets = dict(registry.items) diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py index 06bf1a9..3830ce5 100644 --- a/src/mlia/target/tosa/__init__.py +++ b/src/mlia/target/tosa/__init__.py @@ -3,5 +3,15 @@ """TOSA target module.""" from mlia.target.registry import registry from mlia.target.registry import TargetInfo +from mlia.target.tosa.advisor import configure_and_get_tosa_advisor +from mlia.target.tosa.config import TOSAConfiguration -registry.register("tosa", TargetInfo(["tosa-checker"])) +registry.register( + "tosa", + TargetInfo( + supported_backends=["tosa-checker"], + default_backends=["tosa-checker"], + advisor_factory_func=configure_and_get_tosa_advisor, + target_profile_cls=TOSAConfiguration, + ), +) diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py index 5588d0f..5fb18ed 100644 --- a/src/mlia/target/tosa/advisor.py +++ b/src/mlia/target/tosa/advisor.py @@ -5,6 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Any +from typing import cast from mlia.core.advice_generation import AdviceCategory from mlia.core.advice_generation import AdviceProducer @@ -17,6 +18,7 @@ from mlia.core.data_collection import DataCollector from mlia.core.events import Event from mlia.core.metadata import MLIAMetadata from mlia.core.metadata import ModelMetadata +from mlia.target.registry import profile from mlia.target.tosa.advice_generation import TOSAAdviceProducer from mlia.target.tosa.config import TOSAConfiguration from mlia.target.tosa.data_analysis import TOSADataAnalyzer @@ -66,7 +68,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor): return [ TOSAAdvisorStartedEvent( model, - TOSAConfiguration.load_profile(target_profile), + cast(TOSAConfiguration, profile(target_profile)), MetadataDisplay( TOSAMetadata("tosa-checker"), MLIAMetadata("mlia"), |