aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDiego Russo <diego.russo@arm.com>2023-02-02 22:04:05 +0000
committerDiego Russo <diego.russo@arm.com>2023-02-03 17:48:11 +0000
commitf1eaff3c9790464bed3183ff76555cf815166f50 (patch)
treebba288bb051925a692e1e1f998f7cc48df755df6 /src
parentd1b2374cda6811a93d1174400fc2eecd7100a8c3 (diff)
downloadmlia-f1eaff3c9790464bed3183ff76555cf815166f50.tar.gz
MLIA-782 Remove --output parameter
* Remove --output parameter from argument parser * Remove FormattedFilePath class and its presence across the codebase * Move logging module from cli to core * The output format is now injected in the execution context and used across MLIA * Depending on the output format, TextReporter and JSONReporter have been created and used accordingly. * The whole output to standard output and/or logfile is driven via the logging module: the only case where the print is used is when the --json parameter is specified. This is needed becase all output (including third party application as well) needs to be disabled otherwise it might corrupt the json output in the standard output. * Debug information is logged into the log file and printed to stdout when the output format is plain_text. * Update E2E test and config to cope with the new mechanism of outputting json data to standard output. Change-Id: I4395800b0b1af4d24406a828d780bdeef98cd413
Diffstat (limited to 'src')
-rw-r--r--src/mlia/api.py13
-rw-r--r--src/mlia/backend/tosa_checker/compat.py6
-rw-r--r--src/mlia/cli/commands.py24
-rw-r--r--src/mlia/cli/main.py9
-rw-r--r--src/mlia/cli/options.py35
-rw-r--r--src/mlia/core/common.py33
-rw-r--r--src/mlia/core/context.py16
-rw-r--r--src/mlia/core/handlers.py31
-rw-r--r--src/mlia/core/logging.py (renamed from src/mlia/cli/logging.py)61
-rw-r--r--src/mlia/core/reporting.py166
-rw-r--r--src/mlia/core/typing.py7
-rw-r--r--src/mlia/core/workflow.py3
-rw-r--r--src/mlia/target/cortex_a/advisor.py4
-rw-r--r--src/mlia/target/cortex_a/handlers.py5
-rw-r--r--src/mlia/target/ethos_u/advisor.py4
-rw-r--r--src/mlia/target/ethos_u/handlers.py5
-rw-r--r--src/mlia/target/tosa/advisor.py4
-rw-r--r--src/mlia/target/tosa/handlers.py5
-rw-r--r--src/mlia/utils/logging.py17
19 files changed, 196 insertions, 252 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index 2cabf37..8105276 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -9,7 +9,6 @@ from typing import Any
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
-from mlia.core.common import FormattedFilePath
from mlia.core.context import ExecutionContext
from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
@@ -24,7 +23,6 @@ def get_advice(
model: str | Path,
category: set[str],
optimization_targets: list[dict[str, Any]] | None = None,
- output: FormattedFilePath | None = None,
context: ExecutionContext | None = None,
backends: list[str] | None = None,
) -> None:
@@ -42,8 +40,6 @@ def get_advice(
category "compatibility" is used by default.
:param optimization_targets: optional model optimization targets that
could be used for generating advice in "optimization" category.
- :param output: path to the report file. If provided, MLIA will save
- report in this location.
:param context: optional parameter which represents execution context,
could be used for advanced use cases
:param backends: A list of backends that should be used for the given
@@ -57,11 +53,9 @@ def get_advice(
>>> get_advice("ethos-u55-256", "path/to/the/model",
{"optimization", "compatibility"})
- Getting the advice for the category "performance" and save result report in file
- "report.json"
+ Getting the advice for the category "performance".
- >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"},
- output=FormattedFilePath("report.json")
+ >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"})
"""
advice_category = AdviceCategory.from_string(category)
@@ -76,7 +70,6 @@ def get_advice(
context,
target_profile,
model,
- output,
optimization_targets=optimization_targets,
backends=backends,
)
@@ -88,7 +81,6 @@ def get_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Find appropriate advisor for the target."""
@@ -109,6 +101,5 @@ def get_advisor(
context,
target_profile,
model,
- output,
**extra_args,
)
diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py
index 81f3015..1c410d3 100644
--- a/src/mlia/backend/tosa_checker/compat.py
+++ b/src/mlia/backend/tosa_checker/compat.py
@@ -5,12 +5,12 @@ from __future__ import annotations
import sys
from dataclasses import dataclass
+from pathlib import Path
from typing import Any
from typing import cast
from typing import Protocol
from mlia.backend.errors import BackendUnavailableError
-from mlia.core.typing import PathOrFileLike
from mlia.utils.logging import capture_raw_output
@@ -45,7 +45,7 @@ class TOSACompatibilityInfo:
def get_tosa_compatibility_info(
- tflite_model_path: PathOrFileLike,
+ tflite_model_path: str | Path,
) -> TOSACompatibilityInfo:
"""Return list of the operators."""
# Capture the possible exception in running get_tosa_checker
@@ -100,7 +100,7 @@ def get_tosa_compatibility_info(
)
-def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None:
+def get_tosa_checker(tflite_model_path: str | Path) -> TOSAChecker | None:
"""Return instance of the TOSA checker."""
try:
import tosa_checker as tc # pylint: disable=import-outside-toplevel
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index d2242ba..c17d571 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -7,10 +7,10 @@ functionality.
Before running them from scripts 'logging' module should
be configured. Function 'setup_logging' from module
-'mli.cli.logging' could be used for that, e.g.
+'mli.core.logging' could be used for that, e.g.
>>> from mlia.api import ExecutionContext
->>> from mlia.cli.logging import setup_logging
+>>> from mlia.core.logging import setup_logging
>>> setup_logging(verbose=True)
>>> import mlia.cli.commands as mlia
>>> mlia.check(ExecutionContext(), "ethos-u55-256",
@@ -27,7 +27,6 @@ from mlia.cli.command_validators import validate_backend
from mlia.cli.command_validators import validate_check_target_profile
from mlia.cli.config import get_installation_manager
from mlia.cli.options import parse_optimization_parameters
-from mlia.cli.options import parse_output_parameters
from mlia.utils.console import create_section_header
logger = logging.getLogger(__name__)
@@ -41,8 +40,6 @@ def check(
model: str | None = None,
compatibility: bool = False,
performance: bool = False,
- output: Path | None = None,
- json: bool = False,
backend: list[str] | None = None,
) -> None:
"""Generate a full report on the input model.
@@ -61,7 +58,6 @@ def check(
:param model: path to the Keras model
:param compatibility: flag that identifies whether to run compatibility checks
:param performance: flag that identifies whether to run performance checks
- :param output: path to the file where the report will be saved
:param backend: list of the backends to use for evaluation
Example:
@@ -69,18 +65,15 @@ def check(
and operator compatibility.
>>> from mlia.api import ExecutionContext
- >>> from mlia.cli.logging import setup_logging
+ >>> from mlia.core.logging import setup_logging
>>> setup_logging()
>>> from mlia.cli.commands import check
>>> check(ExecutionContext(), "ethos-u55-256",
- "model.h5", compatibility=True, performance=True,
- output="report.json")
+ "model.h5", compatibility=True, performance=True)
"""
if not model:
raise Exception("Model is not provided")
- formatted_output = parse_output_parameters(output, json)
-
# Set category based on checks to perform (i.e. "compatibility" and/or
# "performance").
# If no check type is specified, "compatibility" is the default category.
@@ -98,7 +91,6 @@ def check(
target_profile,
model,
category,
- output=formatted_output,
context=ctx,
backends=validated_backend,
)
@@ -113,8 +105,6 @@ def optimize( # pylint: disable=too-many-arguments
pruning_target: float | None,
clustering_target: int | None,
layers_to_optimize: list[str] | None = None,
- output: Path | None = None,
- json: bool = False,
backend: list[str] | None = None,
) -> None:
"""Show the performance improvements (if any) after applying the optimizations.
@@ -133,15 +123,13 @@ def optimize( # pylint: disable=too-many-arguments
:param pruning_target: pruning optimization target
:param layers_to_optimize: list of the layers of the model which should be
optimized, if None then all layers are used
- :param output: path to the file where the report will be saved
- :param json: set the output format to json
:param backend: list of the backends to use for evaluation
Example:
Run command for the target profile ethos-u55-256 and
the provided TensorFlow Lite model and print report on the standard output
- >>> from mlia.cli.logging import setup_logging
+ >>> from mlia.core.logging import setup_logging
>>> from mlia.api import ExecutionContext
>>> setup_logging()
>>> from mlia.cli.commands import optimize
@@ -161,7 +149,6 @@ def optimize( # pylint: disable=too-many-arguments
)
)
- formatted_output = parse_output_parameters(output, json)
validated_backend = validate_backend(target_profile, backend)
get_advice(
@@ -169,7 +156,6 @@ def optimize( # pylint: disable=too-many-arguments
model,
{"optimization"},
optimization_targets=opt_params,
- output=formatted_output,
context=ctx,
backends=validated_backend,
)
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 1102d45..7ce7dc9 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -19,7 +19,6 @@ from mlia.cli.commands import check
from mlia.cli.commands import optimize
from mlia.cli.common import CommandInfo
from mlia.cli.helpers import CLIActionResolver
-from mlia.cli.logging import setup_logging
from mlia.cli.options import add_backend_install_options
from mlia.cli.options import add_backend_options
from mlia.cli.options import add_backend_uninstall_options
@@ -30,9 +29,11 @@ from mlia.cli.options import add_model_options
from mlia.cli.options import add_multi_optimization_options
from mlia.cli.options import add_output_options
from mlia.cli.options import add_target_options
+from mlia.cli.options import get_output_format
from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
+from mlia.core.logging import setup_logging
from mlia.target.registry import registry as target_registry
@@ -162,10 +163,13 @@ def setup_context(
ctx = ExecutionContext(
verbose="debug" in args and args.debug,
action_resolver=CLIActionResolver(vars(args)),
+ output_format=get_output_format(args),
)
+ setup_logging(ctx.logs_path, ctx.verbose, ctx.output_format)
+
# these parameters should not be passed into command function
- skipped_params = ["func", "command", "debug"]
+ skipped_params = ["func", "command", "debug", "json"]
# pass these parameters only if command expects them
expected_params = [context_var_name]
@@ -186,7 +190,6 @@ def setup_context(
def run_command(args: argparse.Namespace) -> int:
"""Run command."""
ctx, func_args = setup_context(args)
- setup_logging(ctx.logs_path, ctx.verbose)
logger.debug(
"*** This is the beginning of the command '%s' execution ***", args.command
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 5aca3b3..e01f107 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -12,7 +12,7 @@ from mlia.cli.config import DEFAULT_CLUSTERING_TARGET
from mlia.cli.config import DEFAULT_PRUNING_TARGET
from mlia.cli.config import get_available_backends
from mlia.cli.config import is_corstone_backend
-from mlia.core.common import FormattedFilePath
+from mlia.core.typing import OutputFormat
from mlia.utils.filesystem import get_supported_profile_names
@@ -89,16 +89,9 @@ def add_output_options(parser: argparse.ArgumentParser) -> None:
"""Add output specific options."""
output_group = parser.add_argument_group("output options")
output_group.add_argument(
- "-o",
- "--output",
- type=Path,
- help=("Name of the file where the report will be saved."),
- )
-
- output_group.add_argument(
"--json",
action="store_true",
- help=("Format to use for the output (requires --output argument to be set)."),
+ help=("Print the output in JSON format."),
)
@@ -209,22 +202,6 @@ def add_backend_options(
)
-def parse_output_parameters(path: Path | None, json: bool) -> FormattedFilePath | None:
- """Parse and return path and file format as FormattedFilePath."""
- if not path and json:
- raise argparse.ArgumentError(
- None,
- "To enable JSON output you need to specify the output path. "
- "(e.g. --output out.json --json)",
- )
- if not path:
- return None
- if json:
- return FormattedFilePath(path, "json")
-
- return FormattedFilePath(path, "plain_text")
-
-
def parse_optimization_parameters(
pruning: bool = False,
clustering: bool = False,
@@ -301,3 +278,11 @@ def get_target_profile_opts(device_args: dict | None) -> list[str]:
for name in non_default
for item in construct_param(params_name[name], device_args[name])
]
+
+
+def get_output_format(args: argparse.Namespace) -> OutputFormat:
+ """Return the OutputFormat depending on the CLI flags."""
+ output_format: OutputFormat = "plain_text"
+ if "json" in args and args.json:
+ output_format = "json"
+ return output_format
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
index 53df001..baaed50 100644
--- a/src/mlia/core/common.py
+++ b/src/mlia/core/common.py
@@ -13,9 +13,6 @@ from enum import auto
from enum import Flag
from typing import Any
-from mlia.core.typing import OutputFormat
-from mlia.core.typing import PathOrFileLike
-
# This type is used as type alias for the items which are being passed around
# in advisor workflow. There are no restrictions on the type of the
# object. This alias used only to emphasize the nature of the input/output
@@ -23,36 +20,6 @@ from mlia.core.typing import PathOrFileLike
DataItem = Any
-class FormattedFilePath:
- """Class used to keep track of the format that a path points to."""
-
- def __init__(self, path: PathOrFileLike, fmt: OutputFormat = "plain_text") -> None:
- """Init FormattedFilePath."""
- self._path = path
- self._fmt = fmt
-
- @property
- def fmt(self) -> OutputFormat:
- """Return file format."""
- return self._fmt
-
- @property
- def path(self) -> PathOrFileLike:
- """Return file path."""
- return self._path
-
- def __eq__(self, other: object) -> bool:
- """Check for equality with other objects."""
- if isinstance(other, FormattedFilePath):
- return other.fmt == self.fmt and other.path == self.path
-
- return False
-
- def __repr__(self) -> str:
- """Represent object."""
- return f"FormattedFilePath {self.path=}, {self.fmt=}"
-
-
class AdviceCategory(Flag):
"""Advice category.
diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py
index 94aa885..f8442a3 100644
--- a/src/mlia/core/context.py
+++ b/src/mlia/core/context.py
@@ -23,6 +23,7 @@ from mlia.core.events import EventHandler
from mlia.core.events import EventPublisher
from mlia.core.helpers import ActionResolver
from mlia.core.helpers import APIActionResolver
+from mlia.core.typing import OutputFormat
logger = logging.getLogger(__name__)
@@ -68,6 +69,11 @@ class Context(ABC):
def action_resolver(self) -> ActionResolver:
"""Return action resolver."""
+ @property
+ @abstractmethod
+ def output_format(self) -> OutputFormat:
+ """Return the output format."""
+
@abstractmethod
def update(
self,
@@ -106,6 +112,7 @@ class ExecutionContext(Context):
logs_dir: str = "logs",
models_dir: str = "models",
action_resolver: ActionResolver | None = None,
+ output_format: OutputFormat = "plain_text",
) -> None:
"""Init execution context.
@@ -139,6 +146,7 @@ class ExecutionContext(Context):
self.logs_dir = logs_dir
self.models_dir = models_dir
self._action_resolver = action_resolver or APIActionResolver()
+ self._output_format = output_format
@property
def working_dir(self) -> Path:
@@ -197,6 +205,11 @@ class ExecutionContext(Context):
"""Return path to the logs directory."""
return self._working_dir_path / self.logs_dir
+ @property
+ def output_format(self) -> OutputFormat:
+ """Return the output format."""
+ return self._output_format
+
def update(
self,
*,
@@ -221,7 +234,8 @@ class ExecutionContext(Context):
f"ExecutionContext: working_dir={self._working_dir_path}, "
f"advice_category={category}, "
f"config_parameters={self.config_parameters}, "
- f"verbose={self.verbose}"
+ f"verbose={self.verbose}, "
+ f"output_format={self.output_format}"
)
diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py
index 6e50934..24d4881 100644
--- a/src/mlia/core/handlers.py
+++ b/src/mlia/core/handlers.py
@@ -9,7 +9,6 @@ from typing import Callable
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
-from mlia.core.common import FormattedFilePath
from mlia.core.events import ActionFinishedEvent
from mlia.core.events import ActionStartedEvent
from mlia.core.events import AdviceStageFinishedEvent
@@ -25,9 +24,11 @@ from mlia.core.events import EventDispatcher
from mlia.core.events import ExecutionFailedEvent
from mlia.core.events import ExecutionFinishedEvent
from mlia.core.events import ExecutionStartedEvent
+from mlia.core.mixins import ContextMixin
+from mlia.core.reporting import JSONReporter
from mlia.core.reporting import Report
from mlia.core.reporting import Reporter
-from mlia.core.typing import PathOrFileLike
+from mlia.core.reporting import TextReporter
from mlia.utils.console import create_section_header
@@ -92,26 +93,27 @@ _ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started")
_MODEL_ANALYSIS_MSG = create_section_header("Model Analysis")
_MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results")
_ADV_GENERATION_MSG = create_section_header("Advice Generation")
-_REPORT_GENERATION_MSG = create_section_header("Report Generation")
-class WorkflowEventsHandler(SystemEventsHandler):
+class WorkflowEventsHandler(SystemEventsHandler, ContextMixin):
"""Event handler for the system events."""
+ reporter: Reporter
+
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output: FormattedFilePath | None = None,
) -> None:
"""Init event handler."""
- output_format = output.fmt if output else "plain_text"
- self.reporter = Reporter(formatter_resolver, output_format)
- self.output = output.path if output else None
-
+ self.formatter_resolver = formatter_resolver
self.advice: list[Advice] = []
def on_execution_started(self, event: ExecutionStartedEvent) -> None:
"""Handle ExecutionStarted event."""
+ if self.context.output_format == "json":
+ self.reporter = JSONReporter(self.formatter_resolver)
+ else:
+ self.reporter = TextReporter(self.formatter_resolver)
logger.info(_ADV_EXECUTION_STARTED)
def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
@@ -132,12 +134,6 @@ class WorkflowEventsHandler(SystemEventsHandler):
"""Handle DataCollectorSkipped event."""
logger.info("Skipped: %s", event.reason)
- @staticmethod
- def report_generated(output: PathOrFileLike) -> None:
- """Log report generation."""
- logger.info(_REPORT_GENERATION_MSG)
- logger.info("Report(s) and advice list saved to: %s", output)
-
def on_data_analysis_stage_finished(
self, event: DataAnalysisStageFinishedEvent
) -> None:
@@ -160,7 +156,4 @@ class WorkflowEventsHandler(SystemEventsHandler):
table_style="no_borders",
)
- self.reporter.generate_report(self.output)
-
- if self.output is not None:
- self.report_generated(self.output)
+ self.reporter.generate_report()
diff --git a/src/mlia/cli/logging.py b/src/mlia/core/logging.py
index 5c5c4b8..686f8ab 100644
--- a/src/mlia/cli/logging.py
+++ b/src/mlia/core/logging.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI logging configuration."""
from __future__ import annotations
@@ -8,18 +8,20 @@ import sys
from pathlib import Path
from typing import Iterable
+from mlia.core.typing import OutputFormat
from mlia.utils.logging import attach_handlers
from mlia.utils.logging import create_log_handler
-from mlia.utils.logging import LogFilter
+from mlia.utils.logging import NoASCIIFormatter
-_CONSOLE_DEBUG_FORMAT = "%(name)s - %(message)s"
+_CONSOLE_DEBUG_FORMAT = "%(name)s - %(levelname)s - %(message)s"
_FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
def setup_logging(
logs_dir: str | Path | None = None,
verbose: bool = False,
+ output_format: OutputFormat = "plain_text",
log_filename: str = "mlia.log",
) -> None:
"""Set up logging.
@@ -30,6 +32,8 @@ def setup_logging(
debug information. If the path is not provided then no log files will
be created during execution
:param verbose: enable extended logging for the tools loggers
+ :param output_format: specify the out format needed for setting up the right
+ logging system
:param log_filename: name of the log file in the logs directory
"""
mlia_logger, tensorflow_logger, py_warnings_logger = (
@@ -42,7 +46,7 @@ def setup_logging(
for logger in [mlia_logger, tensorflow_logger]:
logger.setLevel(logging.DEBUG)
- mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose)
+ mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose, output_format)
attach_handlers(mlia_handlers, [mlia_logger])
tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose)
@@ -53,28 +57,47 @@ def _get_mlia_handlers(
logs_dir: str | Path | None,
log_filename: str,
verbose: bool,
+ output_format: OutputFormat,
) -> Iterable[logging.Handler]:
"""Get handlers for the MLIA loggers."""
- yield create_log_handler(
- stream=sys.stdout,
- log_level=logging.INFO,
- )
-
- if verbose:
- mlia_verbose_handler = create_log_handler(
- stream=sys.stdout,
- log_level=logging.DEBUG,
- log_format=_CONSOLE_DEBUG_FORMAT,
- log_filter=LogFilter.equals(logging.DEBUG),
+ # MLIA needs output to standard output via the logging system only when the
+ # format is plain text. When the user specifies the "json" output format,
+ # MLIA disables completely the logging system for the console output and it
+ # relies on the print() function. This is needed because the output might
+ # be corrupted with spurious messages in the standard output.
+ if output_format == "plain_text":
+ if verbose:
+ log_level = logging.DEBUG
+ log_format = _CONSOLE_DEBUG_FORMAT
+ else:
+ log_level = logging.INFO
+ log_format = None
+
+ # Create log handler for stdout
+ yield create_log_handler(
+ stream=sys.stdout, log_level=log_level, log_format=log_format
+ )
+ else:
+ # In case of non plain text output, we need to inform the user if an
+ # error happens during execution.
+ yield create_log_handler(
+ stream=sys.stderr,
+ log_level=logging.ERROR,
)
- yield mlia_verbose_handler
+ # If the logs directory is specified, MLIA stores all output (according to
+ # the logging level) into the file and removing the colouring of the
+ # console output.
if logs_dir:
+ if verbose:
+ log_level = logging.DEBUG
+ else:
+ log_level = logging.INFO
+
yield create_log_handler(
file_path=_get_log_file(logs_dir, log_filename),
- log_level=logging.DEBUG,
- log_format=_FILE_DEBUG_FORMAT,
- log_filter=LogFilter.skip(logging.INFO),
+ log_level=log_level,
+ log_format=NoASCIIFormatter(fmt=_FILE_DEBUG_FORMAT),
delay=True,
)
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index ad63d62..7b9ce5c 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -8,30 +8,21 @@ import logging
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
-from contextlib import contextmanager
-from contextlib import ExitStack
from dataclasses import dataclass
from enum import Enum
from functools import partial
-from io import TextIOWrapper
-from pathlib import Path
from textwrap import fill
from textwrap import indent
from typing import Any
from typing import Callable
-from typing import cast
from typing import Collection
-from typing import Generator
from typing import Iterable
import numpy as np
-from mlia.core.typing import FileLike
from mlia.core.typing import OutputFormat
-from mlia.core.typing import PathOrFileLike
from mlia.utils.console import apply_style
from mlia.utils.console import produce_table
-from mlia.utils.logging import LoggerWriter
from mlia.utils.types import is_list_of
logger = logging.getLogger(__name__)
@@ -505,76 +496,48 @@ class CustomJSONEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, o)
-def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
- """Produce report in json format."""
- json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder)
- print(json_str, file=output)
-
-
-def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
- """Produce report in text format."""
- print(report.to_plain_text(**kwargs), file=output)
+class Reporter(ABC):
+ """Reporter class."""
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ ) -> None:
+ """Init reporter instance."""
+ self.formatter_resolver = formatter_resolver
+ self.data: list[tuple[Any, Callable[[Any], Report]]] = []
-def produce_report(
- data: Any,
- formatter: Callable[[Any], Report],
- fmt: OutputFormat = "plain_text",
- output: PathOrFileLike | None = None,
- **kwargs: Any,
-) -> None:
- """Produce report based on provided data."""
- # check if provided format value is supported
- formats = {"json": json_reporter, "plain_text": text_reporter}
- if fmt not in formats:
- raise Exception(f"Unknown format {fmt}")
+ @abstractmethod
+ def submit(self, data_item: Any, **kwargs: Any) -> None:
+ """Submit data for the report."""
- if output is None:
- output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO))
+ def print_delayed(self) -> None:
+ """Print delayed reports."""
- with ExitStack() as exit_stack:
- if isinstance(output, (str, Path)):
- # open file and add it to the ExitStack context manager
- # in that case it will be automatically closed
- stream = exit_stack.enter_context(open(output, "w", encoding="utf-8"))
- else:
- stream = cast(TextIOWrapper, output)
+ def generate_report(self) -> None:
+ """Generate report."""
- # convert data into serializable form
- formatted_data = formatter(data)
- # find handler for the format
- format_handler = formats[fmt]
- # produce report in requested format
- format_handler(formatted_data, stream, **kwargs)
+ @abstractmethod
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
-class Reporter:
+class TextReporter(Reporter):
"""Reporter class."""
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output_format: OutputFormat = "plain_text",
- print_as_submitted: bool = True,
) -> None:
"""Init reporter instance."""
- self.formatter_resolver = formatter_resolver
- self.output_format = output_format
- self.print_as_submitted = print_as_submitted
-
- self.data: list[tuple[Any, Callable[[Any], Report]]] = []
+ super().__init__(formatter_resolver)
self.delayed: list[tuple[Any, Callable[[Any], Report]]] = []
+ self.output_format: OutputFormat = "plain_text"
def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None:
"""Submit data for the report."""
- if self.print_as_submitted and not delay_print:
- produce_report(
- data_item,
- self.formatter_resolver(data_item),
- fmt="plain_text",
- **kwargs,
- )
-
formatter = _apply_format_parameters(
self.formatter_resolver(data_item), self.output_format, **kwargs
)
@@ -582,51 +545,70 @@ class Reporter:
if delay_print:
self.delayed.append((data_item, formatter))
+ else:
+ self.produce_report(
+ data_item,
+ self.formatter_resolver(data_item),
+ **kwargs,
+ )
def print_delayed(self) -> None:
"""Print delayed reports."""
- if not self.delayed:
- return
+ if self.delayed:
+ data, formatters = zip(*self.delayed)
+ self.produce_report(
+ data,
+ formatter=CompoundFormatter(formatters),
+ )
+ self.delayed = []
- data, formatters = zip(*self.delayed)
- produce_report(
- data,
- formatter=CompoundFormatter(formatters),
- fmt="plain_text",
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
+ formatted_data = formatter(data)
+ logger.info(formatted_data.to_plain_text(**kwargs))
+
+
+class JSONReporter(Reporter):
+ """Reporter class."""
+
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ ) -> None:
+ """Init reporter instance."""
+ super().__init__(formatter_resolver)
+ self.output_format: OutputFormat = "json"
+
+ def submit(self, data_item: Any, **kwargs: Any) -> None:
+ """Submit data for the report."""
+ formatter = _apply_format_parameters(
+ self.formatter_resolver(data_item), self.output_format, **kwargs
)
- self.delayed = []
+ self.data.append((data_item, formatter))
- def generate_report(self, output: PathOrFileLike | None) -> None:
+ def generate_report(self) -> None:
"""Generate report."""
- already_printed = (
- self.print_as_submitted
- and self.output_format == "plain_text"
- and output is None
- )
- if not self.data or already_printed:
+ if not self.data:
return
data, formatters = zip(*self.data)
- produce_report(
+ self.produce_report(
data,
formatter=CompoundFormatter(formatters),
- fmt=self.output_format,
- output=output,
)
-
-@contextmanager
-def get_reporter(
- output_format: OutputFormat,
- output: PathOrFileLike | None,
- formatter_resolver: Callable[[Any], Callable[[Any], Report]],
-) -> Generator[Reporter, None, None]:
- """Get reporter and generate report."""
- reporter = Reporter(formatter_resolver, output_format)
-
- yield reporter
-
- reporter.generate_report(output)
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
+ formatted_data = formatter(data)
+ print(
+ json.dumps(
+ formatted_data.to_json(**kwargs), indent=4, cls=CustomJSONEncoder
+ ),
+ )
def _apply_format_parameters(
diff --git a/src/mlia/core/typing.py b/src/mlia/core/typing.py
index 8244ff5..ea334c9 100644
--- a/src/mlia/core/typing.py
+++ b/src/mlia/core/typing.py
@@ -1,12 +1,7 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for custom type hints."""
-from pathlib import Path
from typing import Literal
-from typing import TextIO
-from typing import Union
-FileLike = TextIO
-PathOrFileLike = Union[str, Path, FileLike]
OutputFormat = Literal["plain_text", "json"]
diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py
index d862a86..9f8ac83 100644
--- a/src/mlia/core/workflow.py
+++ b/src/mlia/core/workflow.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for executors.
@@ -198,6 +198,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
self.collectors,
self.analyzers,
self.producers,
+ self.context.event_handlers or [],
)
if isinstance(comp, ContextMixin)
)
diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py
index b649f0d..1249d93 100644
--- a/src/mlia/target/cortex_a/advisor.py
+++ b/src/mlia/target/cortex_a/advisor.py
@@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
-from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
@@ -67,12 +66,11 @@ def configure_and_get_cortexa_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Cortex-A advisor."""
if context.event_handlers is None:
- context.event_handlers = [CortexAEventHandler(output)]
+ context.event_handlers = [CortexAEventHandler()]
if context.config_parameters is None:
context.config_parameters = _get_config_parameters(model, target_profile)
diff --git a/src/mlia/target/cortex_a/handlers.py b/src/mlia/target/cortex_a/handlers.py
index d6acde5..1a74da7 100644
--- a/src/mlia/target/cortex_a/handlers.py
+++ b/src/mlia/target/cortex_a/handlers.py
@@ -5,7 +5,6 @@ from __future__ import annotations
import logging
-from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
@@ -20,9 +19,9 @@ logger = logging.getLogger(__name__)
class CortexAEventHandler(WorkflowEventsHandler, CortexAAdvisorEventHandler):
"""CLI event handler."""
- def __init__(self, output: FormattedFilePath | None = None) -> None:
+ def __init__(self) -> None:
"""Init event handler."""
- super().__init__(cortex_a_formatters, output)
+ super().__init__(cortex_a_formatters)
def on_collected_data(self, event: CollectedDataEvent) -> None:
"""Handle CollectedDataEvent event."""
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index 640c3e1..ce4e0fc 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
-from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
@@ -126,12 +125,11 @@ def configure_and_get_ethosu_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Ethos-U advisor."""
if context.event_handlers is None:
- context.event_handlers = [EthosUEventHandler(output)]
+ context.event_handlers = [EthosUEventHandler()]
if context.config_parameters is None:
context.config_parameters = _get_config_parameters(
diff --git a/src/mlia/target/ethos_u/handlers.py b/src/mlia/target/ethos_u/handlers.py
index 91f6015..9873014 100644
--- a/src/mlia/target/ethos_u/handlers.py
+++ b/src/mlia/target/ethos_u/handlers.py
@@ -6,7 +6,6 @@ from __future__ import annotations
import logging
from mlia.backend.vela.compat import Operators
-from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
from mlia.target.ethos_u.events import EthosUAdvisorEventHandler
@@ -21,9 +20,9 @@ logger = logging.getLogger(__name__)
class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
"""CLI event handler."""
- def __init__(self, output: FormattedFilePath | None = None) -> None:
+ def __init__(self) -> None:
"""Init event handler."""
- super().__init__(ethos_u_formatters, output)
+ super().__init__(ethos_u_formatters)
def on_collected_data(self, event: CollectedDataEvent) -> None:
"""Handle CollectedDataEvent event."""
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index 0da44db..b60e824 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceCategory
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
-from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
@@ -81,12 +80,11 @@ def configure_and_get_tosa_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure TOSA advisor."""
if context.event_handlers is None:
- context.event_handlers = [TOSAEventHandler(output)]
+ context.event_handlers = [TOSAEventHandler()]
if context.config_parameters is None:
context.config_parameters = _get_config_parameters(model, target_profile)
diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py
index 7f80f77..f222823 100644
--- a/src/mlia/target/tosa/handlers.py
+++ b/src/mlia/target/tosa/handlers.py
@@ -7,7 +7,6 @@ from __future__ import annotations
import logging
from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
-from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
from mlia.target.tosa.events import TOSAAdvisorEventHandler
@@ -20,9 +19,9 @@ logger = logging.getLogger(__name__)
class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
"""Event handler for TOSA advisor."""
- def __init__(self, output: FormattedFilePath | None = None) -> None:
+ def __init__(self) -> None:
"""Init event handler."""
- super().__init__(tosa_formatters, output)
+ super().__init__(tosa_formatters)
def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
"""Handle TOSAAdvisorStartedEvent event."""
diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py
index 0659dcf..17f6cae 100644
--- a/src/mlia/utils/logging.py
+++ b/src/mlia/utils/logging.py
@@ -18,6 +18,8 @@ from typing import Generator
from typing import Iterable
from typing import TextIO
+from mlia.utils.console import remove_ascii_codes
+
class LoggerWriter:
"""Redirect printed messages to the logger."""
@@ -152,12 +154,21 @@ class LogFilter(logging.Filter):
return cls(skip_by_level)
+class NoASCIIFormatter(logging.Formatter):
+ """Custom Formatter for logging into file."""
+
+ def format(self, record: logging.LogRecord) -> str:
+ """Overwrite format method to remove ascii codes from record."""
+ result = super().format(record)
+ return remove_ascii_codes(result)
+
+
def create_log_handler(
*,
file_path: Path | None = None,
stream: Any | None = None,
log_level: int | None = None,
- log_format: str | None = None,
+ log_format: str | logging.Formatter | None = None,
log_filter: logging.Filter | None = None,
delay: bool = True,
) -> logging.Handler:
@@ -176,7 +187,9 @@ def create_log_handler(
handler.setLevel(log_level)
if log_format:
- handler.setFormatter(logging.Formatter(log_format))
+ if isinstance(log_format, str):
+ log_format = logging.Formatter(log_format)
+ handler.setFormatter(log_format)
if log_filter:
handler.addFilter(log_filter)