aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli/command_validators.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/cli/command_validators.py')
-rw-r--r--src/mlia/cli/command_validators.py20
1 files changed, 8 insertions, 12 deletions
diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py
index 23101e0..a0f5433 100644
--- a/src/mlia/cli/command_validators.py
+++ b/src/mlia/cli/command_validators.py
@@ -7,8 +7,8 @@ import argparse
import logging
import sys
-from mlia.cli.config import get_default_backends_dict
-from mlia.target.config import get_target
+from mlia.target.registry import default_backends
+from mlia.target.registry import get_target
from mlia.target.registry import supported_backends
logger = logging.getLogger(__name__)
@@ -26,22 +26,18 @@ def validate_backend(
target = get_target(target_profile)
if not backend:
- return get_default_backends_dict()[target]
+ return default_backends(target)
- compatible_backends = supported_backends(target)
+ compatible_backends = list(map(normalize_string, supported_backends(target)))
+ backends = {normalize_string(b): b for b in backend}
- nor_backend = list(map(normalize_string, backend))
- nor_compat_backend = list(map(normalize_string, compatible_backends))
-
- incompatible_backends = [
- backend[i] for i, x in enumerate(nor_backend) if x not in nor_compat_backend
- ]
+ incompatible_backends = [b for b in backends if b not in compatible_backends]
# Throw an error if any unsupported backends are used
if incompatible_backends:
raise argparse.ArgumentError(
None,
- f"{', '.join(incompatible_backends)} backend not supported "
- f"with target-profile {target_profile}.",
+ f"Backend {', '.join(backends[b] for b in incompatible_backends)} "
+ f"not supported with target-profile {target_profile}.",
)
return backend