aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli/options.py
blob: dc5cb73bff2c9a0ad6ab3ecc0e4d455616e774c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the CLI options."""
import argparse
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional

from mlia.cli.config import get_available_backends
from mlia.cli.config import get_default_backends
from mlia.cli.config import is_corstone_backend
from mlia.utils.filesystem import get_supported_profile_names
from mlia.utils.types import is_number


def add_target_options(parser: argparse.ArgumentParser) -> None:
    """Add target specific options."""
    target_profiles = get_supported_profile_names()

    default_target_profile = None
    default_help = ""
    if target_profiles:
        default_target_profile = target_profiles[0]
        default_help = " (default: %(default)s)"

    target_group = parser.add_argument_group("target options")
    target_group.add_argument(
        "--target-profile",
        choices=target_profiles,
        default=default_target_profile,
        help="Target profile that will set the target options "
        "such as target, mac value, memory mode, etc. "
        f"For the values associated with each target profile "
        f" please refer to the documenation {default_help}.",
    )


def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
    """Add optimization specific options."""
    multi_optimization_group = parser.add_argument_group("optimization options")

    multi_optimization_group.add_argument(
        "--optimization-type",
        default="pruning,clustering",
        help="List of the optimization types separated by comma (default: %(default)s)",
    )
    multi_optimization_group.add_argument(
        "--optimization-target",
        default="0.5,32",
        help="""List of the optimization targets separated by comma,
             (for pruning this is sparsity between (0,1),
             for clustering this is the number of clusters (positive integer))
             (default: %(default)s)""",
    )


def add_optional_tflite_model_options(parser: argparse.ArgumentParser) -> None:
    """Add optional model specific options."""
    model_group = parser.add_argument_group("TFLite model options")
    # make model parameter optional
    model_group.add_argument("model", nargs="?", help="TFLite model (optional)")


def add_tflite_model_options(parser: argparse.ArgumentParser) -> None:
    """Add model specific options."""
    model_group = parser.add_argument_group("TFLite model options")
    model_group.add_argument("model", help="TFLite model")


def add_output_options(parser: argparse.ArgumentParser) -> None:
    """Add output specific options."""
    valid_extensions = ["csv", "json"]

    def check_extension(filename: str) -> str:
        """Check extension of the provided file."""
        suffix = Path(filename).suffix
        if suffix.startswith("."):
            suffix = suffix[1:]

        if suffix.lower() not in valid_extensions:
            parser.error(f"Unsupported format '{suffix}'")

        return filename

    output_group = parser.add_argument_group("output options")
    output_group.add_argument(
        "--output",
        type=check_extension,
        help=(
            "Name of the file where report will be saved. "
            "Report format is automatically detected based on the file extension. "
            f"Supported formats are: {', '.join(valid_extensions)}"
        ),
    )


def add_debug_options(parser: argparse.ArgumentParser) -> None:
    """Add debug options."""
    debug_group = parser.add_argument_group("debug options")
    debug_group.add_argument(
        "--verbose", default=False, action="store_true", help="Produce verbose output"
    )


def add_keras_model_options(parser: argparse.ArgumentParser) -> None:
    """Add model specific options."""
    model_group = parser.add_argument_group("Keras model options")
    model_group.add_argument("model", help="Keras model")


def add_custom_supported_operators_options(parser: argparse.ArgumentParser) -> None:
    """Add custom options for the command 'operators'."""
    parser.add_argument(
        "--supported-ops-report",
        action="store_true",
        default=False,
        help=(
            "Generate the SUPPORTED_OPS.md file in the "
            "current working directory and exit"
        ),
    )


def add_backend_options(parser: argparse.ArgumentParser) -> None:
    """Add options for the backends configuration."""

    def valid_directory(param: str) -> Path:
        """Check if passed string is a valid directory path."""
        if not (dir_path := Path(param)).is_dir():
            parser.error(f"Invalid directory path {param}")

        return dir_path

    subparsers = parser.add_subparsers(title="Backend actions", dest="backend_action")
    subparsers.required = True

    install_subparser = subparsers.add_parser(
        "install", help="Install backend", allow_abbrev=False
    )
    install_type_group = install_subparser.add_mutually_exclusive_group()
    install_type_group.required = True
    install_type_group.add_argument(
        "--path", type=valid_directory, help="Path to the installed backend"
    )
    install_type_group.add_argument(
        "--download",
        default=False,
        action="store_true",
        help="Download and install backend",
    )
    install_subparser.add_argument(
        "--i-agree-to-the-contained-eula",
        default=False,
        action="store_true",
        help=argparse.SUPPRESS,
    )
    install_subparser.add_argument(
        "--noninteractive",
        default=False,
        action="store_true",
        help="Non interactive mode with automatic confirmation of every action",
    )
    install_subparser.add_argument(
        "name",
        nargs="?",
        help="Name of the backend to install",
    )

    subparsers.add_parser("status", help="Show backends status")


def add_evaluation_options(parser: argparse.ArgumentParser) -> None:
    """Add evaluation options."""
    available_backends = get_available_backends()
    default_backends = get_default_backends()

    def only_one_corstone_checker() -> Callable:
        """
        Return a callable to check that only one Corstone backend is passed.

        Raises an exception when more than one Corstone backend is passed.
        """
        num_corstones = 0

        def check(backend: str) -> str:
            """Count Corstone backends and raise an exception if more than one."""
            nonlocal num_corstones
            if is_corstone_backend(backend):
                num_corstones = num_corstones + 1
                if num_corstones > 1:
                    raise argparse.ArgumentTypeError(
                        "There must be only one Corstone backend in the argument list."
                    )
            return backend

        return check

    evaluation_group = parser.add_argument_group("evaluation options")
    evaluation_group.add_argument(
        "--evaluate-on",
        help="Backends to use for evaluation (default: %(default)s)",
        nargs="*",
        choices=available_backends,
        default=default_backends,
        type=only_one_corstone_checker(),
    )


def parse_optimization_parameters(
    optimization_type: str,
    optimization_target: str,
    sep: str = ",",
    layers_to_optimize: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
    """Parse provided optimization parameters."""
    if not optimization_type:
        raise Exception("Optimization type is not provided")

    if not optimization_target:
        raise Exception("Optimization target is not provided")

    opt_types = optimization_type.split(sep)
    opt_targets = optimization_target.split(sep)

    if len(opt_types) != len(opt_targets):
        raise Exception("Wrong number of optimization targets and types")

    non_numeric_targets = [
        opt_target for opt_target in opt_targets if not is_number(opt_target)
    ]
    if len(non_numeric_targets) > 0:
        raise Exception("Non numeric value for the optimization target")

    optimizer_params = [
        {
            "optimization_type": opt_type.strip(),
            "optimization_target": float(opt_target),
            "layers_to_optimize": layers_to_optimize,
        }
        for opt_type, opt_target in zip(opt_types, opt_targets)
    ]

    return optimizer_params


def get_target_profile_opts(device_args: Optional[Dict]) -> List[str]:
    """Get non default values passed as parameters for the target profile."""
    if not device_args:
        return []

    dummy_parser = argparse.ArgumentParser()
    add_target_options(dummy_parser)
    args = dummy_parser.parse_args([])

    params_name = {
        action.dest: param_name
        for param_name, action in dummy_parser._option_string_actions.items()  # pylint: disable=protected-access
    }

    non_default = [
        arg_name
        for arg_name, arg_value in device_args.items()
        if arg_name in args and vars(args)[arg_name] != arg_value
    ]

    def construct_param(name: str, value: Any) -> List[str]:
        """Construct parameter."""
        if isinstance(value, list):
            return [str(item) for v in value for item in [name, v]]

        return [name, str(value)]

    return [
        item
        for name in non_default
        for item in construct_param(params_name[name], device_args[name])
    ]