aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-02-02 14:02:05 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-02-10 13:45:18 +0000
commit7a661257b6adad0c8f53e32b42ced56a1e7d952f (patch)
tree938ad8578c5b9edc0573e810ce64ce0a5bda3d8c
parent50271dee0a84bfc481ce798184f07b5b0b4bc64d (diff)
downloadmlia-7a661257b6adad0c8f53e32b42ced56a1e7d952f.tar.gz
MLIA-769 Expand use of target/backend registries
- Use the target/backend registries to avoid hard-coded names. - Cache target profiles to avoid re-loading them Change-Id: I474b7c9ef23894e1d8a3ea06d13a37652054c62e
-rw-r--r--src/mlia/api.py21
-rw-r--r--src/mlia/backend/corstone/__init__.py30
-rw-r--r--src/mlia/backend/manager.py13
-rw-r--r--src/mlia/cli/command_validators.py20
-rw-r--r--src/mlia/cli/commands.py2
-rw-r--r--src/mlia/cli/config.py69
-rw-r--r--src/mlia/cli/helpers.py24
-rw-r--r--src/mlia/cli/main.py15
-rw-r--r--src/mlia/cli/options.py43
-rw-r--r--src/mlia/target/config.py71
-rw-r--r--src/mlia/target/cortex_a/__init__.py12
-rw-r--r--src/mlia/target/cortex_a/advisor.py4
-rw-r--r--src/mlia/target/ethos_u/__init__.py19
-rw-r--r--src/mlia/target/ethos_u/advisor.py4
-rw-r--r--src/mlia/target/ethos_u/config.py22
-rw-r--r--src/mlia/target/registry.py68
-rw-r--r--src/mlia/target/tosa/__init__.py12
-rw-r--r--src/mlia/target/tosa/advisor.py4
-rw-r--r--tests/conftest.py7
-rw-r--r--tests/test_api.py8
-rw-r--r--tests/test_cli_command_validators.py57
-rw-r--r--tests/test_cli_commands.py8
-rw-r--r--tests/test_cli_config.py48
-rw-r--r--tests/test_cli_helpers.py11
-rw-r--r--tests/test_target_config.py63
-rw-r--r--tests/test_target_ethos_u_data_analysis.py9
-rw-r--r--tests/test_target_ethos_u_reporters.py5
-rw-r--r--tests/test_target_registry.py57
-rw-r--r--tests_e2e/test_e2e.py2
29 files changed, 432 insertions, 296 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index fd5fc13..7adae48 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -10,10 +10,8 @@ from typing import Any
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
-from mlia.target.config import get_target
-from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
-from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
-from mlia.target.tosa.advisor import configure_and_get_tosa_advisor
+from mlia.target.registry import profile
+from mlia.target.registry import registry as target_registry
logger = logging.getLogger(__name__)
@@ -84,19 +82,8 @@ def get_advisor(
**extra_args: Any,
) -> InferenceAdvisor:
"""Find appropriate advisor for the target."""
- target_factories = {
- "ethos-u55": configure_and_get_ethosu_advisor,
- "ethos-u65": configure_and_get_ethosu_advisor,
- "tosa": configure_and_get_tosa_advisor,
- "cortex-a": configure_and_get_cortexa_advisor,
- }
-
- try:
- target = get_target(target_profile)
- factory_function = target_factories[target]
- except KeyError as err:
- raise Exception(f"Unsupported profile {target_profile}") from err
-
+ target = profile(target_profile).target
+ factory_function = target_registry.items[target].advisor_factory_func
return factory_function(
context,
target_profile,
diff --git a/src/mlia/backend/corstone/__init__.py b/src/mlia/backend/corstone/__init__.py
index 36f74ee..b59ab65 100644
--- a/src/mlia/backend/corstone/__init__.py
+++ b/src/mlia/backend/corstone/__init__.py
@@ -7,24 +7,20 @@ from mlia.backend.config import System
from mlia.backend.registry import registry
from mlia.core.common import AdviceCategory
-registry.register(
- "Corstone-300",
- BackendConfiguration(
- supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
- supported_systems=[System.LINUX_AMD64],
- backend_type=BackendType.CUSTOM,
- ),
-)
-registry.register(
- "Corstone-310",
- BackendConfiguration(
- supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
- supported_systems=[System.LINUX_AMD64],
- backend_type=BackendType.CUSTOM,
- ),
-)
+# List of mutually exclusive Corstone backends ordered by priority
+CORSTONE_PRIORITY = ("Corstone-310", "Corstone-300")
+
+for corstone_name in CORSTONE_PRIORITY:
+ registry.register(
+ corstone_name,
+ BackendConfiguration(
+ supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
+ supported_systems=[System.LINUX_AMD64],
+ backend_type=BackendType.CUSTOM,
+ ),
+ )
def is_corstone_backend(backend_name: str) -> bool:
"""Check if backend belongs to Corstone."""
- return backend_name in ["Corstone-300", "Corstone-310"]
+ return backend_name in CORSTONE_PRIORITY
diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py
index b0fa919..d953b2d 100644
--- a/src/mlia/backend/manager.py
+++ b/src/mlia/backend/manager.py
@@ -9,11 +9,13 @@ from abc import abstractmethod
from pathlib import Path
from typing import Callable
+from mlia.backend.config import BackendType
from mlia.backend.corstone.install import get_corstone_installations
from mlia.backend.install import DownloadAndInstall
from mlia.backend.install import Installation
from mlia.backend.install import InstallationType
from mlia.backend.install import InstallFromPath
+from mlia.backend.registry import registry as backend_registry
from mlia.backend.tosa_checker.install import get_tosa_backend_installation
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
@@ -279,3 +281,14 @@ def get_installation_manager(noninteractive: bool = False) -> InstallationManage
backends.append(get_tosa_backend_installation())
return DefaultInstallationManager(backends, noninteractive=noninteractive)
+
+
+def get_available_backends() -> list[str]:
+ """Return list of the available backends."""
+ manager = get_installation_manager()
+ available_backends = [
+ backend
+ for backend, cfg in backend_registry.items.items()
+ if cfg.type == BackendType.BUILTIN or manager.backend_installed(backend)
+ ]
+ return available_backends
diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py
index 23101e0..a0f5433 100644
--- a/src/mlia/cli/command_validators.py
+++ b/src/mlia/cli/command_validators.py
@@ -7,8 +7,8 @@ import argparse
import logging
import sys
-from mlia.cli.config import get_default_backends_dict
-from mlia.target.config import get_target
+from mlia.target.registry import default_backends
+from mlia.target.registry import get_target
from mlia.target.registry import supported_backends
logger = logging.getLogger(__name__)
@@ -26,22 +26,18 @@ def validate_backend(
target = get_target(target_profile)
if not backend:
- return get_default_backends_dict()[target]
+ return default_backends(target)
- compatible_backends = supported_backends(target)
+ compatible_backends = list(map(normalize_string, supported_backends(target)))
+ backends = {normalize_string(b): b for b in backend}
- nor_backend = list(map(normalize_string, backend))
- nor_compat_backend = list(map(normalize_string, compatible_backends))
-
- incompatible_backends = [
- backend[i] for i, x in enumerate(nor_backend) if x not in nor_compat_backend
- ]
+ incompatible_backends = [b for b in backends if b not in compatible_backends]
# Throw an error if any unsupported backends are used
if incompatible_backends:
raise argparse.ArgumentError(
None,
- f"{', '.join(incompatible_backends)} backend not supported "
- f"with target-profile {target_profile}.",
+ f"Backend {', '.join(backends[b] for b in incompatible_backends)} "
+ f"not supported with target-profile {target_profile}.",
)
return backend
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index c17d571..27f5b2b 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -23,9 +23,9 @@ from pathlib import Path
from mlia.api import ExecutionContext
from mlia.api import get_advice
+from mlia.backend.manager import get_installation_manager
from mlia.cli.command_validators import validate_backend
from mlia.cli.command_validators import validate_check_target_profile
-from mlia.cli.config import get_installation_manager
from mlia.cli.options import parse_optimization_parameters
from mlia.utils.console import create_section_header
diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py
deleted file mode 100644
index 433300c..0000000
--- a/src/mlia/cli/config.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Environment configuration functions."""
-from __future__ import annotations
-
-import logging
-
-from mlia.backend.manager import get_installation_manager
-from mlia.target.registry import all_supported_backends
-
-logger = logging.getLogger(__name__)
-
-DEFAULT_PRUNING_TARGET = 0.5
-DEFAULT_CLUSTERING_TARGET = 32
-
-
-def get_available_backends() -> list[str]:
- """Return list of the available backends."""
- available_backends = ["Vela", "ArmNNTFLiteDelegate"]
-
- # Add backends using backend manager
- manager = get_installation_manager()
- available_backends.extend(
- backend
- for backend in all_supported_backends()
- if manager.backend_installed(backend)
- )
-
- return available_backends
-
-
-# List of mutually exclusive Corstone backends ordered by priority
-_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300")
-_NON_ETHOS_U_BACKENDS = ("tosa-checker", "ArmNNTFLiteDelegate")
-
-
-def get_ethos_u_default_backends(backends: list[str]) -> list[str]:
- """Get Ethos-U default backends for evaluation."""
- return [x for x in backends if x not in _NON_ETHOS_U_BACKENDS]
-
-
-def get_default_backends() -> list[str]:
- """Get default backends for evaluation."""
- backends = get_available_backends()
-
- # Filter backends to only include one Corstone backend
- for corstone in _CORSTONE_EXCLUSIVE_PRIORITY:
- if corstone in backends:
- backends = [
- backend
- for backend in backends
- if backend == corstone or backend not in _CORSTONE_EXCLUSIVE_PRIORITY
- ]
- break
-
- return backends
-
-
-def get_default_backends_dict() -> dict[str, list[str]]:
- """Return default backends for all targets."""
- default_backends = get_default_backends()
- ethos_u_defaults = get_ethos_u_default_backends(default_backends)
-
- return {
- "ethos-u55": ethos_u_defaults,
- "ethos-u65": ethos_u_defaults,
- "tosa": ["tosa-checker"],
- "cortex-a": ["ArmNNTFLiteDelegate"],
- }
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
index ac64581..576670b 100644
--- a/src/mlia/cli/helpers.py
+++ b/src/mlia/cli/helpers.py
@@ -1,14 +1,19 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Module for various helper classes."""
+"""Module for various helpers."""
from __future__ import annotations
+from pathlib import Path
+from shutil import copy
from typing import Any
+from typing import cast
from mlia.cli.options import get_target_profile_opts
from mlia.core.helpers import ActionResolver
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.target.config import get_builtin_profile_path
+from mlia.target.config import is_builtin_profile
from mlia.utils.types import is_list_of
@@ -108,3 +113,20 @@ class CLIActionResolver(ActionResolver):
model_path = self.args.get("model")
return model_path, device_opts
+
+
+def copy_profile_file_to_output_dir(
+ target_profile: str | Path, output_dir: str | Path
+) -> bool:
+ """Copy the target profile file to the output directory."""
+ profile_file_path = (
+ get_builtin_profile_path(cast(str, target_profile))
+ if is_builtin_profile(target_profile)
+ else Path(target_profile)
+ )
+ output_file_path = f"{output_dir}/{profile_file_path.stem}.toml"
+ try:
+ copy(profile_file_path, output_file_path)
+ return True
+ except OSError as err:
+ raise RuntimeError("Failed to copy profile file:", err.strerror) from err
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 793e155..b3a9d4c 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -18,6 +18,7 @@ from mlia.cli.commands import check
from mlia.cli.commands import optimize
from mlia.cli.common import CommandInfo
from mlia.cli.helpers import CLIActionResolver
+from mlia.cli.helpers import copy_profile_file_to_output_dir
from mlia.cli.options import add_backend_install_options
from mlia.cli.options import add_backend_options
from mlia.cli.options import add_backend_uninstall_options
@@ -30,11 +31,11 @@ from mlia.cli.options import add_output_directory
from mlia.cli.options import add_output_options
from mlia.cli.options import add_target_options
from mlia.cli.options import get_output_format
+from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
from mlia.core.logging import setup_logging
-from mlia.target.config import copy_profile_file_to_output_dir
from mlia.target.registry import table as target_table
@@ -59,7 +60,13 @@ def get_commands() -> list[CommandInfo]:
[
add_output_directory,
add_model_options,
- add_target_options,
+ partial(
+ add_target_options,
+ supported_advice=[
+ AdviceCategory.COMPATIBILITY,
+ AdviceCategory.PERFORMANCE,
+ ],
+ ),
add_backend_options,
add_check_category_options,
add_output_options,
@@ -72,7 +79,9 @@ def get_commands() -> list[CommandInfo]:
[
add_output_directory,
add_keras_model_options,
- partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]),
+ partial(
+ add_target_options, supported_advice=[AdviceCategory.OPTIMIZATION]
+ ),
partial(
add_backend_options,
backends_to_skip=["tosa-checker", "ArmNNTFLiteDelegate"],
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 421533a..8cd2935 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -7,13 +7,17 @@ import argparse
from pathlib import Path
from typing import Any
from typing import Callable
+from typing import Sequence
from mlia.backend.corstone import is_corstone_backend
-from mlia.cli.config import DEFAULT_CLUSTERING_TARGET
-from mlia.cli.config import DEFAULT_PRUNING_TARGET
-from mlia.cli.config import get_available_backends
+from mlia.backend.manager import get_available_backends
+from mlia.core.common import AdviceCategory
from mlia.core.typing import OutputFormat
-from mlia.target.config import get_builtin_supported_profile_names
+from mlia.target.registry import builtin_profile_names
+from mlia.target.registry import registry as target_registry
+
+DEFAULT_PRUNING_TARGET = 0.5
+DEFAULT_CLUSTERING_TARGET = 32
def add_check_category_options(parser: argparse.ArgumentParser) -> None:
@@ -31,22 +35,39 @@ def add_check_category_options(parser: argparse.ArgumentParser) -> None:
def add_target_options(
parser: argparse.ArgumentParser,
- profiles_to_skip: list[str] | None = None,
+ supported_advice: Sequence[AdviceCategory] | None = None,
required: bool = True,
) -> None:
"""Add target specific options."""
- target_profiles = get_builtin_supported_profile_names()
- if profiles_to_skip:
- target_profiles = [tp for tp in target_profiles if tp not in profiles_to_skip]
-
- default_target_profile = "ethos-u55-256"
+ target_profiles = builtin_profile_names()
+
+ if supported_advice:
+
+ def is_advice_supported(profile: str, advice: Sequence[AdviceCategory]) -> bool:
+ """
+ Collect all target profiles that support the advice.
+
+ This means target profiles that...
+ - have the right target prefix, e.g. "ethos-u55..." to avoid loading
+ all target profiles
+ - support any of the required advice
+ """
+ for target, info in target_registry.items.items():
+ if profile.startswith(target):
+ return any(info.is_supported(adv) for adv in advice)
+ return False
+
+ target_profiles = [
+ profile
+ for profile in target_profiles
+ if is_advice_supported(profile, supported_advice)
+ ]
target_group = parser.add_argument_group("target options")
target_group.add_argument(
"-t",
"--target-profile",
required=required,
- default=default_target_profile,
help="Built-in target profile or path to the custom target profile. "
f"Built-in target profiles are {', '.join(target_profiles)}. "
"Target profile that will set the target options "
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index bf603dd..eb7ecff 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -6,9 +6,10 @@ from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
+from functools import lru_cache
from pathlib import Path
-from shutil import copy
from typing import Any
+from typing import Callable
from typing import cast
from typing import TypeVar
@@ -19,23 +20,20 @@ except ModuleNotFoundError:
from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
+from mlia.core.advisor import InferenceAdvisor
from mlia.utils.filesystem import get_mlia_target_profiles_dir
-def get_profile_file(target_profile: str | Path) -> Path:
- """Get the target profile toml file."""
- if not target_profile:
- raise Exception("Target profile is not provided.")
+def get_builtin_profile_path(target_profile: str) -> Path:
+ """
+ Construct the path to the built-in target profile file.
- profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml")
- if not profile_file.is_file():
- profile_file = Path(target_profile)
-
- if not profile_file.exists():
- raise Exception(f"File not found: {profile_file}.")
- return profile_file
+ No checks are performed.
+ """
+ return get_mlia_target_profiles_dir() / f"{target_profile}.toml"
+@lru_cache
def load_profile(path: str | Path) -> dict[str, Any]:
"""Get settings for the provided target profile."""
with open(path, "rb") as file:
@@ -55,24 +53,12 @@ def get_builtin_supported_profile_names() -> list[str]:
)
-def get_target(target_profile: str | Path) -> str:
- """Return target for the provided target_profile."""
- profile_file = get_profile_file(target_profile)
- profile = load_profile(profile_file)
- return cast(str, profile["target"])
+BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names()
-def copy_profile_file_to_output_dir(
- target_profile: str | Path, output_dir: str | Path
-) -> bool:
- """Copy the target profile file to output directory."""
- profile_file_path = get_profile_file(target_profile)
- output_file_path = f"{output_dir}/{profile_file_path.stem}.toml"
- try:
- copy(profile_file_path, output_file_path)
- return True
- except OSError as err:
- raise RuntimeError("Failed to copy profile file:", err.strerror) from err
+def is_builtin_profile(profile_name: str | Path) -> bool:
+ """Check if the given profile name belongs to a built-in profile."""
+ return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES
T = TypeVar("T", bound="TargetProfile")
@@ -88,21 +74,29 @@ class TargetProfile(ABC):
@classmethod
def load(cls: type[T], path: str | Path) -> T:
"""Load and verify a target profile from file and return new instance."""
- profile = load_profile(path)
+ profile_data = load_profile(path)
try:
- new_instance = cls(**profile)
+ new_instance = cls.load_json_data(profile_data)
except KeyError as ex:
raise KeyError(f"Missing key in file {path}.") from ex
- new_instance.verify()
+ return new_instance
+ @classmethod
+ def load_json_data(cls: type[T], profile_data: dict) -> T:
+ """Load a target profile from the JSON data."""
+ new_instance = cls(**profile_data)
+ new_instance.verify()
return new_instance
@classmethod
- def load_profile(cls: type[T], target_profile: str) -> T:
- """Load a target profile by name."""
- profile_file = get_profile_file(target_profile)
+ def load_profile(cls: type[T], target_profile: str | Path) -> T:
+ """Load a target profile from built-in target profile name or file path."""
+ if is_builtin_profile(target_profile):
+ profile_file = get_builtin_profile_path(cast(str, target_profile))
+ else:
+ profile_file = Path(target_profile)
return cls.load(profile_file)
def save(self, path: str | Path) -> None:
@@ -125,6 +119,9 @@ class TargetInfo:
"""Collect information about supported targets."""
supported_backends: list[str]
+ default_backends: list[str]
+ advisor_factory_func: Callable[..., InferenceAdvisor]
+ target_profile_cls: type[TargetProfile]
def __str__(self) -> str:
"""List supported backends."""
@@ -135,7 +132,8 @@ class TargetInfo:
) -> bool:
"""Check if any of the supported backends support this kind of advice."""
return any(
- backend_registry.items[name].is_supported(advice, check_system)
+ name in backend_registry.items
+ and backend_registry.items[name].is_supported(advice, check_system)
for name in self.supported_backends
)
@@ -146,5 +144,6 @@ class TargetInfo:
return [
name
for name in self.supported_backends
- if backend_registry.items[name].is_supported(advice, check_system)
+ if name in backend_registry.items
+ and backend_registry.items[name].is_supported(advice, check_system)
]
diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py
index f686bfc..87f268a 100644
--- a/src/mlia/target/cortex_a/__init__.py
+++ b/src/mlia/target/cortex_a/__init__.py
@@ -1,7 +1,17 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A target module."""
+from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
+from mlia.target.cortex_a.config import CortexAConfiguration
from mlia.target.registry import registry
from mlia.target.registry import TargetInfo
-registry.register("cortex-a", TargetInfo(["ArmNNTFLiteDelegate"]))
+registry.register(
+ "cortex-a",
+ TargetInfo(
+ supported_backends=["ArmNNTFLiteDelegate"],
+ default_backends=["ArmNNTFLiteDelegate"],
+ advisor_factory_func=configure_and_get_cortexa_advisor,
+ target_profile_cls=CortexAConfiguration,
+ ),
+)
diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py
index 518c9f1..a093784 100644
--- a/src/mlia/target/cortex_a/advisor.py
+++ b/src/mlia/target/cortex_a/advisor.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any
+from typing import cast
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
@@ -21,6 +22,7 @@ from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer
from mlia.target.cortex_a.data_collection import CortexAOperatorCompatibility
from mlia.target.cortex_a.events import CortexAAdvisorStartedEvent
from mlia.target.cortex_a.handlers import CortexAEventHandler
+from mlia.target.registry import profile
class CortexAInferenceAdvisor(DefaultInferenceAdvisor):
@@ -59,7 +61,7 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor):
return [
CortexAAdvisorStartedEvent(
- model, CortexAConfiguration.load_profile(target_profile)
+ model, cast(CortexAConfiguration, profile(target_profile))
),
]
diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py
index d53be53..6b6777d 100644
--- a/src/mlia/target/ethos_u/__init__.py
+++ b/src/mlia/target/ethos_u/__init__.py
@@ -1,8 +1,23 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U target module."""
+from mlia.backend.corstone import CORSTONE_PRIORITY
+from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
+from mlia.target.ethos_u.config import EthosUConfiguration
+from mlia.target.ethos_u.config import get_default_ethos_u_backends
from mlia.target.registry import registry
from mlia.target.registry import TargetInfo
-registry.register("ethos-u55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
-registry.register("ethos-u65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
+SUPPORTED_BACKENDS_PRIORITY = ["Vela", *CORSTONE_PRIORITY]
+
+
+for ethos_u in ("ethos-u55", "ethos-u65"):
+ registry.register(
+ ethos_u,
+ TargetInfo(
+ supported_backends=SUPPORTED_BACKENDS_PRIORITY,
+ default_backends=get_default_ethos_u_backends(SUPPORTED_BACKENDS_PRIORITY),
+ advisor_factory_func=configure_and_get_ethosu_advisor,
+ target_profile_cls=EthosUConfiguration,
+ ),
+ )
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index 225fd87..5f23fdd 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any
+from typing import cast
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
@@ -25,6 +26,7 @@ from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance
from mlia.target.ethos_u.data_collection import EthosUPerformance
from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent
from mlia.target.ethos_u.handlers import EthosUEventHandler
+from mlia.target.registry import profile
from mlia.utils.types import is_list_of
@@ -96,7 +98,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def _get_device_cfg(self, context: Context) -> EthosUConfiguration:
"""Get device configuration."""
target_profile = self.get_target_profile(context)
- return EthosUConfiguration.load_profile(target_profile)
+ return cast(EthosUConfiguration, profile(target_profile))
def _get_optimization_settings(self, context: Context) -> list[list[dict]]:
"""Get optimization settings."""
diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py
index eb5691d..d1a2c7a 100644
--- a/src/mlia/target/ethos_u/config.py
+++ b/src/mlia/target/ethos_u/config.py
@@ -6,12 +6,13 @@ from __future__ import annotations
import logging
from typing import Any
+from mlia.backend.corstone import is_corstone_backend
+from mlia.backend.manager import get_available_backends
from mlia.backend.vela.compiler import resolve_compiler_config
from mlia.backend.vela.compiler import VelaCompilerOptions
from mlia.target.config import TargetProfile
from mlia.utils.filesystem import get_vela_config
-
logger = logging.getLogger(__name__)
@@ -67,3 +68,22 @@ class EthosUConfiguration(TargetProfile):
def __repr__(self) -> str:
"""Return string representation."""
return f"<Ethos-U configuration target={self.target}>"
+
+
+def get_default_ethos_u_backends(
+ supported_backends_priority_order: list[str],
+) -> list[str]:
+ """Return default backends for Ethos-U targets."""
+ available_backends = get_available_backends()
+
+ default_backends = []
+ corstone_added = False
+ for backend in supported_backends_priority_order:
+ if backend not in available_backends:
+ continue
+ if is_corstone_backend(backend):
+ if corstone_added:
+ continue # only add one Corstone backend
+ corstone_added = True
+ default_backends.append(backend)
+ return default_backends
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py
index 4870fc8..9fccecb 100644
--- a/src/mlia/target/registry.py
+++ b/src/mlia/target/registry.py
@@ -3,17 +3,78 @@
"""Target module."""
from __future__ import annotations
+from functools import lru_cache
+from pathlib import Path
+from typing import cast
+
from mlia.backend.config import BackendType
from mlia.backend.manager import get_installation_manager
from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.reporting import Column
from mlia.core.reporting import Table
+from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES
+from mlia.target.config import get_builtin_profile_path
+from mlia.target.config import is_builtin_profile
+from mlia.target.config import load_profile
from mlia.target.config import TargetInfo
+from mlia.target.config import TargetProfile
from mlia.utils.registry import Registry
+
+class TargetRegistry(Registry[TargetInfo]):
+ """Registry for targets."""
+
+ def register(self, name: str, item: TargetInfo) -> bool:
+ """Register an item: returns `False` if already registered."""
+ assert all(
+ backend in backend_registry.items for backend in item.supported_backends
+ )
+ return super().register(name, item)
+
+
# All supported targets are required to be registered here.
-registry = Registry[TargetInfo]()
+registry = TargetRegistry()
+
+
+def builtin_profile_names() -> list[str]:
+ """Return a list of built-in profile names (not file paths)."""
+ return BUILTIN_SUPPORTED_PROFILE_NAMES
+
+
+@lru_cache
+def profile(target_profile: str | Path) -> TargetProfile:
+ """Get the target profile data (built-in or custom file)."""
+ if not target_profile:
+ raise ValueError("No valid target profile was provided.")
+ if is_builtin_profile(target_profile):
+ profile_file = get_builtin_profile_path(cast(str, target_profile))
+ profile_ = create_target_profile(profile_file)
+ else:
+ profile_file = Path(target_profile)
+ if profile_file.is_file():
+ profile_ = create_target_profile(profile_file)
+ else:
+ raise ValueError(
+ f"Profile '{target_profile}' is neither a valid built-in "
+ "target profile name or a valid file path."
+ )
+
+ return profile_
+
+
+def get_target(target_profile: str | Path) -> str:
+ """Return target for the provided target_profile."""
+ return profile(target_profile).target
+
+
+@lru_cache
+def create_target_profile(path: Path) -> TargetProfile:
+ """Create a new instance of a TargetProfile from the file."""
+ profile_data = load_profile(path)
+ target = profile_data["target"]
+ target_info = registry.items[target]
+ return target_info.target_profile_cls.load_json_data(profile_data)
def supported_advice(target: str) -> list[AdviceCategory]:
@@ -29,6 +90,11 @@ def supported_backends(target: str) -> list[str]:
return registry.items[target].filter_supported_backends(check_system=False)
+def default_backends(target: str) -> list[str]:
+ """Get a list of default backends for the given target."""
+ return registry.items[target].default_backends
+
+
def get_backend_to_supported_targets() -> dict[str, list]:
"""Get a dict that maps a list of supported targets given backend."""
targets = dict(registry.items)
diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py
index 06bf1a9..3830ce5 100644
--- a/src/mlia/target/tosa/__init__.py
+++ b/src/mlia/target/tosa/__init__.py
@@ -3,5 +3,15 @@
"""TOSA target module."""
from mlia.target.registry import registry
from mlia.target.registry import TargetInfo
+from mlia.target.tosa.advisor import configure_and_get_tosa_advisor
+from mlia.target.tosa.config import TOSAConfiguration
-registry.register("tosa", TargetInfo(["tosa-checker"]))
+registry.register(
+ "tosa",
+ TargetInfo(
+ supported_backends=["tosa-checker"],
+ default_backends=["tosa-checker"],
+ advisor_factory_func=configure_and_get_tosa_advisor,
+ target_profile_cls=TOSAConfiguration,
+ ),
+)
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index 5588d0f..5fb18ed 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any
+from typing import cast
from mlia.core.advice_generation import AdviceCategory
from mlia.core.advice_generation import AdviceProducer
@@ -17,6 +18,7 @@ from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
from mlia.core.metadata import MLIAMetadata
from mlia.core.metadata import ModelMetadata
+from mlia.target.registry import profile
from mlia.target.tosa.advice_generation import TOSAAdviceProducer
from mlia.target.tosa.config import TOSAConfiguration
from mlia.target.tosa.data_analysis import TOSADataAnalyzer
@@ -66,7 +68,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
return [
TOSAAdvisorStartedEvent(
model,
- TOSAConfiguration.load_profile(target_profile),
+ cast(TOSAConfiguration, profile(target_profile)),
MetadataDisplay(
TOSAMetadata("tosa-checker"),
MLIAMetadata("mlia"),
diff --git a/tests/conftest.py b/tests/conftest.py
index d797869..55b296f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -83,8 +83,11 @@ def fixture_test_models_path(
save_tflite_model(tflite_model, tflite_model_path)
tflite_vela_model = tmp_path / "test_model_vela.tflite"
- device = EthosUConfiguration.load_profile("ethos-u55-256")
- optimize_model(tflite_model_path, device.compiler_options, tflite_vela_model)
+
+ target_profile = EthosUConfiguration.load_profile("ethos-u55-256")
+ optimize_model(
+ tflite_model_path, target_profile.compiler_options, tflite_vela_model
+ )
tf.saved_model.save(keras_model, str(tmp_path / "tf_model_test_model"))
diff --git a/tests/test_api.py b/tests/test_api.py
index 251d5ac..b40c55b 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -19,12 +19,8 @@ from mlia.target.tosa.advisor import TOSAInferenceAdvisor
def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
"""Test getting advice when no target provided."""
- with pytest.raises(Exception, match="Target profile is not provided"):
- get_advice(
- None, # type:ignore
- test_keras_model,
- {"compatibility"},
- )
+ with pytest.raises(Exception, match="No valid target profile was provided."):
+ get_advice(None, test_keras_model, {"compatibility"}) # type: ignore
def test_get_advice_wrong_category(test_keras_model: Path) -> None:
diff --git a/tests/test_cli_command_validators.py b/tests/test_cli_command_validators.py
index cd048ee..29813f4 100644
--- a/tests/test_cli_command_validators.py
+++ b/tests/test_cli_command_validators.py
@@ -5,7 +5,6 @@ from __future__ import annotations
import argparse
from contextlib import ExitStack
-from unittest.mock import MagicMock
import pytest
@@ -93,72 +92,66 @@ def test_validate_check_target_profile(
@pytest.mark.parametrize(
- "input_target_profile, input_backends, throws_exception,"
- "exception_message, output_backends",
+ (
+ "input_target_profile",
+ "input_backends",
+ "throws_exception",
+ "exception_message",
+ "output_backends",
+ ),
[
[
"tosa",
- ["Vela"],
- True,
- "Vela backend not supported with target-profile tosa.",
+ ["tosa-checker"],
+ False,
None,
+ ["tosa-checker"],
],
[
"tosa",
- ["Corstone-300, Vela"],
+ ["Corstone-310"],
True,
- "Corstone-300, Vela backend not supported with target-profile tosa.",
+ "Backend Corstone-310 not supported with target-profile tosa.",
None,
],
[
"cortex-a",
- ["Corstone-310", "tosa-checker"],
- True,
- "Corstone-310, tosa-checker backend not supported "
- "with target-profile cortex-a.",
+ ["ArmNNTFLiteDelegate"],
+ False,
None,
+ ["ArmNNTFLiteDelegate"],
],
[
- "ethos-u55-256",
- ["tosa-checker", "Corstone-310"],
+ "cortex-a",
+ ["tosa-checker"],
True,
- "tosa-checker backend not supported with target-profile ethos-u55-256.",
+ "Backend tosa-checker not supported with target-profile cortex-a.",
None,
],
- ["tosa", None, False, None, ["tosa-checker"]],
- ["cortex-a", None, False, None, ["ArmNNTFLiteDelegate"]],
- ["tosa", ["tosa-checker"], False, None, ["tosa-checker"]],
- ["cortex-a", ["ArmNNTFLiteDelegate"], False, None, ["ArmNNTFLiteDelegate"]],
[
"ethos-u55-256",
- ["Vela", "Corstone-300"],
+ ["Vela", "Corstone-310"],
False,
None,
- ["Vela", "Corstone-300"],
+ ["Vela", "Corstone-310"],
],
[
- "ethos-u55-256",
- None,
- False,
+ "ethos-u65-256",
+ ["Vela", "Corstone-310", "tosa-checker"],
+ True,
+ "Backend tosa-checker not supported with target-profile ethos-u65-256.",
None,
- ["Vela", "Corstone-300"],
],
],
)
def test_validate_backend(
- monkeypatch: pytest.MonkeyPatch,
input_target_profile: str,
- input_backends: list[str] | None,
+ input_backends: list[str],
throws_exception: bool,
exception_message: str,
output_backends: list[str] | None,
) -> None:
"""Test backend validation with target-profiles and backends."""
- monkeypatch.setattr(
- "mlia.cli.config.get_available_backends",
- MagicMock(return_value=["Vela", "Corstone-300"]),
- )
-
exit_stack = ExitStack()
if throws_exception:
exit_stack.enter_context(
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 61cc5a6..f3213c4 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -33,7 +33,13 @@ def test_performance_unknown_target(
sample_context: ExecutionContext, test_tflite_model: Path
) -> None:
"""Test that command should fail if unknown target passed."""
- with pytest.raises(Exception, match=r"File not found:*"):
+ with pytest.raises(
+ Exception,
+ match=(
+ r"Profile 'unknown' is neither a valid built-in target profile "
+ r"name or a valid file path."
+ ),
+ ):
check(
sample_context,
model=str(test_tflite_model),
diff --git a/tests/test_cli_config.py b/tests/test_cli_config.py
deleted file mode 100644
index 8494d73..0000000
--- a/tests/test_cli_config.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Tests for cli.config module."""
-from __future__ import annotations
-
-from unittest.mock import MagicMock
-
-import pytest
-
-from mlia.cli.config import get_default_backends
-
-
-@pytest.mark.parametrize(
- "available_backends, expected_default_backends",
- [
- [["Vela"], ["Vela"]],
- [["Corstone-300"], ["Corstone-300"]],
- [["Corstone-310"], ["Corstone-310"]],
- [["Corstone-300", "Corstone-310"], ["Corstone-310"]],
- [["Vela", "Corstone-300", "Corstone-310"], ["Vela", "Corstone-310"]],
- [
- ["Vela", "Corstone-300", "Corstone-310", "New backend"],
- ["Vela", "Corstone-310", "New backend"],
- ],
- [
- ["Vela", "Corstone-300", "New backend"],
- ["Vela", "Corstone-300", "New backend"],
- ],
- [["ArmNNTFLiteDelegate"], ["ArmNNTFLiteDelegate"]],
- [["tosa-checker"], ["tosa-checker"]],
- [
- ["ArmNNTFLiteDelegate", "Corstone-300"],
- ["ArmNNTFLiteDelegate", "Corstone-300"],
- ],
- ],
-)
-def test_get_default_backends(
- monkeypatch: pytest.MonkeyPatch,
- available_backends: list[str],
- expected_default_backends: list[str],
-) -> None:
- """Test function get_default backends."""
- monkeypatch.setattr(
- "mlia.cli.config.get_available_backends",
- MagicMock(return_value=available_backends),
- )
-
- assert get_default_backends() == expected_default_backends
diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py
index 8f7e4b0..6d19207 100644
--- a/tests/test_cli_helpers.py
+++ b/tests/test_cli_helpers.py
@@ -3,11 +3,13 @@
"""Tests for the helper classes."""
from __future__ import annotations
+from pathlib import Path
from typing import Any
import pytest
from mlia.cli.helpers import CLIActionResolver
+from mlia.cli.helpers import copy_profile_file_to_output_dir
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
@@ -139,3 +141,12 @@ class TestCliActionResolver:
"""Test checking operator compatibility info."""
resolver = CLIActionResolver(args)
assert resolver.check_operator_compatibility() == expected_result
+
+
+def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None:
+ """Test if the profile file is copied into the output directory."""
+ test_target_profile_name = "ethos-u55-128"
+ test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml")
+
+ copy_profile_file_to_output_dir(test_target_profile_name, tmp_path)
+ assert Path.is_file(test_file_path)
diff --git a/tests/test_target_config.py b/tests/test_target_config.py
index c6235a5..368d394 100644
--- a/tests/test_target_config.py
+++ b/tests/test_target_config.py
@@ -3,35 +3,28 @@
"""Tests for the backend config module."""
from __future__ import annotations
-from pathlib import Path
-
import pytest
from mlia.backend.config import BackendConfiguration
from mlia.backend.config import BackendType
from mlia.backend.config import System
from mlia.core.common import AdviceCategory
-from mlia.target.config import copy_profile_file_to_output_dir
+from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES
+from mlia.target.config import get_builtin_profile_path
from mlia.target.config import get_builtin_supported_profile_names
-from mlia.target.config import get_profile_file
+from mlia.target.config import is_builtin_profile
from mlia.target.config import load_profile
from mlia.target.config import TargetInfo
from mlia.target.config import TargetProfile
+from mlia.target.cortex_a.advisor import CortexAInferenceAdvisor
+from mlia.target.cortex_a.config import CortexAConfiguration
from mlia.utils.registry import Registry
-def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None:
- """Test if the profile file is copied into the output directory."""
- test_target_profile_name = "ethos-u55-128"
- test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml")
-
- copy_profile_file_to_output_dir(test_target_profile_name, tmp_path)
- assert Path.is_file(test_file_path)
-
-
-def test_get_builtin_supported_profile_names() -> None:
- """Test profile names getter."""
- assert get_builtin_supported_profile_names() == [
+def test_builtin_supported_profile_names() -> None:
+ """Test built-in profile names."""
+ assert BUILTIN_SUPPORTED_PROFILE_NAMES == get_builtin_supported_profile_names()
+ assert BUILTIN_SUPPORTED_PROFILE_NAMES == [
"cortex-a",
"ethos-u55-128",
"ethos-u55-256",
@@ -39,23 +32,24 @@ def test_get_builtin_supported_profile_names() -> None:
"ethos-u65-512",
"tosa",
]
+ for profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES:
+ assert is_builtin_profile(profile_name)
+ profile_file = get_builtin_profile_path(profile_name)
+ assert profile_file.is_file()
-def test_get_profile_file() -> None:
- """Test function 'get_profile_file'."""
- profile_file = get_profile_file("cortex-a")
+def test_builtin_profile_files() -> None:
+ """Test function 'get_bulitin_profile_file'."""
+ profile_file = get_builtin_profile_path("cortex-a")
assert profile_file.is_file()
- assert profile_file == get_profile_file(profile_file)
- with pytest.raises(Exception):
- get_profile_file("UNKNOWN")
- with pytest.raises(Exception):
- get_profile_file("")
+ profile_file = get_builtin_profile_path("UNKNOWN_FILE_THAT_DOES_NOT_EXIST")
+ assert not profile_file.exists()
def test_load_profile() -> None:
"""Test getting profile data."""
- profile_file = get_profile_file("ethos-u55-256")
+ profile_file = get_builtin_profile_path("ethos-u55-256")
assert load_profile(profile_file) == {
"target": "ethos-u55",
"mac": 256,
@@ -80,6 +74,9 @@ def test_target_profile() -> None:
profile = MyTargetProfile("AnyTarget")
assert profile.target == "AnyTarget"
+ profile = MyTargetProfile.load_json_data({"target": "MySuperTarget"})
+ assert profile.target == "MySuperTarget"
+
profile = MyTargetProfile("")
with pytest.raises(ValueError):
profile.verify()
@@ -101,7 +98,12 @@ def test_target_info(
supported: bool,
) -> None:
"""Test the class 'TargetInfo'."""
- info = TargetInfo(["backend"])
+ info = TargetInfo(
+ ["backend"],
+ ["backend"],
+ CortexAInferenceAdvisor,
+ CortexAConfiguration,
+ )
backend_registry = Registry[BackendConfiguration]()
backend_registry.register(
@@ -116,3 +118,12 @@ def test_target_info(
assert info.is_supported(advice, check_system) == supported
assert bool(info.filter_supported_backends(advice, check_system)) == supported
+
+ info = TargetInfo(
+ ["unknown_backend"],
+ ["unknown_backend"],
+ CortexAInferenceAdvisor,
+ CortexAConfiguration,
+ )
+ assert not info.is_supported(advice, check_system)
+ assert not info.filter_supported_backends(advice, check_system)
diff --git a/tests/test_target_ethos_u_data_analysis.py b/tests/test_target_ethos_u_data_analysis.py
index e919f5d..8e63946 100644
--- a/tests/test_target_ethos_u_data_analysis.py
+++ b/tests/test_target_ethos_u_data_analysis.py
@@ -3,6 +3,8 @@
"""Tests for Ethos-U data analysis module."""
from __future__ import annotations
+from typing import cast
+
import pytest
from mlia.backend.vela.compat import NpuSupported
@@ -23,6 +25,7 @@ from mlia.target.ethos_u.performance import MemoryUsage
from mlia.target.ethos_u.performance import NPUCycles
from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics
from mlia.target.ethos_u.performance import PerformanceMetrics
+from mlia.target.registry import profile
def test_perf_metrics_diff() -> None:
@@ -84,7 +87,7 @@ def test_perf_metrics_diff() -> None:
[
OptimizationPerformanceMetrics(
PerformanceMetrics(
- EthosUConfiguration.load_profile("ethos-u55-256"),
+ cast(EthosUConfiguration, profile("ethos-u55-256")),
NPUCycles(1, 2, 3, 4, 5, 6),
# memory metrics are in kilobytes
MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore
@@ -95,7 +98,7 @@ def test_perf_metrics_diff() -> None:
OptimizationSettings("pruning", 0.5, None),
],
PerformanceMetrics(
- EthosUConfiguration.load_profile("ethos-u55-256"),
+ cast(EthosUConfiguration, profile("ethos-u55-256")),
NPUCycles(1, 2, 3, 4, 5, 6),
# memory metrics are in kilobytes
MemoryUsage(
@@ -127,7 +130,7 @@ def test_perf_metrics_diff() -> None:
[
OptimizationPerformanceMetrics(
PerformanceMetrics(
- EthosUConfiguration.load_profile("ethos-u55-256"),
+ cast(EthosUConfiguration, profile("ethos-u55-256")),
NPUCycles(1, 2, 3, 4, 5, 6),
# memory metrics are in kilobytes
MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore
diff --git a/tests/test_target_ethos_u_reporters.py b/tests/test_target_ethos_u_reporters.py
index 0c5764e..9707dff 100644
--- a/tests/test_target_ethos_u_reporters.py
+++ b/tests/test_target_ethos_u_reporters.py
@@ -3,6 +3,8 @@
"""Tests for reports module."""
from __future__ import annotations
+from typing import cast
+
import pytest
from mlia.backend.vela.compat import NpuSupported
@@ -12,6 +14,7 @@ from mlia.core.reporting import Table
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.reporters import report_device_details
from mlia.target.ethos_u.reporters import report_operators
+from mlia.target.registry import profile
from mlia.utils.console import remove_ascii_codes
@@ -118,7 +121,7 @@ def test_report_operators(
"device, expected_plain_text, expected_json_dict",
[
[
- EthosUConfiguration.load_profile("ethos-u55-256"),
+ cast(EthosUConfiguration, profile("ethos-u55-256")),
"""Device information:
Target ethos-u55
MAC 256
diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py
index 5012148..2cbd97d 100644
--- a/tests/test_target_registry.py
+++ b/tests/test_target_registry.py
@@ -6,7 +6,11 @@ from __future__ import annotations
import pytest
from mlia.core.common import AdviceCategory
+from mlia.target.config import get_builtin_profile_path
from mlia.target.registry import all_supported_backends
+from mlia.target.registry import default_backends
+from mlia.target.registry import is_supported
+from mlia.target.registry import profile
from mlia.target.registry import registry
from mlia.target.registry import supported_advice
from mlia.target.registry import supported_backends
@@ -57,6 +61,23 @@ def test_supported_advice(
@pytest.mark.parametrize(
+ ("backend", "target", "expected_result"),
+ (
+ ("ArmNNTFLiteDelegate", None, True),
+ ("ArmNNTFLiteDelegate", "cortex-a", True),
+ ("ArmNNTFLiteDelegate", "tosa", False),
+ ("Corstone-310", None, True),
+ ("Corstone-310", "ethos-u55", True),
+ ("Corstone-310", "ethos-u65", True),
+ ("Corstone-310", "cortex-a", False),
+ ),
+)
+def test_is_supported(backend: str, target: str | None, expected_result: bool) -> None:
+ """Test function is_supported()."""
+ assert is_supported(backend, target) == expected_result
+
+
+@pytest.mark.parametrize(
("target_name", "expected_backends"),
(
("cortex-a", ["ArmNNTFLiteDelegate"]),
@@ -92,3 +113,39 @@ def test_all_supported_backends() -> None:
"Corstone-310",
"Corstone-300",
}
+
+
+@pytest.mark.parametrize(
+ ("target", "expected_default_backends", "is_subset_only"),
+ [
+ ["cortex-a", ["ArmNNTFLiteDelegate"], False],
+ ["tosa", ["tosa-checker"], False],
+ ["ethos-u55", ["Vela"], True],
+ ["ethos-u65", ["Vela"], True],
+ ],
+)
+def test_default_backends(
+ target: str,
+ expected_default_backends: list[str],
+ is_subset_only: bool,
+) -> None:
+ """Test function default_backends()."""
+ if is_subset_only:
+ assert set(expected_default_backends).issubset(default_backends(target))
+ else:
+ assert default_backends(target) == expected_default_backends
+
+
+@pytest.mark.parametrize(
+ "target_profile", ("cortex-a", "tosa", "ethos-u55-128", "ethos-u65-256")
+)
+def test_profile(target_profile: str) -> None:
+ """Test function profile()."""
+ # Test loading by built-in profile name
+ cfg = profile(target_profile)
+ assert target_profile.startswith(cfg.target)
+
+ # Test loading the file directly
+ profile_file = get_builtin_profile_path(target_profile)
+ cfg = profile(profile_file)
+ assert target_profile.startswith(cfg.target)
diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py
index d09d0ab..37a1833 100644
--- a/tests_e2e/test_e2e.py
+++ b/tests_e2e/test_e2e.py
@@ -20,7 +20,7 @@ from typing import Sequence
import pytest
-from mlia.cli.config import get_available_backends
+from mlia.backend.manager import get_available_backends
from mlia.cli.main import get_commands
from mlia.cli.main import get_possible_command_names
from mlia.cli.main import init_parser