aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/utils/logging.py
blob: 17f6caedbd4c2a373916fb7b10969141648e766b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Logging utility functions."""
from __future__ import annotations

import logging
import os
import sys
import tempfile
from contextlib import contextmanager
from contextlib import ExitStack
from contextlib import redirect_stderr
from contextlib import redirect_stdout
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Generator
from typing import Iterable
from typing import TextIO

from mlia.utils.console import remove_ascii_codes


class LoggerWriter:
    """Redirect printed messages to the logger."""

    def __init__(self, logger: logging.Logger, level: int):
        """Init logger writer."""
        self.logger = logger
        self.level = level

    def write(self, message: str) -> None:
        """Write message."""
        if message.strip() != "":
            self.logger.log(self.level, message)

    def flush(self) -> None:
        """Flush buffers."""


@contextmanager
def redirect_output(
    logger: logging.Logger,
    stdout_level: int = logging.INFO,
    stderr_level: int = logging.ERROR,
) -> Generator[None, None, None]:
    """Redirect standard output to the logger."""
    stdout_to_log = LoggerWriter(logger, stdout_level)
    stderr_to_log = LoggerWriter(logger, stderr_level)

    with ExitStack() as exit_stack:
        exit_stack.enter_context(redirect_stdout(stdout_to_log))  # type: ignore
        exit_stack.enter_context(redirect_stderr(stderr_to_log))  # type: ignore

        yield


@contextmanager
def process_raw_output(
    consumer: Callable[[str], None], output: TextIO
) -> Generator[None, None, None]:
    """Process output on file descriptor level."""
    with tempfile.TemporaryFile(mode="r+") as tmp:
        old_output_fd: int | None = None
        try:
            output_fd = output.fileno()
            old_output_fd = os.dup(output_fd)
            os.dup2(tmp.fileno(), output_fd)

            yield
        finally:
            if old_output_fd is not None:
                os.dup2(old_output_fd, output_fd)
                os.close(old_output_fd)

            tmp.seek(0)
            for line in tmp.readlines():
                consumer(line)


@contextmanager
def redirect_raw(
    logger: logging.Logger, output: TextIO, log_level: int
) -> Generator[None, None, None]:
    """Redirect output using file descriptors."""

    def consumer(line: str) -> None:
        """Redirect output to the logger."""
        logger.log(log_level, line.rstrip())

    with process_raw_output(consumer, output):
        yield


@contextmanager
def redirect_raw_output(
    logger: logging.Logger,
    stdout_level: int | None = logging.INFO,
    stderr_level: int | None = logging.ERROR,
) -> Generator[None, None, None]:
    """Redirect output on the process level."""
    with ExitStack() as exit_stack:
        for level, output in [
            (stdout_level, sys.stdout),
            (stderr_level, sys.stderr),
        ]:
            if level is not None:
                exit_stack.enter_context(redirect_raw(logger, output, level))

        yield


@contextmanager
def capture_raw_output(output: TextIO) -> Generator[list[str], None, None]:
    """Capture output as list of strings."""
    result: list[str] = []

    def consumer(line: str) -> None:
        """Save output for later processing."""
        result.append(line)

    with process_raw_output(consumer, output):
        yield result


class LogFilter(logging.Filter):
    """Configurable log filter."""

    def __init__(self, log_record_filter: Callable[[logging.LogRecord], bool]) -> None:
        """Init log filter instance."""
        super().__init__()
        self.log_record_filter = log_record_filter

    def filter(self, record: logging.LogRecord) -> bool:
        """Filter log messages."""
        return self.log_record_filter(record)

    @classmethod
    def equals(cls, log_level: int) -> LogFilter:
        """Return log filter that filters messages by log level."""

        def filter_by_level(log_record: logging.LogRecord) -> bool:
            return log_record.levelno == log_level

        return cls(filter_by_level)

    @classmethod
    def skip(cls, log_level: int) -> LogFilter:
        """Return log filter that skips messages with particular level."""

        def skip_by_level(log_record: logging.LogRecord) -> bool:
            return log_record.levelno != log_level

        return cls(skip_by_level)


class NoASCIIFormatter(logging.Formatter):
    """Custom Formatter for logging into file."""

    def format(self, record: logging.LogRecord) -> str:
        """Overwrite format method to remove ascii codes from record."""
        result = super().format(record)
        return remove_ascii_codes(result)


def create_log_handler(
    *,
    file_path: Path | None = None,
    stream: Any | None = None,
    log_level: int | None = None,
    log_format: str | logging.Formatter | None = None,
    log_filter: logging.Filter | None = None,
    delay: bool = True,
) -> logging.Handler:
    """Create logger handler."""
    handler: logging.Handler | None = None

    if file_path is not None:
        handler = logging.FileHandler(file_path, delay=delay)
    elif stream is not None:
        handler = logging.StreamHandler(stream)

    if handler is None:
        raise Exception("Unable to create logging handler")

    if log_level:
        handler.setLevel(log_level)

    if log_format:
        if isinstance(log_format, str):
            log_format = logging.Formatter(log_format)
        handler.setFormatter(log_format)

    if log_filter:
        handler.addFilter(log_filter)

    return handler


def attach_handlers(
    handlers: Iterable[logging.Handler], loggers: Iterable[logging.Logger]
) -> None:
    """Attach handlers to the loggers."""
    for handler in handlers:
        for logger in loggers:
            logger.addHandler(handler)


@contextmanager
def log_action(action: str) -> Generator[None, None, None]:
    """Log action."""
    logger = logging.getLogger(__name__)

    logger.info(action)
    yield
    logger.info("Done\n")