From 58a65fee574c00329cf92b387a6d2513dcbf6100 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Mon, 24 Oct 2022 15:08:08 +0100 Subject: MLIA-433 Add TensorFlow Lite compatibility check - Add ability to intercept low level TensorFlow output - Produce advice for the models that could not be converted to the TensorFlow Lite format - Refactor utility functions for TensorFlow Lite conversion - Add TensorFlow Lite compatibility checker Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d --- src/mlia/utils/logging.py | 60 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 2 deletions(-) (limited to 'src/mlia/utils') diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py index 793500a..cf7ad27 100644 --- a/src/mlia/utils/logging.py +++ b/src/mlia/utils/logging.py @@ -4,6 +4,9 @@ from __future__ import annotations import logging +import os +import sys +import tempfile from contextlib import contextmanager from contextlib import ExitStack from contextlib import redirect_stderr @@ -12,6 +15,8 @@ from pathlib import Path from typing import Any from typing import Callable from typing import Generator +from typing import Iterable +from typing import TextIO class LoggerWriter: @@ -35,7 +40,7 @@ class LoggerWriter: def redirect_output( logger: logging.Logger, stdout_level: int = logging.INFO, - stderr_level: int = logging.INFO, + stderr_level: int = logging.ERROR, ) -> Generator[None, None, None]: """Redirect standard output to the logger.""" stdout_to_log = LoggerWriter(logger, stdout_level) @@ -48,6 +53,47 @@ def redirect_output( yield +@contextmanager +def redirect_raw( + logger: logging.Logger, output: TextIO, log_level: int +) -> Generator[None, None, None]: + """Redirect output using file descriptors.""" + with tempfile.TemporaryFile(mode="r+") as tmp: + old_output_fd: int | None = None + try: + output_fd = output.fileno() + old_output_fd = os.dup(output_fd) + os.dup2(tmp.fileno(), output_fd) + + yield + finally: + if old_output_fd is not None: + os.dup2(old_output_fd, output_fd) + os.close(old_output_fd) + + tmp.seek(0) + for line in tmp.readlines(): + logger.log(log_level, line.rstrip()) + + +@contextmanager +def redirect_raw_output( + logger: logging.Logger, + stdout_level: int | None = logging.INFO, + stderr_level: int | None = logging.ERROR, +) -> Generator[None, None, None]: + """Redirect output on the process level.""" + with ExitStack() as exit_stack: + for level, output in [ + (stdout_level, sys.stdout), + (stderr_level, sys.stderr), + ]: + if level is not None: + exit_stack.enter_context(redirect_raw(logger, output, level)) + + yield + + class LogFilter(logging.Filter): """Configurable log filter.""" @@ -112,9 +158,19 @@ def create_log_handler( def attach_handlers( - handlers: list[logging.Handler], loggers: list[logging.Logger] + handlers: Iterable[logging.Handler], loggers: Iterable[logging.Logger] ) -> None: """Attach handlers to the loggers.""" for handler in handlers: for logger in loggers: logger.addHandler(handler) + + +@contextmanager +def log_action(action: str) -> Generator[None, None, None]: + """Log action.""" + logger = logging.getLogger(__name__) + + logger.info(action) + yield + logger.info("Done\n") -- cgit v1.2.1