From 7a661257b6adad0c8f53e32b42ced56a1e7d952f Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Thu, 2 Feb 2023 14:02:05 +0000 Subject: MLIA-769 Expand use of target/backend registries - Use the target/backend registries to avoid hard-coded names. - Cache target profiles to avoid re-loading them Change-Id: I474b7c9ef23894e1d8a3ea06d13a37652054c62e --- src/mlia/target/config.py | 71 ++++++++++++++++++------------------ src/mlia/target/cortex_a/__init__.py | 12 +++++- src/mlia/target/cortex_a/advisor.py | 4 +- src/mlia/target/ethos_u/__init__.py | 19 +++++++++- src/mlia/target/ethos_u/advisor.py | 4 +- src/mlia/target/ethos_u/config.py | 22 ++++++++++- src/mlia/target/registry.py | 68 +++++++++++++++++++++++++++++++++- src/mlia/target/tosa/__init__.py | 12 +++++- src/mlia/target/tosa/advisor.py | 4 +- 9 files changed, 171 insertions(+), 45 deletions(-) (limited to 'src/mlia/target') diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index bf603dd..eb7ecff 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -6,9 +6,10 @@ from __future__ import annotations from abc import ABC from abc import abstractmethod from dataclasses import dataclass +from functools import lru_cache from pathlib import Path -from shutil import copy from typing import Any +from typing import Callable from typing import cast from typing import TypeVar @@ -19,23 +20,20 @@ except ModuleNotFoundError: from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory +from mlia.core.advisor import InferenceAdvisor from mlia.utils.filesystem import get_mlia_target_profiles_dir -def get_profile_file(target_profile: str | Path) -> Path: - """Get the target profile toml file.""" - if not target_profile: - raise Exception("Target profile is not provided.") +def get_builtin_profile_path(target_profile: str) -> Path: + """ + Construct the path to the built-in target profile file. - profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml") - if not profile_file.is_file(): - profile_file = Path(target_profile) - - if not profile_file.exists(): - raise Exception(f"File not found: {profile_file}.") - return profile_file + No checks are performed. + """ + return get_mlia_target_profiles_dir() / f"{target_profile}.toml" +@lru_cache def load_profile(path: str | Path) -> dict[str, Any]: """Get settings for the provided target profile.""" with open(path, "rb") as file: @@ -55,24 +53,12 @@ def get_builtin_supported_profile_names() -> list[str]: ) -def get_target(target_profile: str | Path) -> str: - """Return target for the provided target_profile.""" - profile_file = get_profile_file(target_profile) - profile = load_profile(profile_file) - return cast(str, profile["target"]) +BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names() -def copy_profile_file_to_output_dir( - target_profile: str | Path, output_dir: str | Path -) -> bool: - """Copy the target profile file to output directory.""" - profile_file_path = get_profile_file(target_profile) - output_file_path = f"{output_dir}/{profile_file_path.stem}.toml" - try: - copy(profile_file_path, output_file_path) - return True - except OSError as err: - raise RuntimeError("Failed to copy profile file:", err.strerror) from err +def is_builtin_profile(profile_name: str | Path) -> bool: + """Check if the given profile name belongs to a built-in profile.""" + return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES T = TypeVar("T", bound="TargetProfile") @@ -88,21 +74,29 @@ class TargetProfile(ABC): @classmethod def load(cls: type[T], path: str | Path) -> T: """Load and verify a target profile from file and return new instance.""" - profile = load_profile(path) + profile_data = load_profile(path) try: - new_instance = cls(**profile) + new_instance = cls.load_json_data(profile_data) except KeyError as ex: raise KeyError(f"Missing key in file {path}.") from ex - new_instance.verify() + return new_instance + @classmethod + def load_json_data(cls: type[T], profile_data: dict) -> T: + """Load a target profile from the JSON data.""" + new_instance = cls(**profile_data) + new_instance.verify() return new_instance @classmethod - def load_profile(cls: type[T], target_profile: str) -> T: - """Load a target profile by name.""" - profile_file = get_profile_file(target_profile) + def load_profile(cls: type[T], target_profile: str | Path) -> T: + """Load a target profile from built-in target profile name or file path.""" + if is_builtin_profile(target_profile): + profile_file = get_builtin_profile_path(cast(str, target_profile)) + else: + profile_file = Path(target_profile) return cls.load(profile_file) def save(self, path: str | Path) -> None: @@ -125,6 +119,9 @@ class TargetInfo: """Collect information about supported targets.""" supported_backends: list[str] + default_backends: list[str] + advisor_factory_func: Callable[..., InferenceAdvisor] + target_profile_cls: type[TargetProfile] def __str__(self) -> str: """List supported backends.""" @@ -135,7 +132,8 @@ class TargetInfo: ) -> bool: """Check if any of the supported backends support this kind of advice.""" return any( - backend_registry.items[name].is_supported(advice, check_system) + name in backend_registry.items + and backend_registry.items[name].is_supported(advice, check_system) for name in self.supported_backends ) @@ -146,5 +144,6 @@ class TargetInfo: return [ name for name in self.supported_backends - if backend_registry.items[name].is_supported(advice, check_system) + if name in backend_registry.items + and backend_registry.items[name].is_supported(advice, check_system) ] diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py index f686bfc..87f268a 100644 --- a/src/mlia/target/cortex_a/__init__.py +++ b/src/mlia/target/cortex_a/__init__.py @@ -1,7 +1,17 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cortex-A target module.""" +from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor +from mlia.target.cortex_a.config import CortexAConfiguration from mlia.target.registry import registry from mlia.target.registry import TargetInfo -registry.register("cortex-a", TargetInfo(["ArmNNTFLiteDelegate"])) +registry.register( + "cortex-a", + TargetInfo( + supported_backends=["ArmNNTFLiteDelegate"], + default_backends=["ArmNNTFLiteDelegate"], + advisor_factory_func=configure_and_get_cortexa_advisor, + target_profile_cls=CortexAConfiguration, + ), +) diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py index 518c9f1..a093784 100644 --- a/src/mlia/target/cortex_a/advisor.py +++ b/src/mlia/target/cortex_a/advisor.py @@ -5,6 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Any +from typing import cast from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor @@ -21,6 +22,7 @@ from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer from mlia.target.cortex_a.data_collection import CortexAOperatorCompatibility from mlia.target.cortex_a.events import CortexAAdvisorStartedEvent from mlia.target.cortex_a.handlers import CortexAEventHandler +from mlia.target.registry import profile class CortexAInferenceAdvisor(DefaultInferenceAdvisor): @@ -59,7 +61,7 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor): return [ CortexAAdvisorStartedEvent( - model, CortexAConfiguration.load_profile(target_profile) + model, cast(CortexAConfiguration, profile(target_profile)) ), ] diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py index d53be53..6b6777d 100644 --- a/src/mlia/target/ethos_u/__init__.py +++ b/src/mlia/target/ethos_u/__init__.py @@ -1,8 +1,23 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U target module.""" +from mlia.backend.corstone import CORSTONE_PRIORITY +from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor +from mlia.target.ethos_u.config import EthosUConfiguration +from mlia.target.ethos_u.config import get_default_ethos_u_backends from mlia.target.registry import registry from mlia.target.registry import TargetInfo -registry.register("ethos-u55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) -registry.register("ethos-u65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) +SUPPORTED_BACKENDS_PRIORITY = ["Vela", *CORSTONE_PRIORITY] + + +for ethos_u in ("ethos-u55", "ethos-u65"): + registry.register( + ethos_u, + TargetInfo( + supported_backends=SUPPORTED_BACKENDS_PRIORITY, + default_backends=get_default_ethos_u_backends(SUPPORTED_BACKENDS_PRIORITY), + advisor_factory_func=configure_and_get_ethosu_advisor, + target_profile_cls=EthosUConfiguration, + ), + ) diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py index 225fd87..5f23fdd 100644 --- a/src/mlia/target/ethos_u/advisor.py +++ b/src/mlia/target/ethos_u/advisor.py @@ -5,6 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Any +from typing import cast from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor @@ -25,6 +26,7 @@ from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance from mlia.target.ethos_u.data_collection import EthosUPerformance from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent from mlia.target.ethos_u.handlers import EthosUEventHandler +from mlia.target.registry import profile from mlia.utils.types import is_list_of @@ -96,7 +98,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): def _get_device_cfg(self, context: Context) -> EthosUConfiguration: """Get device configuration.""" target_profile = self.get_target_profile(context) - return EthosUConfiguration.load_profile(target_profile) + return cast(EthosUConfiguration, profile(target_profile)) def _get_optimization_settings(self, context: Context) -> list[list[dict]]: """Get optimization settings.""" diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py index eb5691d..d1a2c7a 100644 --- a/src/mlia/target/ethos_u/config.py +++ b/src/mlia/target/ethos_u/config.py @@ -6,12 +6,13 @@ from __future__ import annotations import logging from typing import Any +from mlia.backend.corstone import is_corstone_backend +from mlia.backend.manager import get_available_backends from mlia.backend.vela.compiler import resolve_compiler_config from mlia.backend.vela.compiler import VelaCompilerOptions from mlia.target.config import TargetProfile from mlia.utils.filesystem import get_vela_config - logger = logging.getLogger(__name__) @@ -67,3 +68,22 @@ class EthosUConfiguration(TargetProfile): def __repr__(self) -> str: """Return string representation.""" return f"" + + +def get_default_ethos_u_backends( + supported_backends_priority_order: list[str], +) -> list[str]: + """Return default backends for Ethos-U targets.""" + available_backends = get_available_backends() + + default_backends = [] + corstone_added = False + for backend in supported_backends_priority_order: + if backend not in available_backends: + continue + if is_corstone_backend(backend): + if corstone_added: + continue # only add one Corstone backend + corstone_added = True + default_backends.append(backend) + return default_backends diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py index 4870fc8..9fccecb 100644 --- a/src/mlia/target/registry.py +++ b/src/mlia/target/registry.py @@ -3,17 +3,78 @@ """Target module.""" from __future__ import annotations +from functools import lru_cache +from pathlib import Path +from typing import cast + from mlia.backend.config import BackendType from mlia.backend.manager import get_installation_manager from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory from mlia.core.reporting import Column from mlia.core.reporting import Table +from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES +from mlia.target.config import get_builtin_profile_path +from mlia.target.config import is_builtin_profile +from mlia.target.config import load_profile from mlia.target.config import TargetInfo +from mlia.target.config import TargetProfile from mlia.utils.registry import Registry + +class TargetRegistry(Registry[TargetInfo]): + """Registry for targets.""" + + def register(self, name: str, item: TargetInfo) -> bool: + """Register an item: returns `False` if already registered.""" + assert all( + backend in backend_registry.items for backend in item.supported_backends + ) + return super().register(name, item) + + # All supported targets are required to be registered here. -registry = Registry[TargetInfo]() +registry = TargetRegistry() + + +def builtin_profile_names() -> list[str]: + """Return a list of built-in profile names (not file paths).""" + return BUILTIN_SUPPORTED_PROFILE_NAMES + + +@lru_cache +def profile(target_profile: str | Path) -> TargetProfile: + """Get the target profile data (built-in or custom file).""" + if not target_profile: + raise ValueError("No valid target profile was provided.") + if is_builtin_profile(target_profile): + profile_file = get_builtin_profile_path(cast(str, target_profile)) + profile_ = create_target_profile(profile_file) + else: + profile_file = Path(target_profile) + if profile_file.is_file(): + profile_ = create_target_profile(profile_file) + else: + raise ValueError( + f"Profile '{target_profile}' is neither a valid built-in " + "target profile name or a valid file path." + ) + + return profile_ + + +def get_target(target_profile: str | Path) -> str: + """Return target for the provided target_profile.""" + return profile(target_profile).target + + +@lru_cache +def create_target_profile(path: Path) -> TargetProfile: + """Create a new instance of a TargetProfile from the file.""" + profile_data = load_profile(path) + target = profile_data["target"] + target_info = registry.items[target] + return target_info.target_profile_cls.load_json_data(profile_data) def supported_advice(target: str) -> list[AdviceCategory]: @@ -29,6 +90,11 @@ def supported_backends(target: str) -> list[str]: return registry.items[target].filter_supported_backends(check_system=False) +def default_backends(target: str) -> list[str]: + """Get a list of default backends for the given target.""" + return registry.items[target].default_backends + + def get_backend_to_supported_targets() -> dict[str, list]: """Get a dict that maps a list of supported targets given backend.""" targets = dict(registry.items) diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py index 06bf1a9..3830ce5 100644 --- a/src/mlia/target/tosa/__init__.py +++ b/src/mlia/target/tosa/__init__.py @@ -3,5 +3,15 @@ """TOSA target module.""" from mlia.target.registry import registry from mlia.target.registry import TargetInfo +from mlia.target.tosa.advisor import configure_and_get_tosa_advisor +from mlia.target.tosa.config import TOSAConfiguration -registry.register("tosa", TargetInfo(["tosa-checker"])) +registry.register( + "tosa", + TargetInfo( + supported_backends=["tosa-checker"], + default_backends=["tosa-checker"], + advisor_factory_func=configure_and_get_tosa_advisor, + target_profile_cls=TOSAConfiguration, + ), +) diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py index 5588d0f..5fb18ed 100644 --- a/src/mlia/target/tosa/advisor.py +++ b/src/mlia/target/tosa/advisor.py @@ -5,6 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Any +from typing import cast from mlia.core.advice_generation import AdviceCategory from mlia.core.advice_generation import AdviceProducer @@ -17,6 +18,7 @@ from mlia.core.data_collection import DataCollector from mlia.core.events import Event from mlia.core.metadata import MLIAMetadata from mlia.core.metadata import ModelMetadata +from mlia.target.registry import profile from mlia.target.tosa.advice_generation import TOSAAdviceProducer from mlia.target.tosa.config import TOSAConfiguration from mlia.target.tosa.data_analysis import TOSADataAnalyzer @@ -66,7 +68,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor): return [ TOSAAdvisorStartedEvent( model, - TOSAConfiguration.load_profile(target_profile), + cast(TOSAConfiguration, profile(target_profile)), MetadataDisplay( TOSAMetadata("tosa-checker"), MLIAMetadata("mlia"), -- cgit v1.2.1