aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/core
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/core')
-rw-r--r--src/mlia/core/common.py33
-rw-r--r--src/mlia/core/context.py16
-rw-r--r--src/mlia/core/handlers.py31
-rw-r--r--src/mlia/core/logging.py132
-rw-r--r--src/mlia/core/reporting.py166
-rw-r--r--src/mlia/core/typing.py7
-rw-r--r--src/mlia/core/workflow.py3
7 files changed, 236 insertions, 152 deletions
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
index 53df001..baaed50 100644
--- a/src/mlia/core/common.py
+++ b/src/mlia/core/common.py
@@ -13,9 +13,6 @@ from enum import auto
from enum import Flag
from typing import Any
-from mlia.core.typing import OutputFormat
-from mlia.core.typing import PathOrFileLike
-
# This type is used as type alias for the items which are being passed around
# in advisor workflow. There are no restrictions on the type of the
# object. This alias used only to emphasize the nature of the input/output
@@ -23,36 +20,6 @@ from mlia.core.typing import PathOrFileLike
DataItem = Any
-class FormattedFilePath:
- """Class used to keep track of the format that a path points to."""
-
- def __init__(self, path: PathOrFileLike, fmt: OutputFormat = "plain_text") -> None:
- """Init FormattedFilePath."""
- self._path = path
- self._fmt = fmt
-
- @property
- def fmt(self) -> OutputFormat:
- """Return file format."""
- return self._fmt
-
- @property
- def path(self) -> PathOrFileLike:
- """Return file path."""
- return self._path
-
- def __eq__(self, other: object) -> bool:
- """Check for equality with other objects."""
- if isinstance(other, FormattedFilePath):
- return other.fmt == self.fmt and other.path == self.path
-
- return False
-
- def __repr__(self) -> str:
- """Represent object."""
- return f"FormattedFilePath {self.path=}, {self.fmt=}"
-
-
class AdviceCategory(Flag):
"""Advice category.
diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py
index 94aa885..f8442a3 100644
--- a/src/mlia/core/context.py
+++ b/src/mlia/core/context.py
@@ -23,6 +23,7 @@ from mlia.core.events import EventHandler
from mlia.core.events import EventPublisher
from mlia.core.helpers import ActionResolver
from mlia.core.helpers import APIActionResolver
+from mlia.core.typing import OutputFormat
logger = logging.getLogger(__name__)
@@ -68,6 +69,11 @@ class Context(ABC):
def action_resolver(self) -> ActionResolver:
"""Return action resolver."""
+ @property
+ @abstractmethod
+ def output_format(self) -> OutputFormat:
+ """Return the output format."""
+
@abstractmethod
def update(
self,
@@ -106,6 +112,7 @@ class ExecutionContext(Context):
logs_dir: str = "logs",
models_dir: str = "models",
action_resolver: ActionResolver | None = None,
+ output_format: OutputFormat = "plain_text",
) -> None:
"""Init execution context.
@@ -139,6 +146,7 @@ class ExecutionContext(Context):
self.logs_dir = logs_dir
self.models_dir = models_dir
self._action_resolver = action_resolver or APIActionResolver()
+ self._output_format = output_format
@property
def working_dir(self) -> Path:
@@ -197,6 +205,11 @@ class ExecutionContext(Context):
"""Return path to the logs directory."""
return self._working_dir_path / self.logs_dir
+ @property
+ def output_format(self) -> OutputFormat:
+ """Return the output format."""
+ return self._output_format
+
def update(
self,
*,
@@ -221,7 +234,8 @@ class ExecutionContext(Context):
f"ExecutionContext: working_dir={self._working_dir_path}, "
f"advice_category={category}, "
f"config_parameters={self.config_parameters}, "
- f"verbose={self.verbose}"
+ f"verbose={self.verbose}, "
+ f"output_format={self.output_format}"
)
diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py
index 6e50934..24d4881 100644
--- a/src/mlia/core/handlers.py
+++ b/src/mlia/core/handlers.py
@@ -9,7 +9,6 @@ from typing import Callable
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
-from mlia.core.common import FormattedFilePath
from mlia.core.events import ActionFinishedEvent
from mlia.core.events import ActionStartedEvent
from mlia.core.events import AdviceStageFinishedEvent
@@ -25,9 +24,11 @@ from mlia.core.events import EventDispatcher
from mlia.core.events import ExecutionFailedEvent
from mlia.core.events import ExecutionFinishedEvent
from mlia.core.events import ExecutionStartedEvent
+from mlia.core.mixins import ContextMixin
+from mlia.core.reporting import JSONReporter
from mlia.core.reporting import Report
from mlia.core.reporting import Reporter
-from mlia.core.typing import PathOrFileLike
+from mlia.core.reporting import TextReporter
from mlia.utils.console import create_section_header
@@ -92,26 +93,27 @@ _ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started")
_MODEL_ANALYSIS_MSG = create_section_header("Model Analysis")
_MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results")
_ADV_GENERATION_MSG = create_section_header("Advice Generation")
-_REPORT_GENERATION_MSG = create_section_header("Report Generation")
-class WorkflowEventsHandler(SystemEventsHandler):
+class WorkflowEventsHandler(SystemEventsHandler, ContextMixin):
"""Event handler for the system events."""
+ reporter: Reporter
+
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output: FormattedFilePath | None = None,
) -> None:
"""Init event handler."""
- output_format = output.fmt if output else "plain_text"
- self.reporter = Reporter(formatter_resolver, output_format)
- self.output = output.path if output else None
-
+ self.formatter_resolver = formatter_resolver
self.advice: list[Advice] = []
def on_execution_started(self, event: ExecutionStartedEvent) -> None:
"""Handle ExecutionStarted event."""
+ if self.context.output_format == "json":
+ self.reporter = JSONReporter(self.formatter_resolver)
+ else:
+ self.reporter = TextReporter(self.formatter_resolver)
logger.info(_ADV_EXECUTION_STARTED)
def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
@@ -132,12 +134,6 @@ class WorkflowEventsHandler(SystemEventsHandler):
"""Handle DataCollectorSkipped event."""
logger.info("Skipped: %s", event.reason)
- @staticmethod
- def report_generated(output: PathOrFileLike) -> None:
- """Log report generation."""
- logger.info(_REPORT_GENERATION_MSG)
- logger.info("Report(s) and advice list saved to: %s", output)
-
def on_data_analysis_stage_finished(
self, event: DataAnalysisStageFinishedEvent
) -> None:
@@ -160,7 +156,4 @@ class WorkflowEventsHandler(SystemEventsHandler):
table_style="no_borders",
)
- self.reporter.generate_report(self.output)
-
- if self.output is not None:
- self.report_generated(self.output)
+ self.reporter.generate_report()
diff --git a/src/mlia/core/logging.py b/src/mlia/core/logging.py
new file mode 100644
index 0000000..686f8ab
--- /dev/null
+++ b/src/mlia/core/logging.py
@@ -0,0 +1,132 @@
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI logging configuration."""
+from __future__ import annotations
+
+import logging
+import sys
+from pathlib import Path
+from typing import Iterable
+
+from mlia.core.typing import OutputFormat
+from mlia.utils.logging import attach_handlers
+from mlia.utils.logging import create_log_handler
+from mlia.utils.logging import NoASCIIFormatter
+
+
+_CONSOLE_DEBUG_FORMAT = "%(name)s - %(levelname)s - %(message)s"
+_FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+
+
+def setup_logging(
+ logs_dir: str | Path | None = None,
+ verbose: bool = False,
+ output_format: OutputFormat = "plain_text",
+ log_filename: str = "mlia.log",
+) -> None:
+ """Set up logging.
+
+ MLIA uses module 'logging' when it needs to produce output.
+
+ :param logs_dir: path to the directory where application will save logs with
+ debug information. If the path is not provided then no log files will
+ be created during execution
+ :param verbose: enable extended logging for the tools loggers
+ :param output_format: specify the out format needed for setting up the right
+ logging system
+ :param log_filename: name of the log file in the logs directory
+ """
+ mlia_logger, tensorflow_logger, py_warnings_logger = (
+ logging.getLogger(logger_name)
+ for logger_name in ["mlia", "tensorflow", "py.warnings"]
+ )
+
+ # enable debug output, actual message filtering depends on
+ # the provided parameters and being done at the handlers level
+ for logger in [mlia_logger, tensorflow_logger]:
+ logger.setLevel(logging.DEBUG)
+
+ mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose, output_format)
+ attach_handlers(mlia_handlers, [mlia_logger])
+
+ tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose)
+ attach_handlers(tools_handlers, [tensorflow_logger, py_warnings_logger])
+
+
+def _get_mlia_handlers(
+ logs_dir: str | Path | None,
+ log_filename: str,
+ verbose: bool,
+ output_format: OutputFormat,
+) -> Iterable[logging.Handler]:
+ """Get handlers for the MLIA loggers."""
+ # MLIA needs output to standard output via the logging system only when the
+ # format is plain text. When the user specifies the "json" output format,
+ # MLIA disables completely the logging system for the console output and it
+ # relies on the print() function. This is needed because the output might
+ # be corrupted with spurious messages in the standard output.
+ if output_format == "plain_text":
+ if verbose:
+ log_level = logging.DEBUG
+ log_format = _CONSOLE_DEBUG_FORMAT
+ else:
+ log_level = logging.INFO
+ log_format = None
+
+ # Create log handler for stdout
+ yield create_log_handler(
+ stream=sys.stdout, log_level=log_level, log_format=log_format
+ )
+ else:
+ # In case of non plain text output, we need to inform the user if an
+ # error happens during execution.
+ yield create_log_handler(
+ stream=sys.stderr,
+ log_level=logging.ERROR,
+ )
+
+ # If the logs directory is specified, MLIA stores all output (according to
+ # the logging level) into the file and removing the colouring of the
+ # console output.
+ if logs_dir:
+ if verbose:
+ log_level = logging.DEBUG
+ else:
+ log_level = logging.INFO
+
+ yield create_log_handler(
+ file_path=_get_log_file(logs_dir, log_filename),
+ log_level=log_level,
+ log_format=NoASCIIFormatter(fmt=_FILE_DEBUG_FORMAT),
+ delay=True,
+ )
+
+
+def _get_tools_handlers(
+ logs_dir: str | Path | None,
+ log_filename: str,
+ verbose: bool,
+) -> Iterable[logging.Handler]:
+ """Get handler for the tools loggers."""
+ if verbose:
+ yield create_log_handler(
+ stream=sys.stdout,
+ log_level=logging.DEBUG,
+ log_format=_CONSOLE_DEBUG_FORMAT,
+ )
+
+ if logs_dir:
+ yield create_log_handler(
+ file_path=_get_log_file(logs_dir, log_filename),
+ log_level=logging.DEBUG,
+ log_format=_FILE_DEBUG_FORMAT,
+ delay=True,
+ )
+
+
+def _get_log_file(logs_dir: str | Path, log_filename: str) -> Path:
+ """Get the log file path."""
+ logs_dir_path = Path(logs_dir)
+ logs_dir_path.mkdir(exist_ok=True)
+
+ return logs_dir_path / log_filename
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index ad63d62..7b9ce5c 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -8,30 +8,21 @@ import logging
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
-from contextlib import contextmanager
-from contextlib import ExitStack
from dataclasses import dataclass
from enum import Enum
from functools import partial
-from io import TextIOWrapper
-from pathlib import Path
from textwrap import fill
from textwrap import indent
from typing import Any
from typing import Callable
-from typing import cast
from typing import Collection
-from typing import Generator
from typing import Iterable
import numpy as np
-from mlia.core.typing import FileLike
from mlia.core.typing import OutputFormat
-from mlia.core.typing import PathOrFileLike
from mlia.utils.console import apply_style
from mlia.utils.console import produce_table
-from mlia.utils.logging import LoggerWriter
from mlia.utils.types import is_list_of
logger = logging.getLogger(__name__)
@@ -505,76 +496,48 @@ class CustomJSONEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, o)
-def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
- """Produce report in json format."""
- json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder)
- print(json_str, file=output)
-
-
-def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
- """Produce report in text format."""
- print(report.to_plain_text(**kwargs), file=output)
+class Reporter(ABC):
+ """Reporter class."""
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ ) -> None:
+ """Init reporter instance."""
+ self.formatter_resolver = formatter_resolver
+ self.data: list[tuple[Any, Callable[[Any], Report]]] = []
-def produce_report(
- data: Any,
- formatter: Callable[[Any], Report],
- fmt: OutputFormat = "plain_text",
- output: PathOrFileLike | None = None,
- **kwargs: Any,
-) -> None:
- """Produce report based on provided data."""
- # check if provided format value is supported
- formats = {"json": json_reporter, "plain_text": text_reporter}
- if fmt not in formats:
- raise Exception(f"Unknown format {fmt}")
+ @abstractmethod
+ def submit(self, data_item: Any, **kwargs: Any) -> None:
+ """Submit data for the report."""
- if output is None:
- output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO))
+ def print_delayed(self) -> None:
+ """Print delayed reports."""
- with ExitStack() as exit_stack:
- if isinstance(output, (str, Path)):
- # open file and add it to the ExitStack context manager
- # in that case it will be automatically closed
- stream = exit_stack.enter_context(open(output, "w", encoding="utf-8"))
- else:
- stream = cast(TextIOWrapper, output)
+ def generate_report(self) -> None:
+ """Generate report."""
- # convert data into serializable form
- formatted_data = formatter(data)
- # find handler for the format
- format_handler = formats[fmt]
- # produce report in requested format
- format_handler(formatted_data, stream, **kwargs)
+ @abstractmethod
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
-class Reporter:
+class TextReporter(Reporter):
"""Reporter class."""
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output_format: OutputFormat = "plain_text",
- print_as_submitted: bool = True,
) -> None:
"""Init reporter instance."""
- self.formatter_resolver = formatter_resolver
- self.output_format = output_format
- self.print_as_submitted = print_as_submitted
-
- self.data: list[tuple[Any, Callable[[Any], Report]]] = []
+ super().__init__(formatter_resolver)
self.delayed: list[tuple[Any, Callable[[Any], Report]]] = []
+ self.output_format: OutputFormat = "plain_text"
def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None:
"""Submit data for the report."""
- if self.print_as_submitted and not delay_print:
- produce_report(
- data_item,
- self.formatter_resolver(data_item),
- fmt="plain_text",
- **kwargs,
- )
-
formatter = _apply_format_parameters(
self.formatter_resolver(data_item), self.output_format, **kwargs
)
@@ -582,51 +545,70 @@ class Reporter:
if delay_print:
self.delayed.append((data_item, formatter))
+ else:
+ self.produce_report(
+ data_item,
+ self.formatter_resolver(data_item),
+ **kwargs,
+ )
def print_delayed(self) -> None:
"""Print delayed reports."""
- if not self.delayed:
- return
+ if self.delayed:
+ data, formatters = zip(*self.delayed)
+ self.produce_report(
+ data,
+ formatter=CompoundFormatter(formatters),
+ )
+ self.delayed = []
- data, formatters = zip(*self.delayed)
- produce_report(
- data,
- formatter=CompoundFormatter(formatters),
- fmt="plain_text",
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
+ formatted_data = formatter(data)
+ logger.info(formatted_data.to_plain_text(**kwargs))
+
+
+class JSONReporter(Reporter):
+ """Reporter class."""
+
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ ) -> None:
+ """Init reporter instance."""
+ super().__init__(formatter_resolver)
+ self.output_format: OutputFormat = "json"
+
+ def submit(self, data_item: Any, **kwargs: Any) -> None:
+ """Submit data for the report."""
+ formatter = _apply_format_parameters(
+ self.formatter_resolver(data_item), self.output_format, **kwargs
)
- self.delayed = []
+ self.data.append((data_item, formatter))
- def generate_report(self, output: PathOrFileLike | None) -> None:
+ def generate_report(self) -> None:
"""Generate report."""
- already_printed = (
- self.print_as_submitted
- and self.output_format == "plain_text"
- and output is None
- )
- if not self.data or already_printed:
+ if not self.data:
return
data, formatters = zip(*self.data)
- produce_report(
+ self.produce_report(
data,
formatter=CompoundFormatter(formatters),
- fmt=self.output_format,
- output=output,
)
-
-@contextmanager
-def get_reporter(
- output_format: OutputFormat,
- output: PathOrFileLike | None,
- formatter_resolver: Callable[[Any], Callable[[Any], Report]],
-) -> Generator[Reporter, None, None]:
- """Get reporter and generate report."""
- reporter = Reporter(formatter_resolver, output_format)
-
- yield reporter
-
- reporter.generate_report(output)
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
+ formatted_data = formatter(data)
+ print(
+ json.dumps(
+ formatted_data.to_json(**kwargs), indent=4, cls=CustomJSONEncoder
+ ),
+ )
def _apply_format_parameters(
diff --git a/src/mlia/core/typing.py b/src/mlia/core/typing.py
index 8244ff5..ea334c9 100644
--- a/src/mlia/core/typing.py
+++ b/src/mlia/core/typing.py
@@ -1,12 +1,7 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for custom type hints."""
-from pathlib import Path
from typing import Literal
-from typing import TextIO
-from typing import Union
-FileLike = TextIO
-PathOrFileLike = Union[str, Path, FileLike]
OutputFormat = Literal["plain_text", "json"]
diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py
index d862a86..9f8ac83 100644
--- a/src/mlia/core/workflow.py
+++ b/src/mlia/core/workflow.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for executors.
@@ -198,6 +198,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
self.collectors,
self.analyzers,
self.producers,
+ self.context.event_handlers or [],
)
if isinstance(comp, ContextMixin)
)