diff options
Diffstat (limited to 'src/mlia/utils/logging.py')
-rw-r--r-- | src/mlia/utils/logging.py | 120 |
1 files changed, 120 insertions, 0 deletions
diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py new file mode 100644 index 0000000..86d7567 --- /dev/null +++ b/src/mlia/utils/logging.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Logging utility functions.""" +import logging +from contextlib import contextmanager +from contextlib import ExitStack +from contextlib import redirect_stderr +from contextlib import redirect_stdout +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Generator +from typing import List +from typing import Optional + + +class LoggerWriter: + """Redirect printed messages to the logger.""" + + def __init__(self, logger: logging.Logger, level: int): + """Init logger writer.""" + self.logger = logger + self.level = level + + def write(self, message: str) -> None: + """Write message.""" + if message.strip() != "": + self.logger.log(self.level, message) + + def flush(self) -> None: + """Flush buffers.""" + + +@contextmanager +def redirect_output( + logger: logging.Logger, + stdout_level: int = logging.INFO, + stderr_level: int = logging.INFO, +) -> Generator[None, None, None]: + """Redirect standard output to the logger.""" + stdout_to_log = LoggerWriter(logger, stdout_level) + stderr_to_log = LoggerWriter(logger, stderr_level) + + with ExitStack() as exit_stack: + exit_stack.enter_context(redirect_stdout(stdout_to_log)) # type: ignore + exit_stack.enter_context(redirect_stderr(stderr_to_log)) # type: ignore + + yield + + +class LogFilter(logging.Filter): + """Configurable log filter.""" + + def __init__(self, log_record_filter: Callable[[logging.LogRecord], bool]) -> None: + """Init log filter instance.""" + super().__init__() + self.log_record_filter = log_record_filter + + def filter(self, record: logging.LogRecord) -> bool: + """Filter log messages.""" + return self.log_record_filter(record) + + @classmethod + def equals(cls, log_level: int) -> "LogFilter": + """Return log filter that filters messages by log level.""" + + def filter_by_level(log_record: logging.LogRecord) -> bool: + return log_record.levelno == log_level + + return cls(filter_by_level) + + @classmethod + def skip(cls, log_level: int) -> "LogFilter": + """Return log filter that skips messages with particular level.""" + + def skip_by_level(log_record: logging.LogRecord) -> bool: + return log_record.levelno != log_level + + return cls(skip_by_level) + + +def create_log_handler( + *, + file_path: Optional[Path] = None, + stream: Optional[Any] = None, + log_level: Optional[int] = None, + log_format: Optional[str] = None, + log_filter: Optional[logging.Filter] = None, + delay: bool = True, +) -> logging.Handler: + """Create logger handler.""" + handler: Optional[logging.Handler] = None + + if file_path is not None: + handler = logging.FileHandler(file_path, delay=delay) + elif stream is not None: + handler = logging.StreamHandler(stream) + + if handler is None: + raise Exception("Unable to create logging handler") + + if log_level: + handler.setLevel(log_level) + + if log_format: + handler.setFormatter(logging.Formatter(log_format)) + + if log_filter: + handler.addFilter(log_filter) + + return handler + + +def attach_handlers( + handlers: List[logging.Handler], loggers: List[logging.Logger] +) -> None: + """Attach handlers to the loggers.""" + for handler in handlers: + for logger in loggers: + logger.addHandler(handler) |