aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli/command_validators.py
blob: a0f5433d3523b71dcd14c6be38ae83583509b28a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI command validators module."""
from __future__ import annotations

import argparse
import logging
import sys

from mlia.target.registry import default_backends
from mlia.target.registry import get_target
from mlia.target.registry import supported_backends

logger = logging.getLogger(__name__)


def validate_backend(
    target_profile: str, backend: list[str] | None
) -> list[str] | None:
    """Validate backend with given target profile.

    This validator checks whether the given target-profile and backend are
    compatible with each other.
    It assumes that prior checks where made on the validity of the target-profile.
    """
    target = get_target(target_profile)

    if not backend:
        return default_backends(target)

    compatible_backends = list(map(normalize_string, supported_backends(target)))
    backends = {normalize_string(b): b for b in backend}

    incompatible_backends = [b for b in backends if b not in compatible_backends]
    # Throw an error if any unsupported backends are used
    if incompatible_backends:
        raise argparse.ArgumentError(
            None,
            f"Backend {', '.join(backends[b] for b in incompatible_backends)} "
            f"not supported with target-profile {target_profile}.",
        )
    return backend


def validate_check_target_profile(target_profile: str, category: set[str]) -> None:
    """Validate whether advice category is compatible with the provided target_profile.

    This validator function raises warnings if any desired advice category is not
    compatible with the selected target profile. If no operation can be
    performed as a result of the validation, MLIA exits with error code 0.
    """
    incompatible_targets_performance: list[str] = ["tosa", "cortex-a"]
    incompatible_targets_compatibility: list[str] = []

    # Check which check operation should be performed
    try_performance = "performance" in category
    try_compatibility = "compatibility" in category

    # Cross check which of the desired operations can be performed on given
    # target-profile
    do_performance = (
        try_performance and target_profile not in incompatible_targets_performance
    )
    do_compatibility = (
        try_compatibility and target_profile not in incompatible_targets_compatibility
    )

    # Case: desired operations can be performed with given target profile
    if (try_performance == do_performance) and (try_compatibility == do_compatibility):
        return

    warning_message = "\nWARNING: "
    # Case: performance operation to be skipped
    if try_performance and not do_performance:
        warning_message += (
            "Performance checks skipped as they cannot be "
            f"performed with target profile {target_profile}."
        )

    # Case: compatibility operation to be skipped
    if try_compatibility and not do_compatibility:
        warning_message += (
            "Compatibility checks skipped as they cannot be "
            f"performed with target profile {target_profile}."
        )

    # Case: at least one operation will be performed
    if do_compatibility or do_performance:
        logger.warning(warning_message)
        return

    # Case: no operation will be performed
    warning_message += " No operation was performed."
    logger.warning(warning_message)
    sys.exit(0)


def normalize_string(value: str) -> str:
    """Given a string return the normalized version.

    E.g. Given "ToSa-cHecker" -> "tosachecker"
    """
    return value.lower().replace("-", "")