diff options
Diffstat (limited to 'src/mlia/core')
-rw-r--r-- | src/mlia/core/common.py | 33 | ||||
-rw-r--r-- | src/mlia/core/context.py | 16 | ||||
-rw-r--r-- | src/mlia/core/handlers.py | 31 | ||||
-rw-r--r-- | src/mlia/core/logging.py | 132 | ||||
-rw-r--r-- | src/mlia/core/reporting.py | 166 | ||||
-rw-r--r-- | src/mlia/core/typing.py | 7 | ||||
-rw-r--r-- | src/mlia/core/workflow.py | 3 |
7 files changed, 236 insertions, 152 deletions
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py index 53df001..baaed50 100644 --- a/src/mlia/core/common.py +++ b/src/mlia/core/common.py @@ -13,9 +13,6 @@ from enum import auto from enum import Flag from typing import Any -from mlia.core.typing import OutputFormat -from mlia.core.typing import PathOrFileLike - # This type is used as type alias for the items which are being passed around # in advisor workflow. There are no restrictions on the type of the # object. This alias used only to emphasize the nature of the input/output @@ -23,36 +20,6 @@ from mlia.core.typing import PathOrFileLike DataItem = Any -class FormattedFilePath: - """Class used to keep track of the format that a path points to.""" - - def __init__(self, path: PathOrFileLike, fmt: OutputFormat = "plain_text") -> None: - """Init FormattedFilePath.""" - self._path = path - self._fmt = fmt - - @property - def fmt(self) -> OutputFormat: - """Return file format.""" - return self._fmt - - @property - def path(self) -> PathOrFileLike: - """Return file path.""" - return self._path - - def __eq__(self, other: object) -> bool: - """Check for equality with other objects.""" - if isinstance(other, FormattedFilePath): - return other.fmt == self.fmt and other.path == self.path - - return False - - def __repr__(self) -> str: - """Represent object.""" - return f"FormattedFilePath {self.path=}, {self.fmt=}" - - class AdviceCategory(Flag): """Advice category. diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py index 94aa885..f8442a3 100644 --- a/src/mlia/core/context.py +++ b/src/mlia/core/context.py @@ -23,6 +23,7 @@ from mlia.core.events import EventHandler from mlia.core.events import EventPublisher from mlia.core.helpers import ActionResolver from mlia.core.helpers import APIActionResolver +from mlia.core.typing import OutputFormat logger = logging.getLogger(__name__) @@ -68,6 +69,11 @@ class Context(ABC): def action_resolver(self) -> ActionResolver: """Return action resolver.""" + @property + @abstractmethod + def output_format(self) -> OutputFormat: + """Return the output format.""" + @abstractmethod def update( self, @@ -106,6 +112,7 @@ class ExecutionContext(Context): logs_dir: str = "logs", models_dir: str = "models", action_resolver: ActionResolver | None = None, + output_format: OutputFormat = "plain_text", ) -> None: """Init execution context. @@ -139,6 +146,7 @@ class ExecutionContext(Context): self.logs_dir = logs_dir self.models_dir = models_dir self._action_resolver = action_resolver or APIActionResolver() + self._output_format = output_format @property def working_dir(self) -> Path: @@ -197,6 +205,11 @@ class ExecutionContext(Context): """Return path to the logs directory.""" return self._working_dir_path / self.logs_dir + @property + def output_format(self) -> OutputFormat: + """Return the output format.""" + return self._output_format + def update( self, *, @@ -221,7 +234,8 @@ class ExecutionContext(Context): f"ExecutionContext: working_dir={self._working_dir_path}, " f"advice_category={category}, " f"config_parameters={self.config_parameters}, " - f"verbose={self.verbose}" + f"verbose={self.verbose}, " + f"output_format={self.output_format}" ) diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py index 6e50934..24d4881 100644 --- a/src/mlia/core/handlers.py +++ b/src/mlia/core/handlers.py @@ -9,7 +9,6 @@ from typing import Callable from mlia.core.advice_generation import Advice from mlia.core.advice_generation import AdviceEvent -from mlia.core.common import FormattedFilePath from mlia.core.events import ActionFinishedEvent from mlia.core.events import ActionStartedEvent from mlia.core.events import AdviceStageFinishedEvent @@ -25,9 +24,11 @@ from mlia.core.events import EventDispatcher from mlia.core.events import ExecutionFailedEvent from mlia.core.events import ExecutionFinishedEvent from mlia.core.events import ExecutionStartedEvent +from mlia.core.mixins import ContextMixin +from mlia.core.reporting import JSONReporter from mlia.core.reporting import Report from mlia.core.reporting import Reporter -from mlia.core.typing import PathOrFileLike +from mlia.core.reporting import TextReporter from mlia.utils.console import create_section_header @@ -92,26 +93,27 @@ _ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started") _MODEL_ANALYSIS_MSG = create_section_header("Model Analysis") _MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results") _ADV_GENERATION_MSG = create_section_header("Advice Generation") -_REPORT_GENERATION_MSG = create_section_header("Report Generation") -class WorkflowEventsHandler(SystemEventsHandler): +class WorkflowEventsHandler(SystemEventsHandler, ContextMixin): """Event handler for the system events.""" + reporter: Reporter + def __init__( self, formatter_resolver: Callable[[Any], Callable[[Any], Report]], - output: FormattedFilePath | None = None, ) -> None: """Init event handler.""" - output_format = output.fmt if output else "plain_text" - self.reporter = Reporter(formatter_resolver, output_format) - self.output = output.path if output else None - + self.formatter_resolver = formatter_resolver self.advice: list[Advice] = [] def on_execution_started(self, event: ExecutionStartedEvent) -> None: """Handle ExecutionStarted event.""" + if self.context.output_format == "json": + self.reporter = JSONReporter(self.formatter_resolver) + else: + self.reporter = TextReporter(self.formatter_resolver) logger.info(_ADV_EXECUTION_STARTED) def on_execution_failed(self, event: ExecutionFailedEvent) -> None: @@ -132,12 +134,6 @@ class WorkflowEventsHandler(SystemEventsHandler): """Handle DataCollectorSkipped event.""" logger.info("Skipped: %s", event.reason) - @staticmethod - def report_generated(output: PathOrFileLike) -> None: - """Log report generation.""" - logger.info(_REPORT_GENERATION_MSG) - logger.info("Report(s) and advice list saved to: %s", output) - def on_data_analysis_stage_finished( self, event: DataAnalysisStageFinishedEvent ) -> None: @@ -160,7 +156,4 @@ class WorkflowEventsHandler(SystemEventsHandler): table_style="no_borders", ) - self.reporter.generate_report(self.output) - - if self.output is not None: - self.report_generated(self.output) + self.reporter.generate_report() diff --git a/src/mlia/core/logging.py b/src/mlia/core/logging.py new file mode 100644 index 0000000..686f8ab --- /dev/null +++ b/src/mlia/core/logging.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""CLI logging configuration.""" +from __future__ import annotations + +import logging +import sys +from pathlib import Path +from typing import Iterable + +from mlia.core.typing import OutputFormat +from mlia.utils.logging import attach_handlers +from mlia.utils.logging import create_log_handler +from mlia.utils.logging import NoASCIIFormatter + + +_CONSOLE_DEBUG_FORMAT = "%(name)s - %(levelname)s - %(message)s" +_FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + +def setup_logging( + logs_dir: str | Path | None = None, + verbose: bool = False, + output_format: OutputFormat = "plain_text", + log_filename: str = "mlia.log", +) -> None: + """Set up logging. + + MLIA uses module 'logging' when it needs to produce output. + + :param logs_dir: path to the directory where application will save logs with + debug information. If the path is not provided then no log files will + be created during execution + :param verbose: enable extended logging for the tools loggers + :param output_format: specify the out format needed for setting up the right + logging system + :param log_filename: name of the log file in the logs directory + """ + mlia_logger, tensorflow_logger, py_warnings_logger = ( + logging.getLogger(logger_name) + for logger_name in ["mlia", "tensorflow", "py.warnings"] + ) + + # enable debug output, actual message filtering depends on + # the provided parameters and being done at the handlers level + for logger in [mlia_logger, tensorflow_logger]: + logger.setLevel(logging.DEBUG) + + mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose, output_format) + attach_handlers(mlia_handlers, [mlia_logger]) + + tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose) + attach_handlers(tools_handlers, [tensorflow_logger, py_warnings_logger]) + + +def _get_mlia_handlers( + logs_dir: str | Path | None, + log_filename: str, + verbose: bool, + output_format: OutputFormat, +) -> Iterable[logging.Handler]: + """Get handlers for the MLIA loggers.""" + # MLIA needs output to standard output via the logging system only when the + # format is plain text. When the user specifies the "json" output format, + # MLIA disables completely the logging system for the console output and it + # relies on the print() function. This is needed because the output might + # be corrupted with spurious messages in the standard output. + if output_format == "plain_text": + if verbose: + log_level = logging.DEBUG + log_format = _CONSOLE_DEBUG_FORMAT + else: + log_level = logging.INFO + log_format = None + + # Create log handler for stdout + yield create_log_handler( + stream=sys.stdout, log_level=log_level, log_format=log_format + ) + else: + # In case of non plain text output, we need to inform the user if an + # error happens during execution. + yield create_log_handler( + stream=sys.stderr, + log_level=logging.ERROR, + ) + + # If the logs directory is specified, MLIA stores all output (according to + # the logging level) into the file and removing the colouring of the + # console output. + if logs_dir: + if verbose: + log_level = logging.DEBUG + else: + log_level = logging.INFO + + yield create_log_handler( + file_path=_get_log_file(logs_dir, log_filename), + log_level=log_level, + log_format=NoASCIIFormatter(fmt=_FILE_DEBUG_FORMAT), + delay=True, + ) + + +def _get_tools_handlers( + logs_dir: str | Path | None, + log_filename: str, + verbose: bool, +) -> Iterable[logging.Handler]: + """Get handler for the tools loggers.""" + if verbose: + yield create_log_handler( + stream=sys.stdout, + log_level=logging.DEBUG, + log_format=_CONSOLE_DEBUG_FORMAT, + ) + + if logs_dir: + yield create_log_handler( + file_path=_get_log_file(logs_dir, log_filename), + log_level=logging.DEBUG, + log_format=_FILE_DEBUG_FORMAT, + delay=True, + ) + + +def _get_log_file(logs_dir: str | Path, log_filename: str) -> Path: + """Get the log file path.""" + logs_dir_path = Path(logs_dir) + logs_dir_path.mkdir(exist_ok=True) + + return logs_dir_path / log_filename diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py index ad63d62..7b9ce5c 100644 --- a/src/mlia/core/reporting.py +++ b/src/mlia/core/reporting.py @@ -8,30 +8,21 @@ import logging from abc import ABC from abc import abstractmethod from collections import defaultdict -from contextlib import contextmanager -from contextlib import ExitStack from dataclasses import dataclass from enum import Enum from functools import partial -from io import TextIOWrapper -from pathlib import Path from textwrap import fill from textwrap import indent from typing import Any from typing import Callable -from typing import cast from typing import Collection -from typing import Generator from typing import Iterable import numpy as np -from mlia.core.typing import FileLike from mlia.core.typing import OutputFormat -from mlia.core.typing import PathOrFileLike from mlia.utils.console import apply_style from mlia.utils.console import produce_table -from mlia.utils.logging import LoggerWriter from mlia.utils.types import is_list_of logger = logging.getLogger(__name__) @@ -505,76 +496,48 @@ class CustomJSONEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, o) -def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: - """Produce report in json format.""" - json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder) - print(json_str, file=output) - - -def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: - """Produce report in text format.""" - print(report.to_plain_text(**kwargs), file=output) +class Reporter(ABC): + """Reporter class.""" + def __init__( + self, + formatter_resolver: Callable[[Any], Callable[[Any], Report]], + ) -> None: + """Init reporter instance.""" + self.formatter_resolver = formatter_resolver + self.data: list[tuple[Any, Callable[[Any], Report]]] = [] -def produce_report( - data: Any, - formatter: Callable[[Any], Report], - fmt: OutputFormat = "plain_text", - output: PathOrFileLike | None = None, - **kwargs: Any, -) -> None: - """Produce report based on provided data.""" - # check if provided format value is supported - formats = {"json": json_reporter, "plain_text": text_reporter} - if fmt not in formats: - raise Exception(f"Unknown format {fmt}") + @abstractmethod + def submit(self, data_item: Any, **kwargs: Any) -> None: + """Submit data for the report.""" - if output is None: - output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO)) + def print_delayed(self) -> None: + """Print delayed reports.""" - with ExitStack() as exit_stack: - if isinstance(output, (str, Path)): - # open file and add it to the ExitStack context manager - # in that case it will be automatically closed - stream = exit_stack.enter_context(open(output, "w", encoding="utf-8")) - else: - stream = cast(TextIOWrapper, output) + def generate_report(self) -> None: + """Generate report.""" - # convert data into serializable form - formatted_data = formatter(data) - # find handler for the format - format_handler = formats[fmt] - # produce report in requested format - format_handler(formatted_data, stream, **kwargs) + @abstractmethod + def produce_report( + self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any + ) -> None: + """Produce report based on provided data.""" -class Reporter: +class TextReporter(Reporter): """Reporter class.""" def __init__( self, formatter_resolver: Callable[[Any], Callable[[Any], Report]], - output_format: OutputFormat = "plain_text", - print_as_submitted: bool = True, ) -> None: """Init reporter instance.""" - self.formatter_resolver = formatter_resolver - self.output_format = output_format - self.print_as_submitted = print_as_submitted - - self.data: list[tuple[Any, Callable[[Any], Report]]] = [] + super().__init__(formatter_resolver) self.delayed: list[tuple[Any, Callable[[Any], Report]]] = [] + self.output_format: OutputFormat = "plain_text" def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None: """Submit data for the report.""" - if self.print_as_submitted and not delay_print: - produce_report( - data_item, - self.formatter_resolver(data_item), - fmt="plain_text", - **kwargs, - ) - formatter = _apply_format_parameters( self.formatter_resolver(data_item), self.output_format, **kwargs ) @@ -582,51 +545,70 @@ class Reporter: if delay_print: self.delayed.append((data_item, formatter)) + else: + self.produce_report( + data_item, + self.formatter_resolver(data_item), + **kwargs, + ) def print_delayed(self) -> None: """Print delayed reports.""" - if not self.delayed: - return + if self.delayed: + data, formatters = zip(*self.delayed) + self.produce_report( + data, + formatter=CompoundFormatter(formatters), + ) + self.delayed = [] - data, formatters = zip(*self.delayed) - produce_report( - data, - formatter=CompoundFormatter(formatters), - fmt="plain_text", + def produce_report( + self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any + ) -> None: + """Produce report based on provided data.""" + formatted_data = formatter(data) + logger.info(formatted_data.to_plain_text(**kwargs)) + + +class JSONReporter(Reporter): + """Reporter class.""" + + def __init__( + self, + formatter_resolver: Callable[[Any], Callable[[Any], Report]], + ) -> None: + """Init reporter instance.""" + super().__init__(formatter_resolver) + self.output_format: OutputFormat = "json" + + def submit(self, data_item: Any, **kwargs: Any) -> None: + """Submit data for the report.""" + formatter = _apply_format_parameters( + self.formatter_resolver(data_item), self.output_format, **kwargs ) - self.delayed = [] + self.data.append((data_item, formatter)) - def generate_report(self, output: PathOrFileLike | None) -> None: + def generate_report(self) -> None: """Generate report.""" - already_printed = ( - self.print_as_submitted - and self.output_format == "plain_text" - and output is None - ) - if not self.data or already_printed: + if not self.data: return data, formatters = zip(*self.data) - produce_report( + self.produce_report( data, formatter=CompoundFormatter(formatters), - fmt=self.output_format, - output=output, ) - -@contextmanager -def get_reporter( - output_format: OutputFormat, - output: PathOrFileLike | None, - formatter_resolver: Callable[[Any], Callable[[Any], Report]], -) -> Generator[Reporter, None, None]: - """Get reporter and generate report.""" - reporter = Reporter(formatter_resolver, output_format) - - yield reporter - - reporter.generate_report(output) + def produce_report( + self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any + ) -> None: + """Produce report based on provided data.""" + formatted_data = formatter(data) + print( + json.dumps( + formatted_data.to_json(**kwargs), indent=4, cls=CustomJSONEncoder + ), + ) def _apply_format_parameters( diff --git a/src/mlia/core/typing.py b/src/mlia/core/typing.py index 8244ff5..ea334c9 100644 --- a/src/mlia/core/typing.py +++ b/src/mlia/core/typing.py @@ -1,12 +1,7 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for custom type hints.""" -from pathlib import Path from typing import Literal -from typing import TextIO -from typing import Union -FileLike = TextIO -PathOrFileLike = Union[str, Path, FileLike] OutputFormat = Literal["plain_text", "json"] diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py index d862a86..9f8ac83 100644 --- a/src/mlia/core/workflow.py +++ b/src/mlia/core/workflow.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for executors. @@ -198,6 +198,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): self.collectors, self.analyzers, self.producers, + self.context.event_handlers or [], ) if isinstance(comp, ContextMixin) ) |