aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/utils/logging.py
blob: 86d756742959565c049c2621ffce6a4c2966f4dc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Logging utility functions."""
import logging
from contextlib import contextmanager
from contextlib import ExitStack
from contextlib import redirect_stderr
from contextlib import redirect_stdout
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Generator
from typing import List
from typing import Optional


class LoggerWriter:
    """Redirect printed messages to the logger."""

    def __init__(self, logger: logging.Logger, level: int):
        """Init logger writer."""
        self.logger = logger
        self.level = level

    def write(self, message: str) -> None:
        """Write message."""
        if message.strip() != "":
            self.logger.log(self.level, message)

    def flush(self) -> None:
        """Flush buffers."""


@contextmanager
def redirect_output(
    logger: logging.Logger,
    stdout_level: int = logging.INFO,
    stderr_level: int = logging.INFO,
) -> Generator[None, None, None]:
    """Redirect standard output to the logger."""
    stdout_to_log = LoggerWriter(logger, stdout_level)
    stderr_to_log = LoggerWriter(logger, stderr_level)

    with ExitStack() as exit_stack:
        exit_stack.enter_context(redirect_stdout(stdout_to_log))  # type: ignore
        exit_stack.enter_context(redirect_stderr(stderr_to_log))  # type: ignore

        yield


class LogFilter(logging.Filter):
    """Configurable log filter."""

    def __init__(self, log_record_filter: Callable[[logging.LogRecord], bool]) -> None:
        """Init log filter instance."""
        super().__init__()
        self.log_record_filter = log_record_filter

    def filter(self, record: logging.LogRecord) -> bool:
        """Filter log messages."""
        return self.log_record_filter(record)

    @classmethod
    def equals(cls, log_level: int) -> "LogFilter":
        """Return log filter that filters messages by log level."""

        def filter_by_level(log_record: logging.LogRecord) -> bool:
            return log_record.levelno == log_level

        return cls(filter_by_level)

    @classmethod
    def skip(cls, log_level: int) -> "LogFilter":
        """Return log filter that skips messages with particular level."""

        def skip_by_level(log_record: logging.LogRecord) -> bool:
            return log_record.levelno != log_level

        return cls(skip_by_level)


def create_log_handler(
    *,
    file_path: Optional[Path] = None,
    stream: Optional[Any] = None,
    log_level: Optional[int] = None,
    log_format: Optional[str] = None,
    log_filter: Optional[logging.Filter] = None,
    delay: bool = True,
) -> logging.Handler:
    """Create logger handler."""
    handler: Optional[logging.Handler] = None

    if file_path is not None:
        handler = logging.FileHandler(file_path, delay=delay)
    elif stream is not None:
        handler = logging.StreamHandler(stream)

    if handler is None:
        raise Exception("Unable to create logging handler")

    if log_level:
        handler.setLevel(log_level)

    if log_format:
        handler.setFormatter(logging.Formatter(log_format))

    if log_filter:
        handler.addFilter(log_filter)

    return handler


def attach_handlers(
    handlers: List[logging.Handler], loggers: List[logging.Logger]
) -> None:
    """Attach handlers to the loggers."""
    for handler in handlers:
        for logger in loggers:
            logger.addHandler(handler)