aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/cli')
-rw-r--r--src/mlia/cli/__init__.py3
-rw-r--r--src/mlia/cli/commands.py276
-rw-r--r--src/mlia/cli/common.py38
-rw-r--r--src/mlia/cli/config.py64
-rw-r--r--src/mlia/cli/helpers.py116
-rw-r--r--src/mlia/cli/logging.py117
-rw-r--r--src/mlia/cli/main.py280
-rw-r--r--src/mlia/cli/options.py280
8 files changed, 1174 insertions, 0 deletions
diff --git a/src/mlia/cli/__init__.py b/src/mlia/cli/__init__.py
new file mode 100644
index 0000000..f50778e
--- /dev/null
+++ b/src/mlia/cli/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI module."""
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
new file mode 100644
index 0000000..45c7c32
--- /dev/null
+++ b/src/mlia/cli/commands.py
@@ -0,0 +1,276 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI commands module.
+
+This module contains functions which implement main app
+functionality.
+
+Before running them from scripts 'logging' module should
+be configured. Function 'setup_logging' from module
+'mli.cli.logging' could be used for that, e.g.
+
+>>> from mlia.api import ExecutionContext
+>>> from mlia.cli.logging import setup_logging
+>>> setup_logging(verbose=True)
+>>> import mlia.cli.commands as mlia
+>>> mlia.all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "path/to/model")
+"""
+import logging
+from pathlib import Path
+from typing import cast
+from typing import List
+from typing import Optional
+
+from mlia.api import ExecutionContext
+from mlia.api import get_advice
+from mlia.api import PathOrFileLike
+from mlia.cli.config import get_installation_manager
+from mlia.cli.options import parse_optimization_parameters
+from mlia.devices.ethosu.operators import generate_supported_operators_report
+from mlia.utils.console import create_section_header
+from mlia.utils.types import only_one_selected
+
+logger = logging.getLogger(__name__)
+
+CONFIG = create_section_header("ML Inference Advisor configuration")
+
+
+def all_tests(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: str,
+ optimization_type: str = "pruning,clustering",
+ optimization_target: str = "0.5,32",
+ output: Optional[PathOrFileLike] = None,
+ evaluate_on: Optional[List[str]] = None,
+) -> None:
+ """Generate a full report on the input model.
+
+ This command runs a series of tests in order to generate a
+ comprehensive report/advice:
+
+ - converts the input Keras model into TFLite format
+ - checks the model for operator compatibility on the specified device
+ - applies optimizations to the model and estimates the resulting performance
+ on both the original and the optimized models
+ - generates a final report on the steps above
+ - provides advice on how to (possibly) improve the inference performance
+
+ :param ctx: execution context
+ :param target_profile: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the Keras model
+ :param optimization_type: list of the optimization techniques separated
+ by comma, e.g. 'pruning,clustering'
+ :param optimization_target: list of the corresponding targets for
+ the provided optimization techniques, e.g. '0.5,32'
+ :param output: path to the file where the report will be saved
+ :param evaluate_on: list of the backends to use for evaluation
+
+ Example:
+ Run command for the target profile ethos-u55-256 with two model optimizations
+ and save report in json format locally in the file report.json
+
+ >>> from mlia.api import ExecutionContext
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import all_tests
+ >>> all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "model.h5", "pruning,clustering", "0.5,32",
+ output="report.json")
+ """
+ opt_params = parse_optimization_parameters(
+ optimization_type,
+ optimization_target,
+ )
+
+ get_advice(
+ target_profile,
+ model,
+ "all",
+ optimization_targets=opt_params,
+ output=output,
+ context=ctx,
+ backends=evaluate_on,
+ )
+
+
+def operators(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: Optional[str] = None,
+ output: Optional[PathOrFileLike] = None,
+ supported_ops_report: bool = False,
+) -> None:
+ """Print the model's operator list.
+
+ This command checks the operator compatibility of the input model with
+ the specific target profile. Generates a report of the operator placement
+ (NPU or CPU fallback) and advice on how to improve it (if necessary).
+
+ :param ctx: execution context
+ :param target_profile: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the model, which can be TFLite or Keras
+ :param output: path to the file where the report will be saved
+ :param supported_ops_report: if True then generates supported operators
+ report in current directory and exits
+
+ Example:
+ Run command for the target profile ethos-u55-256 and the provided
+ TFLite model and print report on the standard output
+
+ >>> from mlia.api import ExecutionContext
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import operators
+ >>> operators(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "model.tflite")
+ """
+ if supported_ops_report:
+ generate_supported_operators_report()
+ logger.info("Report saved into SUPPORTED_OPS.md")
+ return
+
+ if not model:
+ raise Exception("Model is not provided")
+
+ get_advice(
+ target_profile,
+ model,
+ "operators",
+ output=output,
+ context=ctx,
+ )
+
+
+def performance(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: str,
+ output: Optional[PathOrFileLike] = None,
+ evaluate_on: Optional[List[str]] = None,
+) -> None:
+ """Print the model's performance stats.
+
+ This command estimates the inference performance of the input model
+ on the specified target profile, and generates a report with advice on how
+ to improve it.
+
+ :param ctx: execution context
+ :param target_profile: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the model, which can be TFLite or Keras
+ :param output: path to the file where the report will be saved
+ :param evaluate_on: list of the backends to use for evaluation
+
+ Example:
+ Run command for the target profile ethos-u55-256 and
+ the provided TFLite model and print report on the standard output
+
+ >>> from mlia.api import ExecutionContext
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import performance
+ >>> performance(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "model.tflite")
+ """
+ get_advice(
+ target_profile,
+ model,
+ "performance",
+ output=output,
+ context=ctx,
+ backends=evaluate_on,
+ )
+
+
+def optimization(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: str,
+ optimization_type: str,
+ optimization_target: str,
+ layers_to_optimize: Optional[List[str]] = None,
+ output: Optional[PathOrFileLike] = None,
+ evaluate_on: Optional[List[str]] = None,
+) -> None:
+ """Show the performance improvements (if any) after applying the optimizations.
+
+ This command applies the selected optimization techniques (up to the
+ indicated targets) and generates a report with advice on how to improve
+ the inference performance (if possible).
+
+ :param ctx: execution context
+ :param target: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the TFLite model
+ :param optimization_type: list of the optimization techniques separated
+ by comma, e.g. 'pruning,clustering'
+ :param optimization_target: list of the corresponding targets for
+ the provided optimization techniques, e.g. '0.5,32'
+ :param layers_to_optimize: list of the layers of the model which should be
+ optimized, if None then all layers are used
+ :param output: path to the file where the report will be saved
+ :param evaluate_on: list of the backends to use for evaluation
+
+ Example:
+ Run command for the target profile ethos-u55-256 and
+ the provided TFLite model and print report on the standard output
+
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import optimization
+ >>> optimization(ExecutionContext(working_dir="mlia_output"),
+ target="ethos-u55-256",
+ "model.tflite", "pruning", "0.5")
+ """
+ opt_params = parse_optimization_parameters(
+ optimization_type,
+ optimization_target,
+ layers_to_optimize=layers_to_optimize,
+ )
+
+ get_advice(
+ target_profile,
+ model,
+ "optimization",
+ optimization_targets=opt_params,
+ output=output,
+ context=ctx,
+ backends=evaluate_on,
+ )
+
+
+def backend(
+ backend_action: str,
+ path: Optional[Path] = None,
+ download: bool = False,
+ name: Optional[str] = None,
+ i_agree_to_the_contained_eula: bool = False,
+ noninteractive: bool = False,
+) -> None:
+ """Backends configuration."""
+ logger.info(CONFIG)
+
+ manager = get_installation_manager(noninteractive)
+
+ if backend_action == "status":
+ manager.show_env_details()
+
+ if backend_action == "install":
+ install_from_path = path is not None
+
+ if not only_one_selected(install_from_path, download):
+ raise Exception(
+ "Please select only one action: download or "
+ "provide path to the backend installation"
+ )
+
+ if install_from_path:
+ manager.install_from(cast(Path, path), name)
+
+ if download:
+ eula_agreement = not i_agree_to_the_contained_eula
+ manager.download_and_install(name, eula_agreement)
diff --git a/src/mlia/cli/common.py b/src/mlia/cli/common.py
new file mode 100644
index 0000000..54bd457
--- /dev/null
+++ b/src/mlia/cli/common.py
@@ -0,0 +1,38 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI common module."""
+import argparse
+from dataclasses import dataclass
+from typing import Callable
+from typing import List
+
+
+@dataclass
+class CommandInfo:
+ """Command description."""
+
+ func: Callable
+ aliases: List[str]
+ opt_groups: List[Callable[[argparse.ArgumentParser], None]]
+ is_default: bool = False
+
+ @property
+ def command_name(self) -> str:
+ """Return command name."""
+ return self.func.__name__
+
+ @property
+ def command_name_and_aliases(self) -> List[str]:
+ """Return list of command name and aliases."""
+ return [self.command_name, *self.aliases]
+
+ @property
+ def command_help(self) -> str:
+ """Return help message for the command."""
+ assert self.func.__doc__, "Command function does not have a docstring"
+ func_help = self.func.__doc__.splitlines()[0].rstrip(".")
+
+ if self.is_default:
+ func_help = f"{func_help} [default]"
+
+ return func_help
diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py
new file mode 100644
index 0000000..838b051
--- /dev/null
+++ b/src/mlia/cli/config.py
@@ -0,0 +1,64 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Environment configuration functions."""
+import logging
+from functools import lru_cache
+from typing import List
+
+import mlia.tools.aiet_wrapper as aiet
+from mlia.tools.metadata.common import DefaultInstallationManager
+from mlia.tools.metadata.common import InstallationManager
+from mlia.tools.metadata.corstone import get_corstone_installations
+
+logger = logging.getLogger(__name__)
+
+
+def get_installation_manager(noninteractive: bool = False) -> InstallationManager:
+ """Return installation manager."""
+ backends = get_corstone_installations()
+
+ return DefaultInstallationManager(backends, noninteractive=noninteractive)
+
+
+@lru_cache
+def get_available_backends() -> List[str]:
+ """Return list of the available backends."""
+ available_backends = ["Vela"]
+
+ # Add backends using AIET
+ manager = get_installation_manager()
+ available_backends.extend(
+ (
+ backend
+ for backend in aiet.supported_backends()
+ if manager.backend_installed(backend)
+ )
+ )
+
+ return available_backends
+
+
+# List of mutually exclusive Corstone backends ordered by priority
+_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300")
+
+
+def get_default_backends() -> List[str]:
+ """Get default backends for evaluation."""
+ backends = get_available_backends()
+
+ # Filter backends to only include one Corstone backend
+ for corstone in _CORSTONE_EXCLUSIVE_PRIORITY:
+ if corstone in backends:
+ backends = [
+ backend
+ for backend in backends
+ if backend == corstone or backend not in _CORSTONE_EXCLUSIVE_PRIORITY
+ ]
+ break
+
+ return backends
+
+
+def is_corstone_backend(backend: str) -> bool:
+ """Check if the given backend is a Corstone backend."""
+ return backend in _CORSTONE_EXCLUSIVE_PRIORITY
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
new file mode 100644
index 0000000..81d5a15
--- /dev/null
+++ b/src/mlia/cli/helpers.py
@@ -0,0 +1,116 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for various helper classes."""
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from mlia.cli.options import get_target_profile_opts
+from mlia.core.helpers import ActionResolver
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.utils.types import is_list_of
+
+
+class CLIActionResolver(ActionResolver):
+ """Helper class for generating cli commands."""
+
+ def __init__(self, args: Dict[str, Any]) -> None:
+ """Init action resolver."""
+ self.args = args
+
+ @staticmethod
+ def _general_optimization_command(model_path: Optional[str]) -> List[str]:
+ """Return general optimization command description."""
+ keras_note = []
+ if model_path is None or not is_keras_model(model_path):
+ model_path = "/path/to/keras_model"
+ keras_note = ["Note: you will need a Keras model for that."]
+
+ return [
+ *keras_note,
+ "For example: mlia optimization --optimization-type "
+ f"pruning,clustering --optimization-target 0.5,32 {model_path}",
+ "For more info: mlia optimization --help",
+ ]
+
+ @staticmethod
+ def _specific_optimization_command(
+ model_path: str,
+ device_opts: str,
+ opt_settings: List[OptimizationSettings],
+ ) -> List[str]:
+ """Return specific optimization command description."""
+ opt_types = ",".join(opt.optimization_type for opt in opt_settings)
+ opt_targs = ",".join(str(opt.optimization_target) for opt in opt_settings)
+
+ return [
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ f"mlia optimization --optimization-type {opt_types} "
+ f"--optimization-target {opt_targs}{device_opts} {model_path}",
+ ]
+
+ def apply_optimizations(self, **kwargs: Any) -> List[str]:
+ """Return command details for applying optimizations."""
+ model_path, device_opts = self._get_model_and_device_opts()
+
+ if (opt_settings := kwargs.pop("opt_settings", None)) is None:
+ return self._general_optimization_command(model_path)
+
+ if is_list_of(opt_settings, OptimizationSettings) and model_path:
+ return self._specific_optimization_command(
+ model_path, device_opts, opt_settings
+ )
+
+ return []
+
+ def supported_operators_info(self) -> List[str]:
+ """Return command details for generating supported ops report."""
+ return [
+ "For guidance on supported operators, run: mlia operators "
+ "--supported-ops-report",
+ ]
+
+ def check_performance(self) -> List[str]:
+ """Return command details for checking performance."""
+ model_path, device_opts = self._get_model_and_device_opts()
+ if not model_path:
+ return []
+
+ return [
+ "Check the estimated performance by running the following command: ",
+ f"mlia performance{device_opts} {model_path}",
+ ]
+
+ def check_operator_compatibility(self) -> List[str]:
+ """Return command details for op compatibility."""
+ model_path, device_opts = self._get_model_and_device_opts()
+ if not model_path:
+ return []
+
+ return [
+ "Try running the following command to verify that:",
+ f"mlia operators{device_opts} {model_path}",
+ ]
+
+ def operator_compatibility_details(self) -> List[str]:
+ """Return command details for op compatibility."""
+ return ["For more details, run: mlia operators --help"]
+
+ def optimization_details(self) -> List[str]:
+ """Return command details for optimization."""
+ return ["For more info, see: mlia optimization --help"]
+
+ def _get_model_and_device_opts(
+ self, separate_device_opts: bool = True
+ ) -> Tuple[Optional[str], str]:
+ """Get model and device options."""
+ device_opts = " ".join(get_target_profile_opts(self.args))
+ if separate_device_opts and device_opts:
+ device_opts = f" {device_opts}"
+
+ model_path = self.args.get("model")
+ return model_path, device_opts
diff --git a/src/mlia/cli/logging.py b/src/mlia/cli/logging.py
new file mode 100644
index 0000000..c5fc7bd
--- /dev/null
+++ b/src/mlia/cli/logging.py
@@ -0,0 +1,117 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI logging configuration."""
+import logging
+import sys
+from pathlib import Path
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mlia.utils.logging import attach_handlers
+from mlia.utils.logging import create_log_handler
+from mlia.utils.logging import LogFilter
+
+
+_CONSOLE_DEBUG_FORMAT = "%(name)s - %(message)s"
+_FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+
+
+def setup_logging(
+ logs_dir: Optional[Union[str, Path]] = None,
+ verbose: bool = False,
+ log_filename: str = "mlia.log",
+) -> None:
+ """Set up logging.
+
+ MLIA uses module 'logging' when it needs to produce output.
+
+ :param logs_dir: path to the directory where application will save logs with
+ debug information. If the path is not provided then no log files will
+ be created during execution
+ :param verbose: enable extended logging for the tools loggers
+ :param log_filename: name of the log file in the logs directory
+ """
+ mlia_logger, *tools_loggers = [
+ logging.getLogger(logger_name)
+ for logger_name in ["mlia", "tensorflow", "py.warnings"]
+ ]
+
+ # enable debug output, actual message filtering depends on
+ # the provided parameters and being done on the handlers level
+ mlia_logger.setLevel(logging.DEBUG)
+
+ mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose)
+ attach_handlers(mlia_handlers, [mlia_logger])
+
+ tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose)
+ attach_handlers(tools_handlers, tools_loggers)
+
+
+def _get_mlia_handlers(
+ logs_dir: Optional[Union[str, Path]],
+ log_filename: str,
+ verbose: bool,
+) -> List[logging.Handler]:
+ """Get handlers for the MLIA loggers."""
+ result = []
+ stdout_handler = create_log_handler(
+ stream=sys.stdout,
+ log_level=logging.INFO,
+ )
+ result.append(stdout_handler)
+
+ if verbose:
+ mlia_verbose_handler = create_log_handler(
+ stream=sys.stdout,
+ log_level=logging.DEBUG,
+ log_format=_CONSOLE_DEBUG_FORMAT,
+ log_filter=LogFilter.equals(logging.DEBUG),
+ )
+ result.append(mlia_verbose_handler)
+
+ if logs_dir:
+ mlia_file_handler = create_log_handler(
+ file_path=_get_log_file(logs_dir, log_filename),
+ log_level=logging.DEBUG,
+ log_format=_FILE_DEBUG_FORMAT,
+ log_filter=LogFilter.skip(logging.INFO),
+ delay=True,
+ )
+ result.append(mlia_file_handler)
+
+ return result
+
+
+def _get_tools_handlers(
+ logs_dir: Optional[Union[str, Path]],
+ log_filename: str,
+ verbose: bool,
+) -> List[logging.Handler]:
+ """Get handler for the tools loggers."""
+ result = []
+ if verbose:
+ verbose_stdout_handler = create_log_handler(
+ stream=sys.stdout,
+ log_level=logging.DEBUG,
+ log_format=_CONSOLE_DEBUG_FORMAT,
+ )
+ result.append(verbose_stdout_handler)
+
+ if logs_dir:
+ file_handler = create_log_handler(
+ file_path=_get_log_file(logs_dir, log_filename),
+ log_level=logging.DEBUG,
+ log_format=_FILE_DEBUG_FORMAT,
+ delay=True,
+ )
+ result.append(file_handler)
+
+ return result
+
+
+def _get_log_file(logs_dir: Union[str, Path], log_filename: str) -> Path:
+ """Get the log file path."""
+ logs_dir_path = Path(logs_dir)
+ logs_dir_path.mkdir(exist_ok=True)
+ return logs_dir_path / log_filename
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
new file mode 100644
index 0000000..33fcdeb
--- /dev/null
+++ b/src/mlia/cli/main.py
@@ -0,0 +1,280 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI main entry point."""
+import argparse
+import logging
+import sys
+from inspect import signature
+from pathlib import Path
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from mlia import __version__
+from mlia.cli.commands import all_tests
+from mlia.cli.commands import backend
+from mlia.cli.commands import operators
+from mlia.cli.commands import optimization
+from mlia.cli.commands import performance
+from mlia.cli.common import CommandInfo
+from mlia.cli.helpers import CLIActionResolver
+from mlia.cli.logging import setup_logging
+from mlia.cli.options import add_backend_options
+from mlia.cli.options import add_custom_supported_operators_options
+from mlia.cli.options import add_debug_options
+from mlia.cli.options import add_evaluation_options
+from mlia.cli.options import add_keras_model_options
+from mlia.cli.options import add_multi_optimization_options
+from mlia.cli.options import add_optional_tflite_model_options
+from mlia.cli.options import add_output_options
+from mlia.cli.options import add_target_options
+from mlia.cli.options import add_tflite_model_options
+from mlia.core.context import ExecutionContext
+
+
+logger = logging.getLogger(__name__)
+
+INFO_MESSAGE = f"""
+ML Inference Advisor {__version__}
+
+Help the design and optimization of neural network models for efficient inference on a target CPU, GPU and NPU
+
+Supported targets:
+
+ - Ethos-U55 <op compatibility, perf estimation, model opt>
+ - Ethos-U65 <op compatibility, perf estimation, model opt>
+
+""".strip()
+
+
+def get_commands() -> List[CommandInfo]:
+ """Return commands configuration."""
+ return [
+ CommandInfo(
+ all_tests,
+ ["all"],
+ [
+ add_target_options,
+ add_keras_model_options,
+ add_multi_optimization_options,
+ add_output_options,
+ add_debug_options,
+ add_evaluation_options,
+ ],
+ True,
+ ),
+ CommandInfo(
+ operators,
+ ["ops"],
+ [
+ add_target_options,
+ add_optional_tflite_model_options,
+ add_output_options,
+ add_custom_supported_operators_options,
+ add_debug_options,
+ ],
+ ),
+ CommandInfo(
+ performance,
+ ["perf"],
+ [
+ add_target_options,
+ add_tflite_model_options,
+ add_output_options,
+ add_debug_options,
+ add_evaluation_options,
+ ],
+ ),
+ CommandInfo(
+ optimization,
+ ["opt"],
+ [
+ add_target_options,
+ add_keras_model_options,
+ add_multi_optimization_options,
+ add_output_options,
+ add_debug_options,
+ add_evaluation_options,
+ ],
+ ),
+ CommandInfo(
+ backend,
+ [],
+ [
+ add_backend_options,
+ add_debug_options,
+ ],
+ ),
+ ]
+
+
+def get_default_command() -> Optional[str]:
+ """Get name of the default command."""
+ commands = get_commands()
+
+ marked_as_default = [cmd.command_name for cmd in commands if cmd.is_default]
+ assert len(marked_as_default) <= 1, "Only one command could be marked as default"
+
+ return next(iter(marked_as_default), None)
+
+
+def get_possible_command_names() -> List[str]:
+ """Get all possible command names including aliases."""
+ return [
+ name_or_alias
+ for cmd in get_commands()
+ for name_or_alias in cmd.command_name_and_aliases
+ ]
+
+
+def init_commands(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ """Init cli subcommands."""
+ subparsers = parser.add_subparsers(title="Commands", dest="command")
+ subparsers.required = True
+
+ for command in get_commands():
+ command_parser = subparsers.add_parser(
+ command.command_name,
+ aliases=command.aliases,
+ help=command.command_help,
+ allow_abbrev=False,
+ )
+ command_parser.set_defaults(func=command.func)
+ for opt_group in command.opt_groups:
+ opt_group(command_parser)
+
+ return parser
+
+
+def setup_context(
+ args: argparse.Namespace, context_var_name: str = "ctx"
+) -> Tuple[ExecutionContext, Dict]:
+ """Set up context and resolve function parameters."""
+ ctx = ExecutionContext(
+ working_dir=args.working_dir,
+ verbose="verbose" in args and args.verbose,
+ action_resolver=CLIActionResolver(vars(args)),
+ )
+
+ # these parameters should not be passed into command function
+ skipped_params = ["func", "command", "working_dir", "verbose"]
+
+ # pass these parameters only if command expects them
+ expected_params = [context_var_name]
+ func_params = signature(args.func).parameters
+
+ params = {context_var_name: ctx, **vars(args)}
+
+ func_args = {
+ param_name: param_value
+ for param_name, param_value in params.items()
+ if param_name not in skipped_params
+ and (param_name not in expected_params or param_name in func_params)
+ }
+
+ return (ctx, func_args)
+
+
+def run_command(args: argparse.Namespace) -> int:
+ """Run command."""
+ ctx, func_args = setup_context(args)
+ setup_logging(ctx.logs_path, ctx.verbose)
+
+ logger.debug(
+ "*** This is the beginning of the command '%s' execution ***", args.command
+ )
+
+ try:
+ logger.info(INFO_MESSAGE)
+
+ args.func(**func_args)
+ return 0
+ except KeyboardInterrupt:
+ logger.error("Execution has been interrupted")
+ except Exception as err: # pylint: disable=broad-except
+ logger.error(
+ "\nExecution finished with error: %s",
+ err,
+ exc_info=err if ctx.verbose else None,
+ )
+
+ err_advice_message = (
+ f"Please check the log files in the {ctx.logs_path} for more details"
+ )
+ if not ctx.verbose:
+ err_advice_message += ", or enable verbose mode"
+
+ logger.error(err_advice_message)
+
+ return 1
+
+
+def init_common_parser() -> argparse.ArgumentParser:
+ """Init common parser."""
+ parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
+ parser.add_argument(
+ "--working-dir",
+ default=f"{Path.cwd() / 'mlia_output'}",
+ help="Path to the directory where MLIA will store logs, "
+ "models, etc. (default: %(default)s)",
+ )
+
+ return parser
+
+
+def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ """Init subcommand parser."""
+ parser = argparse.ArgumentParser(
+ description=INFO_MESSAGE,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ parents=[parent],
+ add_help=False,
+ allow_abbrev=False,
+ )
+ parser.add_argument(
+ "-h",
+ "--help",
+ action="help",
+ default=argparse.SUPPRESS,
+ help="Show this help message and exit",
+ )
+ parser.add_argument(
+ "-v",
+ "--version",
+ action="version",
+ version=f"%(prog)s {__version__}",
+ help="Show program's version number and exit",
+ )
+
+ return parser
+
+
+def add_default_command_if_needed(args: List[str]) -> None:
+ """Add default command to the list of the arguments if needed."""
+ default_command = get_default_command()
+
+ if default_command and len(args) > 0:
+ commands = get_possible_command_names()
+ help_or_version = ["-h", "--help", "-v", "--version"]
+
+ command_is_missing = args[0] not in [*commands, *help_or_version]
+ if command_is_missing:
+ args.insert(0, default_command)
+
+
+def main(argv: Optional[List[str]] = None) -> int:
+ """Entry point of the application."""
+ common_parser = init_common_parser()
+ subcommand_parser = init_subcommand_parser(common_parser)
+ init_commands(subcommand_parser)
+
+ common_args, subcommand_args = common_parser.parse_known_args(argv)
+ add_default_command_if_needed(subcommand_args)
+
+ args = subcommand_parser.parse_args(subcommand_args, common_args)
+ return run_command(args)
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
new file mode 100644
index 0000000..dc5cb73
--- /dev/null
+++ b/src/mlia/cli/options.py
@@ -0,0 +1,280 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for the CLI options."""
+import argparse
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from mlia.cli.config import get_available_backends
+from mlia.cli.config import get_default_backends
+from mlia.cli.config import is_corstone_backend
+from mlia.utils.filesystem import get_supported_profile_names
+from mlia.utils.types import is_number
+
+
+def add_target_options(parser: argparse.ArgumentParser) -> None:
+ """Add target specific options."""
+ target_profiles = get_supported_profile_names()
+
+ default_target_profile = None
+ default_help = ""
+ if target_profiles:
+ default_target_profile = target_profiles[0]
+ default_help = " (default: %(default)s)"
+
+ target_group = parser.add_argument_group("target options")
+ target_group.add_argument(
+ "--target-profile",
+ choices=target_profiles,
+ default=default_target_profile,
+ help="Target profile that will set the target options "
+ "such as target, mac value, memory mode, etc. "
+ f"For the values associated with each target profile "
+ f" please refer to the documenation {default_help}.",
+ )
+
+
+def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
+ """Add optimization specific options."""
+ multi_optimization_group = parser.add_argument_group("optimization options")
+
+ multi_optimization_group.add_argument(
+ "--optimization-type",
+ default="pruning,clustering",
+ help="List of the optimization types separated by comma (default: %(default)s)",
+ )
+ multi_optimization_group.add_argument(
+ "--optimization-target",
+ default="0.5,32",
+ help="""List of the optimization targets separated by comma,
+ (for pruning this is sparsity between (0,1),
+ for clustering this is the number of clusters (positive integer))
+ (default: %(default)s)""",
+ )
+
+
+def add_optional_tflite_model_options(parser: argparse.ArgumentParser) -> None:
+ """Add optional model specific options."""
+ model_group = parser.add_argument_group("TFLite model options")
+ # make model parameter optional
+ model_group.add_argument("model", nargs="?", help="TFLite model (optional)")
+
+
+def add_tflite_model_options(parser: argparse.ArgumentParser) -> None:
+ """Add model specific options."""
+ model_group = parser.add_argument_group("TFLite model options")
+ model_group.add_argument("model", help="TFLite model")
+
+
+def add_output_options(parser: argparse.ArgumentParser) -> None:
+ """Add output specific options."""
+ valid_extensions = ["csv", "json"]
+
+ def check_extension(filename: str) -> str:
+ """Check extension of the provided file."""
+ suffix = Path(filename).suffix
+ if suffix.startswith("."):
+ suffix = suffix[1:]
+
+ if suffix.lower() not in valid_extensions:
+ parser.error(f"Unsupported format '{suffix}'")
+
+ return filename
+
+ output_group = parser.add_argument_group("output options")
+ output_group.add_argument(
+ "--output",
+ type=check_extension,
+ help=(
+ "Name of the file where report will be saved. "
+ "Report format is automatically detected based on the file extension. "
+ f"Supported formats are: {', '.join(valid_extensions)}"
+ ),
+ )
+
+
+def add_debug_options(parser: argparse.ArgumentParser) -> None:
+ """Add debug options."""
+ debug_group = parser.add_argument_group("debug options")
+ debug_group.add_argument(
+ "--verbose", default=False, action="store_true", help="Produce verbose output"
+ )
+
+
+def add_keras_model_options(parser: argparse.ArgumentParser) -> None:
+ """Add model specific options."""
+ model_group = parser.add_argument_group("Keras model options")
+ model_group.add_argument("model", help="Keras model")
+
+
+def add_custom_supported_operators_options(parser: argparse.ArgumentParser) -> None:
+ """Add custom options for the command 'operators'."""
+ parser.add_argument(
+ "--supported-ops-report",
+ action="store_true",
+ default=False,
+ help=(
+ "Generate the SUPPORTED_OPS.md file in the "
+ "current working directory and exit"
+ ),
+ )
+
+
+def add_backend_options(parser: argparse.ArgumentParser) -> None:
+ """Add options for the backends configuration."""
+
+ def valid_directory(param: str) -> Path:
+ """Check if passed string is a valid directory path."""
+ if not (dir_path := Path(param)).is_dir():
+ parser.error(f"Invalid directory path {param}")
+
+ return dir_path
+
+ subparsers = parser.add_subparsers(title="Backend actions", dest="backend_action")
+ subparsers.required = True
+
+ install_subparser = subparsers.add_parser(
+ "install", help="Install backend", allow_abbrev=False
+ )
+ install_type_group = install_subparser.add_mutually_exclusive_group()
+ install_type_group.required = True
+ install_type_group.add_argument(
+ "--path", type=valid_directory, help="Path to the installed backend"
+ )
+ install_type_group.add_argument(
+ "--download",
+ default=False,
+ action="store_true",
+ help="Download and install backend",
+ )
+ install_subparser.add_argument(
+ "--i-agree-to-the-contained-eula",
+ default=False,
+ action="store_true",
+ help=argparse.SUPPRESS,
+ )
+ install_subparser.add_argument(
+ "--noninteractive",
+ default=False,
+ action="store_true",
+ help="Non interactive mode with automatic confirmation of every action",
+ )
+ install_subparser.add_argument(
+ "name",
+ nargs="?",
+ help="Name of the backend to install",
+ )
+
+ subparsers.add_parser("status", help="Show backends status")
+
+
+def add_evaluation_options(parser: argparse.ArgumentParser) -> None:
+ """Add evaluation options."""
+ available_backends = get_available_backends()
+ default_backends = get_default_backends()
+
+ def only_one_corstone_checker() -> Callable:
+ """
+ Return a callable to check that only one Corstone backend is passed.
+
+ Raises an exception when more than one Corstone backend is passed.
+ """
+ num_corstones = 0
+
+ def check(backend: str) -> str:
+ """Count Corstone backends and raise an exception if more than one."""
+ nonlocal num_corstones
+ if is_corstone_backend(backend):
+ num_corstones = num_corstones + 1
+ if num_corstones > 1:
+ raise argparse.ArgumentTypeError(
+ "There must be only one Corstone backend in the argument list."
+ )
+ return backend
+
+ return check
+
+ evaluation_group = parser.add_argument_group("evaluation options")
+ evaluation_group.add_argument(
+ "--evaluate-on",
+ help="Backends to use for evaluation (default: %(default)s)",
+ nargs="*",
+ choices=available_backends,
+ default=default_backends,
+ type=only_one_corstone_checker(),
+ )
+
+
+def parse_optimization_parameters(
+ optimization_type: str,
+ optimization_target: str,
+ sep: str = ",",
+ layers_to_optimize: Optional[List[str]] = None,
+) -> List[Dict[str, Any]]:
+ """Parse provided optimization parameters."""
+ if not optimization_type:
+ raise Exception("Optimization type is not provided")
+
+ if not optimization_target:
+ raise Exception("Optimization target is not provided")
+
+ opt_types = optimization_type.split(sep)
+ opt_targets = optimization_target.split(sep)
+
+ if len(opt_types) != len(opt_targets):
+ raise Exception("Wrong number of optimization targets and types")
+
+ non_numeric_targets = [
+ opt_target for opt_target in opt_targets if not is_number(opt_target)
+ ]
+ if len(non_numeric_targets) > 0:
+ raise Exception("Non numeric value for the optimization target")
+
+ optimizer_params = [
+ {
+ "optimization_type": opt_type.strip(),
+ "optimization_target": float(opt_target),
+ "layers_to_optimize": layers_to_optimize,
+ }
+ for opt_type, opt_target in zip(opt_types, opt_targets)
+ ]
+
+ return optimizer_params
+
+
+def get_target_profile_opts(device_args: Optional[Dict]) -> List[str]:
+ """Get non default values passed as parameters for the target profile."""
+ if not device_args:
+ return []
+
+ dummy_parser = argparse.ArgumentParser()
+ add_target_options(dummy_parser)
+ args = dummy_parser.parse_args([])
+
+ params_name = {
+ action.dest: param_name
+ for param_name, action in dummy_parser._option_string_actions.items() # pylint: disable=protected-access
+ }
+
+ non_default = [
+ arg_name
+ for arg_name, arg_value in device_args.items()
+ if arg_name in args and vars(args)[arg_name] != arg_value
+ ]
+
+ def construct_param(name: str, value: Any) -> List[str]:
+ """Construct parameter."""
+ if isinstance(value, list):
+ return [str(item) for v in value for item in [name, v]]
+
+ return [name, str(value)]
+
+ return [
+ item
+ for name in non_default
+ for item in construct_param(params_name[name], device_args[name])
+ ]