diff options
Diffstat (limited to 'src/mlia/core/reporting.py')
-rw-r--r-- | src/mlia/core/reporting.py | 166 |
1 files changed, 74 insertions, 92 deletions
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py index ad63d62..7b9ce5c 100644 --- a/src/mlia/core/reporting.py +++ b/src/mlia/core/reporting.py @@ -8,30 +8,21 @@ import logging from abc import ABC from abc import abstractmethod from collections import defaultdict -from contextlib import contextmanager -from contextlib import ExitStack from dataclasses import dataclass from enum import Enum from functools import partial -from io import TextIOWrapper -from pathlib import Path from textwrap import fill from textwrap import indent from typing import Any from typing import Callable -from typing import cast from typing import Collection -from typing import Generator from typing import Iterable import numpy as np -from mlia.core.typing import FileLike from mlia.core.typing import OutputFormat -from mlia.core.typing import PathOrFileLike from mlia.utils.console import apply_style from mlia.utils.console import produce_table -from mlia.utils.logging import LoggerWriter from mlia.utils.types import is_list_of logger = logging.getLogger(__name__) @@ -505,76 +496,48 @@ class CustomJSONEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, o) -def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: - """Produce report in json format.""" - json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder) - print(json_str, file=output) - - -def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: - """Produce report in text format.""" - print(report.to_plain_text(**kwargs), file=output) +class Reporter(ABC): + """Reporter class.""" + def __init__( + self, + formatter_resolver: Callable[[Any], Callable[[Any], Report]], + ) -> None: + """Init reporter instance.""" + self.formatter_resolver = formatter_resolver + self.data: list[tuple[Any, Callable[[Any], Report]]] = [] -def produce_report( - data: Any, - formatter: Callable[[Any], Report], - fmt: OutputFormat = "plain_text", - output: PathOrFileLike | None = None, - **kwargs: Any, -) -> None: - """Produce report based on provided data.""" - # check if provided format value is supported - formats = {"json": json_reporter, "plain_text": text_reporter} - if fmt not in formats: - raise Exception(f"Unknown format {fmt}") + @abstractmethod + def submit(self, data_item: Any, **kwargs: Any) -> None: + """Submit data for the report.""" - if output is None: - output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO)) + def print_delayed(self) -> None: + """Print delayed reports.""" - with ExitStack() as exit_stack: - if isinstance(output, (str, Path)): - # open file and add it to the ExitStack context manager - # in that case it will be automatically closed - stream = exit_stack.enter_context(open(output, "w", encoding="utf-8")) - else: - stream = cast(TextIOWrapper, output) + def generate_report(self) -> None: + """Generate report.""" - # convert data into serializable form - formatted_data = formatter(data) - # find handler for the format - format_handler = formats[fmt] - # produce report in requested format - format_handler(formatted_data, stream, **kwargs) + @abstractmethod + def produce_report( + self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any + ) -> None: + """Produce report based on provided data.""" -class Reporter: +class TextReporter(Reporter): """Reporter class.""" def __init__( self, formatter_resolver: Callable[[Any], Callable[[Any], Report]], - output_format: OutputFormat = "plain_text", - print_as_submitted: bool = True, ) -> None: """Init reporter instance.""" - self.formatter_resolver = formatter_resolver - self.output_format = output_format - self.print_as_submitted = print_as_submitted - - self.data: list[tuple[Any, Callable[[Any], Report]]] = [] + super().__init__(formatter_resolver) self.delayed: list[tuple[Any, Callable[[Any], Report]]] = [] + self.output_format: OutputFormat = "plain_text" def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None: """Submit data for the report.""" - if self.print_as_submitted and not delay_print: - produce_report( - data_item, - self.formatter_resolver(data_item), - fmt="plain_text", - **kwargs, - ) - formatter = _apply_format_parameters( self.formatter_resolver(data_item), self.output_format, **kwargs ) @@ -582,51 +545,70 @@ class Reporter: if delay_print: self.delayed.append((data_item, formatter)) + else: + self.produce_report( + data_item, + self.formatter_resolver(data_item), + **kwargs, + ) def print_delayed(self) -> None: """Print delayed reports.""" - if not self.delayed: - return + if self.delayed: + data, formatters = zip(*self.delayed) + self.produce_report( + data, + formatter=CompoundFormatter(formatters), + ) + self.delayed = [] - data, formatters = zip(*self.delayed) - produce_report( - data, - formatter=CompoundFormatter(formatters), - fmt="plain_text", + def produce_report( + self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any + ) -> None: + """Produce report based on provided data.""" + formatted_data = formatter(data) + logger.info(formatted_data.to_plain_text(**kwargs)) + + +class JSONReporter(Reporter): + """Reporter class.""" + + def __init__( + self, + formatter_resolver: Callable[[Any], Callable[[Any], Report]], + ) -> None: + """Init reporter instance.""" + super().__init__(formatter_resolver) + self.output_format: OutputFormat = "json" + + def submit(self, data_item: Any, **kwargs: Any) -> None: + """Submit data for the report.""" + formatter = _apply_format_parameters( + self.formatter_resolver(data_item), self.output_format, **kwargs ) - self.delayed = [] + self.data.append((data_item, formatter)) - def generate_report(self, output: PathOrFileLike | None) -> None: + def generate_report(self) -> None: """Generate report.""" - already_printed = ( - self.print_as_submitted - and self.output_format == "plain_text" - and output is None - ) - if not self.data or already_printed: + if not self.data: return data, formatters = zip(*self.data) - produce_report( + self.produce_report( data, formatter=CompoundFormatter(formatters), - fmt=self.output_format, - output=output, ) - -@contextmanager -def get_reporter( - output_format: OutputFormat, - output: PathOrFileLike | None, - formatter_resolver: Callable[[Any], Callable[[Any], Report]], -) -> Generator[Reporter, None, None]: - """Get reporter and generate report.""" - reporter = Reporter(formatter_resolver, output_format) - - yield reporter - - reporter.generate_report(output) + def produce_report( + self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any + ) -> None: + """Produce report based on provided data.""" + formatted_data = formatter(data) + print( + json.dumps( + formatted_data.to_json(**kwargs), indent=4, cls=CustomJSONEncoder + ), + ) def _apply_format_parameters( |