aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-02-02 14:02:05 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-02-10 13:45:18 +0000
commit7a661257b6adad0c8f53e32b42ced56a1e7d952f (patch)
tree938ad8578c5b9edc0573e810ce64ce0a5bda3d8c /src/mlia/target
parent50271dee0a84bfc481ce798184f07b5b0b4bc64d (diff)
downloadmlia-7a661257b6adad0c8f53e32b42ced56a1e7d952f.tar.gz
MLIA-769 Expand use of target/backend registries
- Use the target/backend registries to avoid hard-coded names. - Cache target profiles to avoid re-loading them Change-Id: I474b7c9ef23894e1d8a3ea06d13a37652054c62e
Diffstat (limited to 'src/mlia/target')
-rw-r--r--src/mlia/target/config.py71
-rw-r--r--src/mlia/target/cortex_a/__init__.py12
-rw-r--r--src/mlia/target/cortex_a/advisor.py4
-rw-r--r--src/mlia/target/ethos_u/__init__.py19
-rw-r--r--src/mlia/target/ethos_u/advisor.py4
-rw-r--r--src/mlia/target/ethos_u/config.py22
-rw-r--r--src/mlia/target/registry.py68
-rw-r--r--src/mlia/target/tosa/__init__.py12
-rw-r--r--src/mlia/target/tosa/advisor.py4
9 files changed, 171 insertions, 45 deletions
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index bf603dd..eb7ecff 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -6,9 +6,10 @@ from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
+from functools import lru_cache
from pathlib import Path
-from shutil import copy
from typing import Any
+from typing import Callable
from typing import cast
from typing import TypeVar
@@ -19,23 +20,20 @@ except ModuleNotFoundError:
from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
+from mlia.core.advisor import InferenceAdvisor
from mlia.utils.filesystem import get_mlia_target_profiles_dir
-def get_profile_file(target_profile: str | Path) -> Path:
- """Get the target profile toml file."""
- if not target_profile:
- raise Exception("Target profile is not provided.")
+def get_builtin_profile_path(target_profile: str) -> Path:
+ """
+ Construct the path to the built-in target profile file.
- profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml")
- if not profile_file.is_file():
- profile_file = Path(target_profile)
-
- if not profile_file.exists():
- raise Exception(f"File not found: {profile_file}.")
- return profile_file
+ No checks are performed.
+ """
+ return get_mlia_target_profiles_dir() / f"{target_profile}.toml"
+@lru_cache
def load_profile(path: str | Path) -> dict[str, Any]:
"""Get settings for the provided target profile."""
with open(path, "rb") as file:
@@ -55,24 +53,12 @@ def get_builtin_supported_profile_names() -> list[str]:
)
-def get_target(target_profile: str | Path) -> str:
- """Return target for the provided target_profile."""
- profile_file = get_profile_file(target_profile)
- profile = load_profile(profile_file)
- return cast(str, profile["target"])
+BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names()
-def copy_profile_file_to_output_dir(
- target_profile: str | Path, output_dir: str | Path
-) -> bool:
- """Copy the target profile file to output directory."""
- profile_file_path = get_profile_file(target_profile)
- output_file_path = f"{output_dir}/{profile_file_path.stem}.toml"
- try:
- copy(profile_file_path, output_file_path)
- return True
- except OSError as err:
- raise RuntimeError("Failed to copy profile file:", err.strerror) from err
+def is_builtin_profile(profile_name: str | Path) -> bool:
+ """Check if the given profile name belongs to a built-in profile."""
+ return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES
T = TypeVar("T", bound="TargetProfile")
@@ -88,21 +74,29 @@ class TargetProfile(ABC):
@classmethod
def load(cls: type[T], path: str | Path) -> T:
"""Load and verify a target profile from file and return new instance."""
- profile = load_profile(path)
+ profile_data = load_profile(path)
try:
- new_instance = cls(**profile)
+ new_instance = cls.load_json_data(profile_data)
except KeyError as ex:
raise KeyError(f"Missing key in file {path}.") from ex
- new_instance.verify()
+ return new_instance
+ @classmethod
+ def load_json_data(cls: type[T], profile_data: dict) -> T:
+ """Load a target profile from the JSON data."""
+ new_instance = cls(**profile_data)
+ new_instance.verify()
return new_instance
@classmethod
- def load_profile(cls: type[T], target_profile: str) -> T:
- """Load a target profile by name."""
- profile_file = get_profile_file(target_profile)
+ def load_profile(cls: type[T], target_profile: str | Path) -> T:
+ """Load a target profile from built-in target profile name or file path."""
+ if is_builtin_profile(target_profile):
+ profile_file = get_builtin_profile_path(cast(str, target_profile))
+ else:
+ profile_file = Path(target_profile)
return cls.load(profile_file)
def save(self, path: str | Path) -> None:
@@ -125,6 +119,9 @@ class TargetInfo:
"""Collect information about supported targets."""
supported_backends: list[str]
+ default_backends: list[str]
+ advisor_factory_func: Callable[..., InferenceAdvisor]
+ target_profile_cls: type[TargetProfile]
def __str__(self) -> str:
"""List supported backends."""
@@ -135,7 +132,8 @@ class TargetInfo:
) -> bool:
"""Check if any of the supported backends support this kind of advice."""
return any(
- backend_registry.items[name].is_supported(advice, check_system)
+ name in backend_registry.items
+ and backend_registry.items[name].is_supported(advice, check_system)
for name in self.supported_backends
)
@@ -146,5 +144,6 @@ class TargetInfo:
return [
name
for name in self.supported_backends
- if backend_registry.items[name].is_supported(advice, check_system)
+ if name in backend_registry.items
+ and backend_registry.items[name].is_supported(advice, check_system)
]
diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py
index f686bfc..87f268a 100644
--- a/src/mlia/target/cortex_a/__init__.py
+++ b/src/mlia/target/cortex_a/__init__.py
@@ -1,7 +1,17 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A target module."""
+from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
+from mlia.target.cortex_a.config import CortexAConfiguration
from mlia.target.registry import registry
from mlia.target.registry import TargetInfo
-registry.register("cortex-a", TargetInfo(["ArmNNTFLiteDelegate"]))
+registry.register(
+ "cortex-a",
+ TargetInfo(
+ supported_backends=["ArmNNTFLiteDelegate"],
+ default_backends=["ArmNNTFLiteDelegate"],
+ advisor_factory_func=configure_and_get_cortexa_advisor,
+ target_profile_cls=CortexAConfiguration,
+ ),
+)
diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py
index 518c9f1..a093784 100644
--- a/src/mlia/target/cortex_a/advisor.py
+++ b/src/mlia/target/cortex_a/advisor.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any
+from typing import cast
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
@@ -21,6 +22,7 @@ from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer
from mlia.target.cortex_a.data_collection import CortexAOperatorCompatibility
from mlia.target.cortex_a.events import CortexAAdvisorStartedEvent
from mlia.target.cortex_a.handlers import CortexAEventHandler
+from mlia.target.registry import profile
class CortexAInferenceAdvisor(DefaultInferenceAdvisor):
@@ -59,7 +61,7 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor):
return [
CortexAAdvisorStartedEvent(
- model, CortexAConfiguration.load_profile(target_profile)
+ model, cast(CortexAConfiguration, profile(target_profile))
),
]
diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py
index d53be53..6b6777d 100644
--- a/src/mlia/target/ethos_u/__init__.py
+++ b/src/mlia/target/ethos_u/__init__.py
@@ -1,8 +1,23 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U target module."""
+from mlia.backend.corstone import CORSTONE_PRIORITY
+from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
+from mlia.target.ethos_u.config import EthosUConfiguration
+from mlia.target.ethos_u.config import get_default_ethos_u_backends
from mlia.target.registry import registry
from mlia.target.registry import TargetInfo
-registry.register("ethos-u55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
-registry.register("ethos-u65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
+SUPPORTED_BACKENDS_PRIORITY = ["Vela", *CORSTONE_PRIORITY]
+
+
+for ethos_u in ("ethos-u55", "ethos-u65"):
+ registry.register(
+ ethos_u,
+ TargetInfo(
+ supported_backends=SUPPORTED_BACKENDS_PRIORITY,
+ default_backends=get_default_ethos_u_backends(SUPPORTED_BACKENDS_PRIORITY),
+ advisor_factory_func=configure_and_get_ethosu_advisor,
+ target_profile_cls=EthosUConfiguration,
+ ),
+ )
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index 225fd87..5f23fdd 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any
+from typing import cast
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
@@ -25,6 +26,7 @@ from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance
from mlia.target.ethos_u.data_collection import EthosUPerformance
from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent
from mlia.target.ethos_u.handlers import EthosUEventHandler
+from mlia.target.registry import profile
from mlia.utils.types import is_list_of
@@ -96,7 +98,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def _get_device_cfg(self, context: Context) -> EthosUConfiguration:
"""Get device configuration."""
target_profile = self.get_target_profile(context)
- return EthosUConfiguration.load_profile(target_profile)
+ return cast(EthosUConfiguration, profile(target_profile))
def _get_optimization_settings(self, context: Context) -> list[list[dict]]:
"""Get optimization settings."""
diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py
index eb5691d..d1a2c7a 100644
--- a/src/mlia/target/ethos_u/config.py
+++ b/src/mlia/target/ethos_u/config.py
@@ -6,12 +6,13 @@ from __future__ import annotations
import logging
from typing import Any
+from mlia.backend.corstone import is_corstone_backend
+from mlia.backend.manager import get_available_backends
from mlia.backend.vela.compiler import resolve_compiler_config
from mlia.backend.vela.compiler import VelaCompilerOptions
from mlia.target.config import TargetProfile
from mlia.utils.filesystem import get_vela_config
-
logger = logging.getLogger(__name__)
@@ -67,3 +68,22 @@ class EthosUConfiguration(TargetProfile):
def __repr__(self) -> str:
"""Return string representation."""
return f"<Ethos-U configuration target={self.target}>"
+
+
+def get_default_ethos_u_backends(
+ supported_backends_priority_order: list[str],
+) -> list[str]:
+ """Return default backends for Ethos-U targets."""
+ available_backends = get_available_backends()
+
+ default_backends = []
+ corstone_added = False
+ for backend in supported_backends_priority_order:
+ if backend not in available_backends:
+ continue
+ if is_corstone_backend(backend):
+ if corstone_added:
+ continue # only add one Corstone backend
+ corstone_added = True
+ default_backends.append(backend)
+ return default_backends
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py
index 4870fc8..9fccecb 100644
--- a/src/mlia/target/registry.py
+++ b/src/mlia/target/registry.py
@@ -3,17 +3,78 @@
"""Target module."""
from __future__ import annotations
+from functools import lru_cache
+from pathlib import Path
+from typing import cast
+
from mlia.backend.config import BackendType
from mlia.backend.manager import get_installation_manager
from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.reporting import Column
from mlia.core.reporting import Table
+from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES
+from mlia.target.config import get_builtin_profile_path
+from mlia.target.config import is_builtin_profile
+from mlia.target.config import load_profile
from mlia.target.config import TargetInfo
+from mlia.target.config import TargetProfile
from mlia.utils.registry import Registry
+
+class TargetRegistry(Registry[TargetInfo]):
+ """Registry for targets."""
+
+ def register(self, name: str, item: TargetInfo) -> bool:
+ """Register an item: returns `False` if already registered."""
+ assert all(
+ backend in backend_registry.items for backend in item.supported_backends
+ )
+ return super().register(name, item)
+
+
# All supported targets are required to be registered here.
-registry = Registry[TargetInfo]()
+registry = TargetRegistry()
+
+
+def builtin_profile_names() -> list[str]:
+ """Return a list of built-in profile names (not file paths)."""
+ return BUILTIN_SUPPORTED_PROFILE_NAMES
+
+
+@lru_cache
+def profile(target_profile: str | Path) -> TargetProfile:
+ """Get the target profile data (built-in or custom file)."""
+ if not target_profile:
+ raise ValueError("No valid target profile was provided.")
+ if is_builtin_profile(target_profile):
+ profile_file = get_builtin_profile_path(cast(str, target_profile))
+ profile_ = create_target_profile(profile_file)
+ else:
+ profile_file = Path(target_profile)
+ if profile_file.is_file():
+ profile_ = create_target_profile(profile_file)
+ else:
+ raise ValueError(
+ f"Profile '{target_profile}' is neither a valid built-in "
+ "target profile name or a valid file path."
+ )
+
+ return profile_
+
+
+def get_target(target_profile: str | Path) -> str:
+ """Return target for the provided target_profile."""
+ return profile(target_profile).target
+
+
+@lru_cache
+def create_target_profile(path: Path) -> TargetProfile:
+ """Create a new instance of a TargetProfile from the file."""
+ profile_data = load_profile(path)
+ target = profile_data["target"]
+ target_info = registry.items[target]
+ return target_info.target_profile_cls.load_json_data(profile_data)
def supported_advice(target: str) -> list[AdviceCategory]:
@@ -29,6 +90,11 @@ def supported_backends(target: str) -> list[str]:
return registry.items[target].filter_supported_backends(check_system=False)
+def default_backends(target: str) -> list[str]:
+ """Get a list of default backends for the given target."""
+ return registry.items[target].default_backends
+
+
def get_backend_to_supported_targets() -> dict[str, list]:
"""Get a dict that maps a list of supported targets given backend."""
targets = dict(registry.items)
diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py
index 06bf1a9..3830ce5 100644
--- a/src/mlia/target/tosa/__init__.py
+++ b/src/mlia/target/tosa/__init__.py
@@ -3,5 +3,15 @@
"""TOSA target module."""
from mlia.target.registry import registry
from mlia.target.registry import TargetInfo
+from mlia.target.tosa.advisor import configure_and_get_tosa_advisor
+from mlia.target.tosa.config import TOSAConfiguration
-registry.register("tosa", TargetInfo(["tosa-checker"]))
+registry.register(
+ "tosa",
+ TargetInfo(
+ supported_backends=["tosa-checker"],
+ default_backends=["tosa-checker"],
+ advisor_factory_func=configure_and_get_tosa_advisor,
+ target_profile_cls=TOSAConfiguration,
+ ),
+)
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index 5588d0f..5fb18ed 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any
+from typing import cast
from mlia.core.advice_generation import AdviceCategory
from mlia.core.advice_generation import AdviceProducer
@@ -17,6 +18,7 @@ from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
from mlia.core.metadata import MLIAMetadata
from mlia.core.metadata import ModelMetadata
+from mlia.target.registry import profile
from mlia.target.tosa.advice_generation import TOSAAdviceProducer
from mlia.target.tosa.config import TOSAConfiguration
from mlia.target.tosa.data_analysis import TOSADataAnalyzer
@@ -66,7 +68,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
return [
TOSAAdvisorStartedEvent(
model,
- TOSAConfiguration.load_profile(target_profile),
+ cast(TOSAConfiguration, profile(target_profile)),
MetadataDisplay(
TOSAMetadata("tosa-checker"),
MLIAMetadata("mlia"),