aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/cli')
-rw-r--r--src/mlia/cli/command_validators.py20
-rw-r--r--src/mlia/cli/commands.py2
-rw-r--r--src/mlia/cli/config.py69
-rw-r--r--src/mlia/cli/helpers.py24
-rw-r--r--src/mlia/cli/main.py15
-rw-r--r--src/mlia/cli/options.py43
6 files changed, 76 insertions, 97 deletions
diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py
index 23101e0..a0f5433 100644
--- a/src/mlia/cli/command_validators.py
+++ b/src/mlia/cli/command_validators.py
@@ -7,8 +7,8 @@ import argparse
import logging
import sys
-from mlia.cli.config import get_default_backends_dict
-from mlia.target.config import get_target
+from mlia.target.registry import default_backends
+from mlia.target.registry import get_target
from mlia.target.registry import supported_backends
logger = logging.getLogger(__name__)
@@ -26,22 +26,18 @@ def validate_backend(
target = get_target(target_profile)
if not backend:
- return get_default_backends_dict()[target]
+ return default_backends(target)
- compatible_backends = supported_backends(target)
+ compatible_backends = list(map(normalize_string, supported_backends(target)))
+ backends = {normalize_string(b): b for b in backend}
- nor_backend = list(map(normalize_string, backend))
- nor_compat_backend = list(map(normalize_string, compatible_backends))
-
- incompatible_backends = [
- backend[i] for i, x in enumerate(nor_backend) if x not in nor_compat_backend
- ]
+ incompatible_backends = [b for b in backends if b not in compatible_backends]
# Throw an error if any unsupported backends are used
if incompatible_backends:
raise argparse.ArgumentError(
None,
- f"{', '.join(incompatible_backends)} backend not supported "
- f"with target-profile {target_profile}.",
+ f"Backend {', '.join(backends[b] for b in incompatible_backends)} "
+ f"not supported with target-profile {target_profile}.",
)
return backend
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index c17d571..27f5b2b 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -23,9 +23,9 @@ from pathlib import Path
from mlia.api import ExecutionContext
from mlia.api import get_advice
+from mlia.backend.manager import get_installation_manager
from mlia.cli.command_validators import validate_backend
from mlia.cli.command_validators import validate_check_target_profile
-from mlia.cli.config import get_installation_manager
from mlia.cli.options import parse_optimization_parameters
from mlia.utils.console import create_section_header
diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py
deleted file mode 100644
index 433300c..0000000
--- a/src/mlia/cli/config.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Environment configuration functions."""
-from __future__ import annotations
-
-import logging
-
-from mlia.backend.manager import get_installation_manager
-from mlia.target.registry import all_supported_backends
-
-logger = logging.getLogger(__name__)
-
-DEFAULT_PRUNING_TARGET = 0.5
-DEFAULT_CLUSTERING_TARGET = 32
-
-
-def get_available_backends() -> list[str]:
- """Return list of the available backends."""
- available_backends = ["Vela", "ArmNNTFLiteDelegate"]
-
- # Add backends using backend manager
- manager = get_installation_manager()
- available_backends.extend(
- backend
- for backend in all_supported_backends()
- if manager.backend_installed(backend)
- )
-
- return available_backends
-
-
-# List of mutually exclusive Corstone backends ordered by priority
-_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300")
-_NON_ETHOS_U_BACKENDS = ("tosa-checker", "ArmNNTFLiteDelegate")
-
-
-def get_ethos_u_default_backends(backends: list[str]) -> list[str]:
- """Get Ethos-U default backends for evaluation."""
- return [x for x in backends if x not in _NON_ETHOS_U_BACKENDS]
-
-
-def get_default_backends() -> list[str]:
- """Get default backends for evaluation."""
- backends = get_available_backends()
-
- # Filter backends to only include one Corstone backend
- for corstone in _CORSTONE_EXCLUSIVE_PRIORITY:
- if corstone in backends:
- backends = [
- backend
- for backend in backends
- if backend == corstone or backend not in _CORSTONE_EXCLUSIVE_PRIORITY
- ]
- break
-
- return backends
-
-
-def get_default_backends_dict() -> dict[str, list[str]]:
- """Return default backends for all targets."""
- default_backends = get_default_backends()
- ethos_u_defaults = get_ethos_u_default_backends(default_backends)
-
- return {
- "ethos-u55": ethos_u_defaults,
- "ethos-u65": ethos_u_defaults,
- "tosa": ["tosa-checker"],
- "cortex-a": ["ArmNNTFLiteDelegate"],
- }
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
index ac64581..576670b 100644
--- a/src/mlia/cli/helpers.py
+++ b/src/mlia/cli/helpers.py
@@ -1,14 +1,19 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Module for various helper classes."""
+"""Module for various helpers."""
from __future__ import annotations
+from pathlib import Path
+from shutil import copy
from typing import Any
+from typing import cast
from mlia.cli.options import get_target_profile_opts
from mlia.core.helpers import ActionResolver
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.target.config import get_builtin_profile_path
+from mlia.target.config import is_builtin_profile
from mlia.utils.types import is_list_of
@@ -108,3 +113,20 @@ class CLIActionResolver(ActionResolver):
model_path = self.args.get("model")
return model_path, device_opts
+
+
+def copy_profile_file_to_output_dir(
+ target_profile: str | Path, output_dir: str | Path
+) -> bool:
+ """Copy the target profile file to the output directory."""
+ profile_file_path = (
+ get_builtin_profile_path(cast(str, target_profile))
+ if is_builtin_profile(target_profile)
+ else Path(target_profile)
+ )
+ output_file_path = f"{output_dir}/{profile_file_path.stem}.toml"
+ try:
+ copy(profile_file_path, output_file_path)
+ return True
+ except OSError as err:
+ raise RuntimeError("Failed to copy profile file:", err.strerror) from err
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 793e155..b3a9d4c 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -18,6 +18,7 @@ from mlia.cli.commands import check
from mlia.cli.commands import optimize
from mlia.cli.common import CommandInfo
from mlia.cli.helpers import CLIActionResolver
+from mlia.cli.helpers import copy_profile_file_to_output_dir
from mlia.cli.options import add_backend_install_options
from mlia.cli.options import add_backend_options
from mlia.cli.options import add_backend_uninstall_options
@@ -30,11 +31,11 @@ from mlia.cli.options import add_output_directory
from mlia.cli.options import add_output_options
from mlia.cli.options import add_target_options
from mlia.cli.options import get_output_format
+from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
from mlia.core.logging import setup_logging
-from mlia.target.config import copy_profile_file_to_output_dir
from mlia.target.registry import table as target_table
@@ -59,7 +60,13 @@ def get_commands() -> list[CommandInfo]:
[
add_output_directory,
add_model_options,
- add_target_options,
+ partial(
+ add_target_options,
+ supported_advice=[
+ AdviceCategory.COMPATIBILITY,
+ AdviceCategory.PERFORMANCE,
+ ],
+ ),
add_backend_options,
add_check_category_options,
add_output_options,
@@ -72,7 +79,9 @@ def get_commands() -> list[CommandInfo]:
[
add_output_directory,
add_keras_model_options,
- partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]),
+ partial(
+ add_target_options, supported_advice=[AdviceCategory.OPTIMIZATION]
+ ),
partial(
add_backend_options,
backends_to_skip=["tosa-checker", "ArmNNTFLiteDelegate"],
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 421533a..8cd2935 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -7,13 +7,17 @@ import argparse
from pathlib import Path
from typing import Any
from typing import Callable
+from typing import Sequence
from mlia.backend.corstone import is_corstone_backend
-from mlia.cli.config import DEFAULT_CLUSTERING_TARGET
-from mlia.cli.config import DEFAULT_PRUNING_TARGET
-from mlia.cli.config import get_available_backends
+from mlia.backend.manager import get_available_backends
+from mlia.core.common import AdviceCategory
from mlia.core.typing import OutputFormat
-from mlia.target.config import get_builtin_supported_profile_names
+from mlia.target.registry import builtin_profile_names
+from mlia.target.registry import registry as target_registry
+
+DEFAULT_PRUNING_TARGET = 0.5
+DEFAULT_CLUSTERING_TARGET = 32
def add_check_category_options(parser: argparse.ArgumentParser) -> None:
@@ -31,22 +35,39 @@ def add_check_category_options(parser: argparse.ArgumentParser) -> None:
def add_target_options(
parser: argparse.ArgumentParser,
- profiles_to_skip: list[str] | None = None,
+ supported_advice: Sequence[AdviceCategory] | None = None,
required: bool = True,
) -> None:
"""Add target specific options."""
- target_profiles = get_builtin_supported_profile_names()
- if profiles_to_skip:
- target_profiles = [tp for tp in target_profiles if tp not in profiles_to_skip]
-
- default_target_profile = "ethos-u55-256"
+ target_profiles = builtin_profile_names()
+
+ if supported_advice:
+
+ def is_advice_supported(profile: str, advice: Sequence[AdviceCategory]) -> bool:
+ """
+ Collect all target profiles that support the advice.
+
+ This means target profiles that...
+ - have the right target prefix, e.g. "ethos-u55..." to avoid loading
+ all target profiles
+ - support any of the required advice
+ """
+ for target, info in target_registry.items.items():
+ if profile.startswith(target):
+ return any(info.is_supported(adv) for adv in advice)
+ return False
+
+ target_profiles = [
+ profile
+ for profile in target_profiles
+ if is_advice_supported(profile, supported_advice)
+ ]
target_group = parser.add_argument_group("target options")
target_group.add_argument(
"-t",
"--target-profile",
required=required,
- default=default_target_profile,
help="Built-in target profile or path to the custom target profile. "
f"Built-in target profiles are {', '.join(target_profiles)}. "
"Target profile that will set the target options "