aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/cli')
-rw-r--r--src/mlia/cli/commands.py24
-rw-r--r--src/mlia/cli/logging.py109
-rw-r--r--src/mlia/cli/main.py9
-rw-r--r--src/mlia/cli/options.py35
4 files changed, 21 insertions, 156 deletions
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index d2242ba..c17d571 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -7,10 +7,10 @@ functionality.
Before running them from scripts 'logging' module should
be configured. Function 'setup_logging' from module
-'mli.cli.logging' could be used for that, e.g.
+'mli.core.logging' could be used for that, e.g.
>>> from mlia.api import ExecutionContext
->>> from mlia.cli.logging import setup_logging
+>>> from mlia.core.logging import setup_logging
>>> setup_logging(verbose=True)
>>> import mlia.cli.commands as mlia
>>> mlia.check(ExecutionContext(), "ethos-u55-256",
@@ -27,7 +27,6 @@ from mlia.cli.command_validators import validate_backend
from mlia.cli.command_validators import validate_check_target_profile
from mlia.cli.config import get_installation_manager
from mlia.cli.options import parse_optimization_parameters
-from mlia.cli.options import parse_output_parameters
from mlia.utils.console import create_section_header
logger = logging.getLogger(__name__)
@@ -41,8 +40,6 @@ def check(
model: str | None = None,
compatibility: bool = False,
performance: bool = False,
- output: Path | None = None,
- json: bool = False,
backend: list[str] | None = None,
) -> None:
"""Generate a full report on the input model.
@@ -61,7 +58,6 @@ def check(
:param model: path to the Keras model
:param compatibility: flag that identifies whether to run compatibility checks
:param performance: flag that identifies whether to run performance checks
- :param output: path to the file where the report will be saved
:param backend: list of the backends to use for evaluation
Example:
@@ -69,18 +65,15 @@ def check(
and operator compatibility.
>>> from mlia.api import ExecutionContext
- >>> from mlia.cli.logging import setup_logging
+ >>> from mlia.core.logging import setup_logging
>>> setup_logging()
>>> from mlia.cli.commands import check
>>> check(ExecutionContext(), "ethos-u55-256",
- "model.h5", compatibility=True, performance=True,
- output="report.json")
+ "model.h5", compatibility=True, performance=True)
"""
if not model:
raise Exception("Model is not provided")
- formatted_output = parse_output_parameters(output, json)
-
# Set category based on checks to perform (i.e. "compatibility" and/or
# "performance").
# If no check type is specified, "compatibility" is the default category.
@@ -98,7 +91,6 @@ def check(
target_profile,
model,
category,
- output=formatted_output,
context=ctx,
backends=validated_backend,
)
@@ -113,8 +105,6 @@ def optimize( # pylint: disable=too-many-arguments
pruning_target: float | None,
clustering_target: int | None,
layers_to_optimize: list[str] | None = None,
- output: Path | None = None,
- json: bool = False,
backend: list[str] | None = None,
) -> None:
"""Show the performance improvements (if any) after applying the optimizations.
@@ -133,15 +123,13 @@ def optimize( # pylint: disable=too-many-arguments
:param pruning_target: pruning optimization target
:param layers_to_optimize: list of the layers of the model which should be
optimized, if None then all layers are used
- :param output: path to the file where the report will be saved
- :param json: set the output format to json
:param backend: list of the backends to use for evaluation
Example:
Run command for the target profile ethos-u55-256 and
the provided TensorFlow Lite model and print report on the standard output
- >>> from mlia.cli.logging import setup_logging
+ >>> from mlia.core.logging import setup_logging
>>> from mlia.api import ExecutionContext
>>> setup_logging()
>>> from mlia.cli.commands import optimize
@@ -161,7 +149,6 @@ def optimize( # pylint: disable=too-many-arguments
)
)
- formatted_output = parse_output_parameters(output, json)
validated_backend = validate_backend(target_profile, backend)
get_advice(
@@ -169,7 +156,6 @@ def optimize( # pylint: disable=too-many-arguments
model,
{"optimization"},
optimization_targets=opt_params,
- output=formatted_output,
context=ctx,
backends=validated_backend,
)
diff --git a/src/mlia/cli/logging.py b/src/mlia/cli/logging.py
deleted file mode 100644
index 5c5c4b8..0000000
--- a/src/mlia/cli/logging.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""CLI logging configuration."""
-from __future__ import annotations
-
-import logging
-import sys
-from pathlib import Path
-from typing import Iterable
-
-from mlia.utils.logging import attach_handlers
-from mlia.utils.logging import create_log_handler
-from mlia.utils.logging import LogFilter
-
-
-_CONSOLE_DEBUG_FORMAT = "%(name)s - %(message)s"
-_FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
-
-
-def setup_logging(
- logs_dir: str | Path | None = None,
- verbose: bool = False,
- log_filename: str = "mlia.log",
-) -> None:
- """Set up logging.
-
- MLIA uses module 'logging' when it needs to produce output.
-
- :param logs_dir: path to the directory where application will save logs with
- debug information. If the path is not provided then no log files will
- be created during execution
- :param verbose: enable extended logging for the tools loggers
- :param log_filename: name of the log file in the logs directory
- """
- mlia_logger, tensorflow_logger, py_warnings_logger = (
- logging.getLogger(logger_name)
- for logger_name in ["mlia", "tensorflow", "py.warnings"]
- )
-
- # enable debug output, actual message filtering depends on
- # the provided parameters and being done at the handlers level
- for logger in [mlia_logger, tensorflow_logger]:
- logger.setLevel(logging.DEBUG)
-
- mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose)
- attach_handlers(mlia_handlers, [mlia_logger])
-
- tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose)
- attach_handlers(tools_handlers, [tensorflow_logger, py_warnings_logger])
-
-
-def _get_mlia_handlers(
- logs_dir: str | Path | None,
- log_filename: str,
- verbose: bool,
-) -> Iterable[logging.Handler]:
- """Get handlers for the MLIA loggers."""
- yield create_log_handler(
- stream=sys.stdout,
- log_level=logging.INFO,
- )
-
- if verbose:
- mlia_verbose_handler = create_log_handler(
- stream=sys.stdout,
- log_level=logging.DEBUG,
- log_format=_CONSOLE_DEBUG_FORMAT,
- log_filter=LogFilter.equals(logging.DEBUG),
- )
- yield mlia_verbose_handler
-
- if logs_dir:
- yield create_log_handler(
- file_path=_get_log_file(logs_dir, log_filename),
- log_level=logging.DEBUG,
- log_format=_FILE_DEBUG_FORMAT,
- log_filter=LogFilter.skip(logging.INFO),
- delay=True,
- )
-
-
-def _get_tools_handlers(
- logs_dir: str | Path | None,
- log_filename: str,
- verbose: bool,
-) -> Iterable[logging.Handler]:
- """Get handler for the tools loggers."""
- if verbose:
- yield create_log_handler(
- stream=sys.stdout,
- log_level=logging.DEBUG,
- log_format=_CONSOLE_DEBUG_FORMAT,
- )
-
- if logs_dir:
- yield create_log_handler(
- file_path=_get_log_file(logs_dir, log_filename),
- log_level=logging.DEBUG,
- log_format=_FILE_DEBUG_FORMAT,
- delay=True,
- )
-
-
-def _get_log_file(logs_dir: str | Path, log_filename: str) -> Path:
- """Get the log file path."""
- logs_dir_path = Path(logs_dir)
- logs_dir_path.mkdir(exist_ok=True)
-
- return logs_dir_path / log_filename
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 1102d45..7ce7dc9 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -19,7 +19,6 @@ from mlia.cli.commands import check
from mlia.cli.commands import optimize
from mlia.cli.common import CommandInfo
from mlia.cli.helpers import CLIActionResolver
-from mlia.cli.logging import setup_logging
from mlia.cli.options import add_backend_install_options
from mlia.cli.options import add_backend_options
from mlia.cli.options import add_backend_uninstall_options
@@ -30,9 +29,11 @@ from mlia.cli.options import add_model_options
from mlia.cli.options import add_multi_optimization_options
from mlia.cli.options import add_output_options
from mlia.cli.options import add_target_options
+from mlia.cli.options import get_output_format
from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
+from mlia.core.logging import setup_logging
from mlia.target.registry import registry as target_registry
@@ -162,10 +163,13 @@ def setup_context(
ctx = ExecutionContext(
verbose="debug" in args and args.debug,
action_resolver=CLIActionResolver(vars(args)),
+ output_format=get_output_format(args),
)
+ setup_logging(ctx.logs_path, ctx.verbose, ctx.output_format)
+
# these parameters should not be passed into command function
- skipped_params = ["func", "command", "debug"]
+ skipped_params = ["func", "command", "debug", "json"]
# pass these parameters only if command expects them
expected_params = [context_var_name]
@@ -186,7 +190,6 @@ def setup_context(
def run_command(args: argparse.Namespace) -> int:
"""Run command."""
ctx, func_args = setup_context(args)
- setup_logging(ctx.logs_path, ctx.verbose)
logger.debug(
"*** This is the beginning of the command '%s' execution ***", args.command
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 5aca3b3..e01f107 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -12,7 +12,7 @@ from mlia.cli.config import DEFAULT_CLUSTERING_TARGET
from mlia.cli.config import DEFAULT_PRUNING_TARGET
from mlia.cli.config import get_available_backends
from mlia.cli.config import is_corstone_backend
-from mlia.core.common import FormattedFilePath
+from mlia.core.typing import OutputFormat
from mlia.utils.filesystem import get_supported_profile_names
@@ -89,16 +89,9 @@ def add_output_options(parser: argparse.ArgumentParser) -> None:
"""Add output specific options."""
output_group = parser.add_argument_group("output options")
output_group.add_argument(
- "-o",
- "--output",
- type=Path,
- help=("Name of the file where the report will be saved."),
- )
-
- output_group.add_argument(
"--json",
action="store_true",
- help=("Format to use for the output (requires --output argument to be set)."),
+ help=("Print the output in JSON format."),
)
@@ -209,22 +202,6 @@ def add_backend_options(
)
-def parse_output_parameters(path: Path | None, json: bool) -> FormattedFilePath | None:
- """Parse and return path and file format as FormattedFilePath."""
- if not path and json:
- raise argparse.ArgumentError(
- None,
- "To enable JSON output you need to specify the output path. "
- "(e.g. --output out.json --json)",
- )
- if not path:
- return None
- if json:
- return FormattedFilePath(path, "json")
-
- return FormattedFilePath(path, "plain_text")
-
-
def parse_optimization_parameters(
pruning: bool = False,
clustering: bool = False,
@@ -301,3 +278,11 @@ def get_target_profile_opts(device_args: dict | None) -> list[str]:
for name in non_default
for item in construct_param(params_name[name], device_args[name])
]
+
+
+def get_output_format(args: argparse.Namespace) -> OutputFormat:
+ """Return the OutputFormat depending on the CLI flags."""
+ output_format: OutputFormat = "plain_text"
+ if "json" in args and args.json:
+ output_format = "json"
+ return output_format