aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-02-02 14:02:05 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-02-10 13:45:18 +0000
commit7a661257b6adad0c8f53e32b42ced56a1e7d952f (patch)
tree938ad8578c5b9edc0573e810ce64ce0a5bda3d8c /src
parent50271dee0a84bfc481ce798184f07b5b0b4bc64d (diff)
downloadmlia-7a661257b6adad0c8f53e32b42ced56a1e7d952f.tar.gz
MLIA-769 Expand use of target/backend registries
- Use the target/backend registries to avoid hard-coded names. - Cache target profiles to avoid re-loading them Change-Id: I474b7c9ef23894e1d8a3ea06d13a37652054c62e
Diffstat (limited to 'src')
-rw-r--r--src/mlia/api.py21
-rw-r--r--src/mlia/backend/corstone/__init__.py30
-rw-r--r--src/mlia/backend/manager.py13
-rw-r--r--src/mlia/cli/command_validators.py20
-rw-r--r--src/mlia/cli/commands.py2
-rw-r--r--src/mlia/cli/config.py69
-rw-r--r--src/mlia/cli/helpers.py24
-rw-r--r--src/mlia/cli/main.py15
-rw-r--r--src/mlia/cli/options.py43
-rw-r--r--src/mlia/target/config.py71
-rw-r--r--src/mlia/target/cortex_a/__init__.py12
-rw-r--r--src/mlia/target/cortex_a/advisor.py4
-rw-r--r--src/mlia/target/ethos_u/__init__.py19
-rw-r--r--src/mlia/target/ethos_u/advisor.py4
-rw-r--r--src/mlia/target/ethos_u/config.py22
-rw-r--r--src/mlia/target/registry.py68
-rw-r--r--src/mlia/target/tosa/__init__.py12
-rw-r--r--src/mlia/target/tosa/advisor.py4
18 files changed, 277 insertions, 176 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index fd5fc13..7adae48 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -10,10 +10,8 @@ from typing import Any
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
-from mlia.target.config import get_target
-from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
-from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
-from mlia.target.tosa.advisor import configure_and_get_tosa_advisor
+from mlia.target.registry import profile
+from mlia.target.registry import registry as target_registry
logger = logging.getLogger(__name__)
@@ -84,19 +82,8 @@ def get_advisor(
**extra_args: Any,
) -> InferenceAdvisor:
"""Find appropriate advisor for the target."""
- target_factories = {
- "ethos-u55": configure_and_get_ethosu_advisor,
- "ethos-u65": configure_and_get_ethosu_advisor,
- "tosa": configure_and_get_tosa_advisor,
- "cortex-a": configure_and_get_cortexa_advisor,
- }
-
- try:
- target = get_target(target_profile)
- factory_function = target_factories[target]
- except KeyError as err:
- raise Exception(f"Unsupported profile {target_profile}") from err
-
+ target = profile(target_profile).target
+ factory_function = target_registry.items[target].advisor_factory_func
return factory_function(
context,
target_profile,
diff --git a/src/mlia/backend/corstone/__init__.py b/src/mlia/backend/corstone/__init__.py
index 36f74ee..b59ab65 100644
--- a/src/mlia/backend/corstone/__init__.py
+++ b/src/mlia/backend/corstone/__init__.py
@@ -7,24 +7,20 @@ from mlia.backend.config import System
from mlia.backend.registry import registry
from mlia.core.common import AdviceCategory
-registry.register(
- "Corstone-300",
- BackendConfiguration(
- supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
- supported_systems=[System.LINUX_AMD64],
- backend_type=BackendType.CUSTOM,
- ),
-)
-registry.register(
- "Corstone-310",
- BackendConfiguration(
- supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
- supported_systems=[System.LINUX_AMD64],
- backend_type=BackendType.CUSTOM,
- ),
-)
+# List of mutually exclusive Corstone backends ordered by priority
+CORSTONE_PRIORITY = ("Corstone-310", "Corstone-300")
+
+for corstone_name in CORSTONE_PRIORITY:
+ registry.register(
+ corstone_name,
+ BackendConfiguration(
+ supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
+ supported_systems=[System.LINUX_AMD64],
+ backend_type=BackendType.CUSTOM,
+ ),
+ )
def is_corstone_backend(backend_name: str) -> bool:
"""Check if backend belongs to Corstone."""
- return backend_name in ["Corstone-300", "Corstone-310"]
+ return backend_name in CORSTONE_PRIORITY
diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py
index b0fa919..d953b2d 100644
--- a/src/mlia/backend/manager.py
+++ b/src/mlia/backend/manager.py
@@ -9,11 +9,13 @@ from abc import abstractmethod
from pathlib import Path
from typing import Callable
+from mlia.backend.config import BackendType
from mlia.backend.corstone.install import get_corstone_installations
from mlia.backend.install import DownloadAndInstall
from mlia.backend.install import Installation
from mlia.backend.install import InstallationType
from mlia.backend.install import InstallFromPath
+from mlia.backend.registry import registry as backend_registry
from mlia.backend.tosa_checker.install import get_tosa_backend_installation
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
@@ -279,3 +281,14 @@ def get_installation_manager(noninteractive: bool = False) -> InstallationManage
backends.append(get_tosa_backend_installation())
return DefaultInstallationManager(backends, noninteractive=noninteractive)
+
+
+def get_available_backends() -> list[str]:
+ """Return list of the available backends."""
+ manager = get_installation_manager()
+ available_backends = [
+ backend
+ for backend, cfg in backend_registry.items.items()
+ if cfg.type == BackendType.BUILTIN or manager.backend_installed(backend)
+ ]
+ return available_backends
diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py
index 23101e0..a0f5433 100644
--- a/src/mlia/cli/command_validators.py
+++ b/src/mlia/cli/command_validators.py
@@ -7,8 +7,8 @@ import argparse
import logging
import sys
-from mlia.cli.config import get_default_backends_dict
-from mlia.target.config import get_target
+from mlia.target.registry import default_backends
+from mlia.target.registry import get_target
from mlia.target.registry import supported_backends
logger = logging.getLogger(__name__)
@@ -26,22 +26,18 @@ def validate_backend(
target = get_target(target_profile)
if not backend:
- return get_default_backends_dict()[target]
+ return default_backends(target)
- compatible_backends = supported_backends(target)
+ compatible_backends = list(map(normalize_string, supported_backends(target)))
+ backends = {normalize_string(b): b for b in backend}
- nor_backend = list(map(normalize_string, backend))
- nor_compat_backend = list(map(normalize_string, compatible_backends))
-
- incompatible_backends = [
- backend[i] for i, x in enumerate(nor_backend) if x not in nor_compat_backend
- ]
+ incompatible_backends = [b for b in backends if b not in compatible_backends]
# Throw an error if any unsupported backends are used
if incompatible_backends:
raise argparse.ArgumentError(
None,
- f"{', '.join(incompatible_backends)} backend not supported "
- f"with target-profile {target_profile}.",
+ f"Backend {', '.join(backends[b] for b in incompatible_backends)} "
+ f"not supported with target-profile {target_profile}.",
)
return backend
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index c17d571..27f5b2b 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -23,9 +23,9 @@ from pathlib import Path
from mlia.api import ExecutionContext
from mlia.api import get_advice
+from mlia.backend.manager import get_installation_manager
from mlia.cli.command_validators import validate_backend
from mlia.cli.command_validators import validate_check_target_profile
-from mlia.cli.config import get_installation_manager
from mlia.cli.options import parse_optimization_parameters
from mlia.utils.console import create_section_header
diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py
deleted file mode 100644
index 433300c..0000000
--- a/src/mlia/cli/config.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Environment configuration functions."""
-from __future__ import annotations
-
-import logging
-
-from mlia.backend.manager import get_installation_manager
-from mlia.target.registry import all_supported_backends
-
-logger = logging.getLogger(__name__)
-
-DEFAULT_PRUNING_TARGET = 0.5
-DEFAULT_CLUSTERING_TARGET = 32
-
-
-def get_available_backends() -> list[str]:
- """Return list of the available backends."""
- available_backends = ["Vela", "ArmNNTFLiteDelegate"]
-
- # Add backends using backend manager
- manager = get_installation_manager()
- available_backends.extend(
- backend
- for backend in all_supported_backends()
- if manager.backend_installed(backend)
- )
-
- return available_backends
-
-
-# List of mutually exclusive Corstone backends ordered by priority
-_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300")
-_NON_ETHOS_U_BACKENDS = ("tosa-checker", "ArmNNTFLiteDelegate")
-
-
-def get_ethos_u_default_backends(backends: list[str]) -> list[str]:
- """Get Ethos-U default backends for evaluation."""
- return [x for x in backends if x not in _NON_ETHOS_U_BACKENDS]
-
-
-def get_default_backends() -> list[str]:
- """Get default backends for evaluation."""
- backends = get_available_backends()
-
- # Filter backends to only include one Corstone backend
- for corstone in _CORSTONE_EXCLUSIVE_PRIORITY:
- if corstone in backends:
- backends = [
- backend
- for backend in backends
- if backend == corstone or backend not in _CORSTONE_EXCLUSIVE_PRIORITY
- ]
- break
-
- return backends
-
-
-def get_default_backends_dict() -> dict[str, list[str]]:
- """Return default backends for all targets."""
- default_backends = get_default_backends()
- ethos_u_defaults = get_ethos_u_default_backends(default_backends)
-
- return {
- "ethos-u55": ethos_u_defaults,
- "ethos-u65": ethos_u_defaults,
- "tosa": ["tosa-checker"],
- "cortex-a": ["ArmNNTFLiteDelegate"],
- }
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
index ac64581..576670b 100644
--- a/src/mlia/cli/helpers.py
+++ b/src/mlia/cli/helpers.py
@@ -1,14 +1,19 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Module for various helper classes."""
+"""Module for various helpers."""
from __future__ import annotations
+from pathlib import Path
+from shutil import copy
from typing import Any
+from typing import cast
from mlia.cli.options import get_target_profile_opts
from mlia.core.helpers import ActionResolver
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.target.config import get_builtin_profile_path
+from mlia.target.config import is_builtin_profile
from mlia.utils.types import is_list_of
@@ -108,3 +113,20 @@ class CLIActionResolver(ActionResolver):
model_path = self.args.get("model")
return model_path, device_opts
+
+
+def copy_profile_file_to_output_dir(
+ target_profile: str | Path, output_dir: str | Path
+) -> bool:
+ """Copy the target profile file to the output directory."""
+ profile_file_path = (
+ get_builtin_profile_path(cast(str, target_profile))
+ if is_builtin_profile(target_profile)
+ else Path(target_profile)
+ )
+ output_file_path = f"{output_dir}/{profile_file_path.stem}.toml"
+ try:
+ copy(profile_file_path, output_file_path)
+ return True
+ except OSError as err:
+ raise RuntimeError("Failed to copy profile file:", err.strerror) from err
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 793e155..b3a9d4c 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -18,6 +18,7 @@ from mlia.cli.commands import check
from mlia.cli.commands import optimize
from mlia.cli.common import CommandInfo
from mlia.cli.helpers import CLIActionResolver
+from mlia.cli.helpers import copy_profile_file_to_output_dir
from mlia.cli.options import add_backend_install_options
from mlia.cli.options import add_backend_options
from mlia.cli.options import add_backend_uninstall_options
@@ -30,11 +31,11 @@ from mlia.cli.options import add_output_directory
from mlia.cli.options import add_output_options
from mlia.cli.options import add_target_options
from mlia.cli.options import get_output_format
+from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
from mlia.core.logging import setup_logging
-from mlia.target.config import copy_profile_file_to_output_dir
from mlia.target.registry import table as target_table
@@ -59,7 +60,13 @@ def get_commands() -> list[CommandInfo]:
[
add_output_directory,
add_model_options,
- add_target_options,
+ partial(
+ add_target_options,
+ supported_advice=[
+ AdviceCategory.COMPATIBILITY,
+ AdviceCategory.PERFORMANCE,
+ ],
+ ),
add_backend_options,
add_check_category_options,
add_output_options,
@@ -72,7 +79,9 @@ def get_commands() -> list[CommandInfo]:
[
add_output_directory,
add_keras_model_options,
- partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]),
+ partial(
+ add_target_options, supported_advice=[AdviceCategory.OPTIMIZATION]
+ ),
partial(
add_backend_options,
backends_to_skip=["tosa-checker", "ArmNNTFLiteDelegate"],
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 421533a..8cd2935 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -7,13 +7,17 @@ import argparse
from pathlib import Path
from typing import Any
from typing import Callable
+from typing import Sequence
from mlia.backend.corstone import is_corstone_backend
-from mlia.cli.config import DEFAULT_CLUSTERING_TARGET
-from mlia.cli.config import DEFAULT_PRUNING_TARGET
-from mlia.cli.config import get_available_backends
+from mlia.backend.manager import get_available_backends
+from mlia.core.common import AdviceCategory
from mlia.core.typing import OutputFormat
-from mlia.target.config import get_builtin_supported_profile_names
+from mlia.target.registry import builtin_profile_names
+from mlia.target.registry import registry as target_registry
+
+DEFAULT_PRUNING_TARGET = 0.5
+DEFAULT_CLUSTERING_TARGET = 32
def add_check_category_options(parser: argparse.ArgumentParser) -> None:
@@ -31,22 +35,39 @@ def add_check_category_options(parser: argparse.ArgumentParser) -> None:
def add_target_options(
parser: argparse.ArgumentParser,
- profiles_to_skip: list[str] | None = None,
+ supported_advice: Sequence[AdviceCategory] | None = None,
required: bool = True,
) -> None:
"""Add target specific options."""
- target_profiles = get_builtin_supported_profile_names()
- if profiles_to_skip:
- target_profiles = [tp for tp in target_profiles if tp not in profiles_to_skip]
-
- default_target_profile = "ethos-u55-256"
+ target_profiles = builtin_profile_names()
+
+ if supported_advice:
+
+ def is_advice_supported(profile: str, advice: Sequence[AdviceCategory]) -> bool:
+ """
+ Collect all target profiles that support the advice.
+
+ This means target profiles that...
+ - have the right target prefix, e.g. "ethos-u55..." to avoid loading
+ all target profiles
+ - support any of the required advice
+ """
+ for target, info in target_registry.items.items():
+ if profile.startswith(target):
+ return any(info.is_supported(adv) for adv in advice)
+ return False
+
+ target_profiles = [
+ profile
+ for profile in target_profiles
+ if is_advice_supported(profile, supported_advice)
+ ]
target_group = parser.add_argument_group("target options")
target_group.add_argument(
"-t",
"--target-profile",
required=required,
- default=default_target_profile,
help="Built-in target profile or path to the custom target profile. "
f"Built-in target profiles are {', '.join(target_profiles)}. "
"Target profile that will set the target options "
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index bf603dd..eb7ecff 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -6,9 +6,10 @@ from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
+from functools import lru_cache
from pathlib import Path
-from shutil import copy
from typing import Any
+from typing import Callable
from typing import cast
from typing import TypeVar
@@ -19,23 +20,20 @@ except ModuleNotFoundError:
from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
+from mlia.core.advisor import InferenceAdvisor
from mlia.utils.filesystem import get_mlia_target_profiles_dir
-def get_profile_file(target_profile: str | Path) -> Path:
- """Get the target profile toml file."""
- if not target_profile:
- raise Exception("Target profile is not provided.")
+def get_builtin_profile_path(target_profile: str) -> Path:
+ """
+ Construct the path to the built-in target profile file.
- profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml")
- if not profile_file.is_file():
- profile_file = Path(target_profile)
-
- if not profile_file.exists():
- raise Exception(f"File not found: {profile_file}.")
- return profile_file
+ No checks are performed.
+ """
+ return get_mlia_target_profiles_dir() / f"{target_profile}.toml"
+@lru_cache
def load_profile(path: str | Path) -> dict[str, Any]:
"""Get settings for the provided target profile."""
with open(path, "rb") as file:
@@ -55,24 +53,12 @@ def get_builtin_supported_profile_names() -> list[str]:
)
-def get_target(target_profile: str | Path) -> str:
- """Return target for the provided target_profile."""
- profile_file = get_profile_file(target_profile)
- profile = load_profile(profile_file)
- return cast(str, profile["target"])
+BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names()
-def copy_profile_file_to_output_dir(
- target_profile: str | Path, output_dir: str | Path
-) -> bool:
- """Copy the target profile file to output directory."""
- profile_file_path = get_profile_file(target_profile)
- output_file_path = f"{output_dir}/{profile_file_path.stem}.toml"
- try:
- copy(profile_file_path, output_file_path)
- return True
- except OSError as err:
- raise RuntimeError("Failed to copy profile file:", err.strerror) from err
+def is_builtin_profile(profile_name: str | Path) -> bool:
+ """Check if the given profile name belongs to a built-in profile."""
+ return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES
T = TypeVar("T", bound="TargetProfile")
@@ -88,21 +74,29 @@ class TargetProfile(ABC):
@classmethod
def load(cls: type[T], path: str | Path) -> T:
"""Load and verify a target profile from file and return new instance."""
- profile = load_profile(path)
+ profile_data = load_profile(path)
try:
- new_instance = cls(**profile)
+ new_instance = cls.load_json_data(profile_data)
except KeyError as ex:
raise KeyError(f"Missing key in file {path}.") from ex
- new_instance.verify()
+ return new_instance
+ @classmethod
+ def load_json_data(cls: type[T], profile_data: dict) -> T:
+ """Load a target profile from the JSON data."""
+ new_instance = cls(**profile_data)
+ new_instance.verify()
return new_instance
@classmethod
- def load_profile(cls: type[T], target_profile: str) -> T:
- """Load a target profile by name."""
- profile_file = get_profile_file(target_profile)
+ def load_profile(cls: type[T], target_profile: str | Path) -> T:
+ """Load a target profile from built-in target profile name or file path."""
+ if is_builtin_profile(target_profile):
+ profile_file = get_builtin_profile_path(cast(str, target_profile))
+ else:
+ profile_file = Path(target_profile)
return cls.load(profile_file)
def save(self, path: str | Path) -> None:
@@ -125,6 +119,9 @@ class TargetInfo:
"""Collect information about supported targets."""
supported_backends: list[str]
+ default_backends: list[str]
+ advisor_factory_func: Callable[..., InferenceAdvisor]
+ target_profile_cls: type[TargetProfile]
def __str__(self) -> str:
"""List supported backends."""
@@ -135,7 +132,8 @@ class TargetInfo:
) -> bool:
"""Check if any of the supported backends support this kind of advice."""
return any(
- backend_registry.items[name].is_supported(advice, check_system)
+ name in backend_registry.items
+ and backend_registry.items[name].is_supported(advice, check_system)
for name in self.supported_backends
)
@@ -146,5 +144,6 @@ class TargetInfo:
return [
name
for name in self.supported_backends
- if backend_registry.items[name].is_supported(advice, check_system)
+ if name in backend_registry.items
+ and backend_registry.items[name].is_supported(advice, check_system)
]
diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py
index f686bfc..87f268a 100644
--- a/src/mlia/target/cortex_a/__init__.py
+++ b/src/mlia/target/cortex_a/__init__.py
@@ -1,7 +1,17 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A target module."""
+from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
+from mlia.target.cortex_a.config import CortexAConfiguration
from mlia.target.registry import registry
from mlia.target.registry import TargetInfo
-registry.register("cortex-a", TargetInfo(["ArmNNTFLiteDelegate"]))
+registry.register(
+ "cortex-a",
+ TargetInfo(
+ supported_backends=["ArmNNTFLiteDelegate"],
+ default_backends=["ArmNNTFLiteDelegate"],
+ advisor_factory_func=configure_and_get_cortexa_advisor,
+ target_profile_cls=CortexAConfiguration,
+ ),
+)
diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py
index 518c9f1..a093784 100644
--- a/src/mlia/target/cortex_a/advisor.py
+++ b/src/mlia/target/cortex_a/advisor.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any
+from typing import cast
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
@@ -21,6 +22,7 @@ from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer
from mlia.target.cortex_a.data_collection import CortexAOperatorCompatibility
from mlia.target.cortex_a.events import CortexAAdvisorStartedEvent
from mlia.target.cortex_a.handlers import CortexAEventHandler
+from mlia.target.registry import profile
class CortexAInferenceAdvisor(DefaultInferenceAdvisor):
@@ -59,7 +61,7 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor):
return [
CortexAAdvisorStartedEvent(
- model, CortexAConfiguration.load_profile(target_profile)
+ model, cast(CortexAConfiguration, profile(target_profile))
),
]
diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py
index d53be53..6b6777d 100644
--- a/src/mlia/target/ethos_u/__init__.py
+++ b/src/mlia/target/ethos_u/__init__.py
@@ -1,8 +1,23 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U target module."""
+from mlia.backend.corstone import CORSTONE_PRIORITY
+from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
+from mlia.target.ethos_u.config import EthosUConfiguration
+from mlia.target.ethos_u.config import get_default_ethos_u_backends
from mlia.target.registry import registry
from mlia.target.registry import TargetInfo
-registry.register("ethos-u55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
-registry.register("ethos-u65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
+SUPPORTED_BACKENDS_PRIORITY = ["Vela", *CORSTONE_PRIORITY]
+
+
+for ethos_u in ("ethos-u55", "ethos-u65"):
+ registry.register(
+ ethos_u,
+ TargetInfo(
+ supported_backends=SUPPORTED_BACKENDS_PRIORITY,
+ default_backends=get_default_ethos_u_backends(SUPPORTED_BACKENDS_PRIORITY),
+ advisor_factory_func=configure_and_get_ethosu_advisor,
+ target_profile_cls=EthosUConfiguration,
+ ),
+ )
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index 225fd87..5f23fdd 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any
+from typing import cast
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
@@ -25,6 +26,7 @@ from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance
from mlia.target.ethos_u.data_collection import EthosUPerformance
from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent
from mlia.target.ethos_u.handlers import EthosUEventHandler
+from mlia.target.registry import profile
from mlia.utils.types import is_list_of
@@ -96,7 +98,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def _get_device_cfg(self, context: Context) -> EthosUConfiguration:
"""Get device configuration."""
target_profile = self.get_target_profile(context)
- return EthosUConfiguration.load_profile(target_profile)
+ return cast(EthosUConfiguration, profile(target_profile))
def _get_optimization_settings(self, context: Context) -> list[list[dict]]:
"""Get optimization settings."""
diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py
index eb5691d..d1a2c7a 100644
--- a/src/mlia/target/ethos_u/config.py
+++ b/src/mlia/target/ethos_u/config.py
@@ -6,12 +6,13 @@ from __future__ import annotations
import logging
from typing import Any
+from mlia.backend.corstone import is_corstone_backend
+from mlia.backend.manager import get_available_backends
from mlia.backend.vela.compiler import resolve_compiler_config
from mlia.backend.vela.compiler import VelaCompilerOptions
from mlia.target.config import TargetProfile
from mlia.utils.filesystem import get_vela_config
-
logger = logging.getLogger(__name__)
@@ -67,3 +68,22 @@ class EthosUConfiguration(TargetProfile):
def __repr__(self) -> str:
"""Return string representation."""
return f"<Ethos-U configuration target={self.target}>"
+
+
+def get_default_ethos_u_backends(
+ supported_backends_priority_order: list[str],
+) -> list[str]:
+ """Return default backends for Ethos-U targets."""
+ available_backends = get_available_backends()
+
+ default_backends = []
+ corstone_added = False
+ for backend in supported_backends_priority_order:
+ if backend not in available_backends:
+ continue
+ if is_corstone_backend(backend):
+ if corstone_added:
+ continue # only add one Corstone backend
+ corstone_added = True
+ default_backends.append(backend)
+ return default_backends
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py
index 4870fc8..9fccecb 100644
--- a/src/mlia/target/registry.py
+++ b/src/mlia/target/registry.py
@@ -3,17 +3,78 @@
"""Target module."""
from __future__ import annotations
+from functools import lru_cache
+from pathlib import Path
+from typing import cast
+
from mlia.backend.config import BackendType
from mlia.backend.manager import get_installation_manager
from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.reporting import Column
from mlia.core.reporting import Table
+from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES
+from mlia.target.config import get_builtin_profile_path
+from mlia.target.config import is_builtin_profile
+from mlia.target.config import load_profile
from mlia.target.config import TargetInfo
+from mlia.target.config import TargetProfile
from mlia.utils.registry import Registry
+
+class TargetRegistry(Registry[TargetInfo]):
+ """Registry for targets."""
+
+ def register(self, name: str, item: TargetInfo) -> bool:
+ """Register an item: returns `False` if already registered."""
+ assert all(
+ backend in backend_registry.items for backend in item.supported_backends
+ )
+ return super().register(name, item)
+
+
# All supported targets are required to be registered here.
-registry = Registry[TargetInfo]()
+registry = TargetRegistry()
+
+
+def builtin_profile_names() -> list[str]:
+ """Return a list of built-in profile names (not file paths)."""
+ return BUILTIN_SUPPORTED_PROFILE_NAMES
+
+
+@lru_cache
+def profile(target_profile: str | Path) -> TargetProfile:
+ """Get the target profile data (built-in or custom file)."""
+ if not target_profile:
+ raise ValueError("No valid target profile was provided.")
+ if is_builtin_profile(target_profile):
+ profile_file = get_builtin_profile_path(cast(str, target_profile))
+ profile_ = create_target_profile(profile_file)
+ else:
+ profile_file = Path(target_profile)
+ if profile_file.is_file():
+ profile_ = create_target_profile(profile_file)
+ else:
+ raise ValueError(
+ f"Profile '{target_profile}' is neither a valid built-in "
+ "target profile name or a valid file path."
+ )
+
+ return profile_
+
+
+def get_target(target_profile: str | Path) -> str:
+ """Return target for the provided target_profile."""
+ return profile(target_profile).target
+
+
+@lru_cache
+def create_target_profile(path: Path) -> TargetProfile:
+ """Create a new instance of a TargetProfile from the file."""
+ profile_data = load_profile(path)
+ target = profile_data["target"]
+ target_info = registry.items[target]
+ return target_info.target_profile_cls.load_json_data(profile_data)
def supported_advice(target: str) -> list[AdviceCategory]:
@@ -29,6 +90,11 @@ def supported_backends(target: str) -> list[str]:
return registry.items[target].filter_supported_backends(check_system=False)
+def default_backends(target: str) -> list[str]:
+ """Get a list of default backends for the given target."""
+ return registry.items[target].default_backends
+
+
def get_backend_to_supported_targets() -> dict[str, list]:
"""Get a dict that maps a list of supported targets given backend."""
targets = dict(registry.items)
diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py
index 06bf1a9..3830ce5 100644
--- a/src/mlia/target/tosa/__init__.py
+++ b/src/mlia/target/tosa/__init__.py
@@ -3,5 +3,15 @@
"""TOSA target module."""
from mlia.target.registry import registry
from mlia.target.registry import TargetInfo
+from mlia.target.tosa.advisor import configure_and_get_tosa_advisor
+from mlia.target.tosa.config import TOSAConfiguration
-registry.register("tosa", TargetInfo(["tosa-checker"]))
+registry.register(
+ "tosa",
+ TargetInfo(
+ supported_backends=["tosa-checker"],
+ default_backends=["tosa-checker"],
+ advisor_factory_func=configure_and_get_tosa_advisor,
+ target_profile_cls=TOSAConfiguration,
+ ),
+)
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index 5588d0f..5fb18ed 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any
+from typing import cast
from mlia.core.advice_generation import AdviceCategory
from mlia.core.advice_generation import AdviceProducer
@@ -17,6 +18,7 @@ from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
from mlia.core.metadata import MLIAMetadata
from mlia.core.metadata import ModelMetadata
+from mlia.target.registry import profile
from mlia.target.tosa.advice_generation import TOSAAdviceProducer
from mlia.target.tosa.config import TOSAConfiguration
from mlia.target.tosa.data_analysis import TOSADataAnalyzer
@@ -66,7 +68,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
return [
TOSAAdvisorStartedEvent(
model,
- TOSAConfiguration.load_profile(target_profile),
+ cast(TOSAConfiguration, profile(target_profile)),
MetadataDisplay(
TOSAMetadata("tosa-checker"),
MLIAMetadata("mlia"),