diff options
Diffstat (limited to 'src/mlia/cli/options.py')
-rw-r--r-- | src/mlia/cli/options.py | 199 |
1 files changed, 109 insertions, 90 deletions
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index 8ea4250..bae6219 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for the CLI options.""" from __future__ import annotations @@ -8,37 +8,48 @@ from pathlib import Path from typing import Any from typing import Callable +from mlia.cli.config import DEFAULT_CLUSTERING_TARGET +from mlia.cli.config import DEFAULT_PRUNING_TARGET from mlia.cli.config import get_available_backends -from mlia.cli.config import get_default_backends from mlia.cli.config import is_corstone_backend -from mlia.core.reporting import OUTPUT_FORMATS +from mlia.core.common import FormattedFilePath from mlia.utils.filesystem import get_supported_profile_names -from mlia.utils.types import is_number + + +def add_check_category_options(parser: argparse.ArgumentParser) -> None: + """Add check category type options.""" + parser.add_argument( + "--performance", action="store_true", help="Perform performance checks." + ) + + parser.add_argument( + "--compatibility", + action="store_true", + help="Perform compatibility checks. (default)", + ) def add_target_options( - parser: argparse.ArgumentParser, profiles_to_skip: list[str] | None = None + parser: argparse.ArgumentParser, + profiles_to_skip: list[str] | None = None, + required: bool = True, ) -> None: """Add target specific options.""" target_profiles = get_supported_profile_names() if profiles_to_skip: target_profiles = [tp for tp in target_profiles if tp not in profiles_to_skip] - default_target_profile = None - default_help = "" - if target_profiles: - default_target_profile = target_profiles[0] - default_help = " (default: %(default)s)" - target_group = parser.add_argument_group("target options") target_group.add_argument( + "-t", "--target-profile", choices=target_profiles, - default=default_target_profile, + required=required, + default="", help="Target profile that will set the target options " "such as target, mac value, memory mode, etc. " - f"For the values associated with each target profile " - f" please refer to the documentation {default_help}.", + "For the values associated with each target profile " + "please refer to the documentation.", ) @@ -47,59 +58,47 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None: multi_optimization_group = parser.add_argument_group("optimization options") multi_optimization_group.add_argument( - "--optimization-type", - default="pruning,clustering", - help="List of the optimization types separated by comma (default: %(default)s)", + "--pruning", action="store_true", help="Apply pruning optimization." ) + multi_optimization_group.add_argument( - "--optimization-target", - default="0.5,32", - help="""List of the optimization targets separated by comma, - (for pruning this is sparsity between (0,1), - for clustering this is the number of clusters (positive integer)) - (default: %(default)s)""", + "--clustering", action="store_true", help="Apply clustering optimization." ) + multi_optimization_group.add_argument( + "--pruning-target", + type=float, + help="Sparsity to be reached during optimization " + f"(default: {DEFAULT_PRUNING_TARGET})", + ) -def add_optional_tflite_model_options(parser: argparse.ArgumentParser) -> None: - """Add optional model specific options.""" - model_group = parser.add_argument_group("TensorFlow Lite model options") - # make model parameter optional - model_group.add_argument( - "model", nargs="?", help="TensorFlow Lite model (optional)" + multi_optimization_group.add_argument( + "--clustering-target", + type=int, + help="Number of clusters to reach during optimization " + f"(default: {DEFAULT_CLUSTERING_TARGET})", ) -def add_tflite_model_options(parser: argparse.ArgumentParser) -> None: +def add_model_options(parser: argparse.ArgumentParser) -> None: """Add model specific options.""" - model_group = parser.add_argument_group("TensorFlow Lite model options") - model_group.add_argument("model", help="TensorFlow Lite model") + parser.add_argument("model", help="TensorFlow Lite model or Keras model") def add_output_options(parser: argparse.ArgumentParser) -> None: """Add output specific options.""" - valid_extensions = OUTPUT_FORMATS - - def check_extension(filename: str) -> str: - """Check extension of the provided file.""" - suffix = Path(filename).suffix - if suffix.startswith("."): - suffix = suffix[1:] - - if suffix.lower() not in valid_extensions: - parser.error(f"Unsupported format '{suffix}'") - - return filename - output_group = parser.add_argument_group("output options") output_group.add_argument( + "-o", "--output", - type=check_extension, - help=( - "Name of the file where report will be saved. " - "Report format is automatically detected based on the file extension. " - f"Supported formats are: {', '.join(valid_extensions)}" - ), + type=Path, + help=("Name of the file where the report will be saved."), + ) + + output_group.add_argument( + "--json", + action="store_true", + help=("Format to use for the output (requires --output argument to be set)."), ) @@ -107,7 +106,11 @@ def add_debug_options(parser: argparse.ArgumentParser) -> None: """Add debug options.""" debug_group = parser.add_argument_group("debug options") debug_group.add_argument( - "--verbose", default=False, action="store_true", help="Produce verbose output" + "-d", + "--debug", + default=False, + action="store_true", + help="Produce verbose output", ) @@ -117,20 +120,6 @@ def add_keras_model_options(parser: argparse.ArgumentParser) -> None: model_group.add_argument("model", help="Keras model") -def add_custom_supported_operators_options(parser: argparse.ArgumentParser) -> None: - """Add custom options for the command 'operators'.""" - parser.add_argument( - "--supported-ops-report", - action="store_true", - default=False, - help=( - "Generate the SUPPORTED_OPS.md file in the " - "current working directory and exit " - "(Ethos-U target profiles only)" - ), - ) - - def add_backend_install_options(parser: argparse.ArgumentParser) -> None: """Add options for the backends configuration.""" @@ -176,10 +165,11 @@ def add_backend_uninstall_options(parser: argparse.ArgumentParser) -> None: ) -def add_evaluation_options(parser: argparse.ArgumentParser) -> None: +def add_backend_options( + parser: argparse.ArgumentParser, backends_to_skip: list[str] | None = None +) -> None: """Add evaluation options.""" available_backends = get_available_backends() - default_backends = get_default_backends() def only_one_corstone_checker() -> Callable: """ @@ -202,41 +192,70 @@ def add_evaluation_options(parser: argparse.ArgumentParser) -> None: return check - evaluation_group = parser.add_argument_group("evaluation options") + # Remove backends to skip + if backends_to_skip: + available_backends = [ + x for x in available_backends if x not in backends_to_skip + ] + + evaluation_group = parser.add_argument_group("backend options") evaluation_group.add_argument( - "--evaluate-on", - help="Backends to use for evaluation (default: %(default)s)", - nargs="*", + "-b", + "--backend", + help="Backends to use for evaluation.", + nargs="+", choices=available_backends, - default=default_backends, type=only_one_corstone_checker(), ) +def parse_output_parameters(path: Path | None, json: bool) -> FormattedFilePath | None: + """Parse and return path and file format as FormattedFilePath.""" + if not path and json: + raise argparse.ArgumentError( + None, + "To enable JSON output you need to specify the output path. " + "(e.g. --output out.json --json)", + ) + if not path: + return None + if json: + return FormattedFilePath(path, "json") + + return FormattedFilePath(path, "plain_text") + + def parse_optimization_parameters( - optimization_type: str, - optimization_target: str, - sep: str = ",", + pruning: bool = False, + clustering: bool = False, + pruning_target: float | None = None, + clustering_target: int | None = None, layers_to_optimize: list[str] | None = None, ) -> list[dict[str, Any]]: """Parse provided optimization parameters.""" - if not optimization_type: - raise Exception("Optimization type is not provided") + opt_types = [] + opt_targets = [] - if not optimization_target: - raise Exception("Optimization target is not provided") + if clustering_target and not clustering: + raise argparse.ArgumentError( + None, + "To enable clustering optimization you need to include the " + "`--clustering` flag in your command.", + ) - opt_types = optimization_type.split(sep) - opt_targets = optimization_target.split(sep) + if not pruning_target: + pruning_target = DEFAULT_PRUNING_TARGET - if len(opt_types) != len(opt_targets): - raise Exception("Wrong number of optimization targets and types") + if not clustering_target: + clustering_target = DEFAULT_CLUSTERING_TARGET - non_numeric_targets = [ - opt_target for opt_target in opt_targets if not is_number(opt_target) - ] - if len(non_numeric_targets) > 0: - raise Exception("Non numeric value for the optimization target") + if (pruning is False and clustering is False) or pruning: + opt_types.append("pruning") + opt_targets.append(pruning_target) + + if clustering: + opt_types.append("clustering") + opt_targets.append(clustering_target) optimizer_params = [ { @@ -256,7 +275,7 @@ def get_target_profile_opts(device_args: dict | None) -> list[str]: return [] parser = argparse.ArgumentParser() - add_target_options(parser) + add_target_options(parser, required=False) args = parser.parse_args([]) params_name = { |