aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli/options.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/cli/options.py')
-rw-r--r--src/mlia/cli/options.py43
1 files changed, 32 insertions, 11 deletions
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 421533a..8cd2935 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -7,13 +7,17 @@ import argparse
from pathlib import Path
from typing import Any
from typing import Callable
+from typing import Sequence
from mlia.backend.corstone import is_corstone_backend
-from mlia.cli.config import DEFAULT_CLUSTERING_TARGET
-from mlia.cli.config import DEFAULT_PRUNING_TARGET
-from mlia.cli.config import get_available_backends
+from mlia.backend.manager import get_available_backends
+from mlia.core.common import AdviceCategory
from mlia.core.typing import OutputFormat
-from mlia.target.config import get_builtin_supported_profile_names
+from mlia.target.registry import builtin_profile_names
+from mlia.target.registry import registry as target_registry
+
+DEFAULT_PRUNING_TARGET = 0.5
+DEFAULT_CLUSTERING_TARGET = 32
def add_check_category_options(parser: argparse.ArgumentParser) -> None:
@@ -31,22 +35,39 @@ def add_check_category_options(parser: argparse.ArgumentParser) -> None:
def add_target_options(
parser: argparse.ArgumentParser,
- profiles_to_skip: list[str] | None = None,
+ supported_advice: Sequence[AdviceCategory] | None = None,
required: bool = True,
) -> None:
"""Add target specific options."""
- target_profiles = get_builtin_supported_profile_names()
- if profiles_to_skip:
- target_profiles = [tp for tp in target_profiles if tp not in profiles_to_skip]
-
- default_target_profile = "ethos-u55-256"
+ target_profiles = builtin_profile_names()
+
+ if supported_advice:
+
+ def is_advice_supported(profile: str, advice: Sequence[AdviceCategory]) -> bool:
+ """
+ Collect all target profiles that support the advice.
+
+ This means target profiles that...
+ - have the right target prefix, e.g. "ethos-u55..." to avoid loading
+ all target profiles
+ - support any of the required advice
+ """
+ for target, info in target_registry.items.items():
+ if profile.startswith(target):
+ return any(info.is_supported(adv) for adv in advice)
+ return False
+
+ target_profiles = [
+ profile
+ for profile in target_profiles
+ if is_advice_supported(profile, supported_advice)
+ ]
target_group = parser.add_argument_group("target options")
target_group.add_argument(
"-t",
"--target-profile",
required=required,
- default=default_target_profile,
help="Built-in target profile or path to the custom target profile. "
f"Built-in target profiles are {', '.join(target_profiles)}. "
"Target profile that will set the target options "