diff options
Diffstat (limited to 'src/mlia/utils/logging.py')
-rw-r--r-- | src/mlia/utils/logging.py | 60 |
1 files changed, 58 insertions, 2 deletions
diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py index 793500a..cf7ad27 100644 --- a/src/mlia/utils/logging.py +++ b/src/mlia/utils/logging.py @@ -4,6 +4,9 @@ from __future__ import annotations import logging +import os +import sys +import tempfile from contextlib import contextmanager from contextlib import ExitStack from contextlib import redirect_stderr @@ -12,6 +15,8 @@ from pathlib import Path from typing import Any from typing import Callable from typing import Generator +from typing import Iterable +from typing import TextIO class LoggerWriter: @@ -35,7 +40,7 @@ class LoggerWriter: def redirect_output( logger: logging.Logger, stdout_level: int = logging.INFO, - stderr_level: int = logging.INFO, + stderr_level: int = logging.ERROR, ) -> Generator[None, None, None]: """Redirect standard output to the logger.""" stdout_to_log = LoggerWriter(logger, stdout_level) @@ -48,6 +53,47 @@ def redirect_output( yield +@contextmanager +def redirect_raw( + logger: logging.Logger, output: TextIO, log_level: int +) -> Generator[None, None, None]: + """Redirect output using file descriptors.""" + with tempfile.TemporaryFile(mode="r+") as tmp: + old_output_fd: int | None = None + try: + output_fd = output.fileno() + old_output_fd = os.dup(output_fd) + os.dup2(tmp.fileno(), output_fd) + + yield + finally: + if old_output_fd is not None: + os.dup2(old_output_fd, output_fd) + os.close(old_output_fd) + + tmp.seek(0) + for line in tmp.readlines(): + logger.log(log_level, line.rstrip()) + + +@contextmanager +def redirect_raw_output( + logger: logging.Logger, + stdout_level: int | None = logging.INFO, + stderr_level: int | None = logging.ERROR, +) -> Generator[None, None, None]: + """Redirect output on the process level.""" + with ExitStack() as exit_stack: + for level, output in [ + (stdout_level, sys.stdout), + (stderr_level, sys.stderr), + ]: + if level is not None: + exit_stack.enter_context(redirect_raw(logger, output, level)) + + yield + + class LogFilter(logging.Filter): """Configurable log filter.""" @@ -112,9 +158,19 @@ def create_log_handler( def attach_handlers( - handlers: list[logging.Handler], loggers: list[logging.Logger] + handlers: Iterable[logging.Handler], loggers: Iterable[logging.Logger] ) -> None: """Attach handlers to the loggers.""" for handler in handlers: for logger in loggers: logger.addHandler(handler) + + +@contextmanager +def log_action(action: str) -> Generator[None, None, None]: + """Log action.""" + logger = logging.getLogger(__name__) + + logger.info(action) + yield + logger.info("Done\n") |