aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/utils
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-24 15:08:08 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-26 17:08:13 +0100
commit58a65fee574c00329cf92b387a6d2513dcbf6100 (patch)
tree47e3185f78b4298ab029785ddee68456e44cac10 /src/mlia/utils
parent9d34cb72d45a6d0a2ec1063ebf32536c1efdba75 (diff)
downloadmlia-58a65fee574c00329cf92b387a6d2513dcbf6100.tar.gz
MLIA-433 Add TensorFlow Lite compatibility check
- Add ability to intercept low level TensorFlow output - Produce advice for the models that could not be converted to the TensorFlow Lite format - Refactor utility functions for TensorFlow Lite conversion - Add TensorFlow Lite compatibility checker Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d
Diffstat (limited to 'src/mlia/utils')
-rw-r--r--src/mlia/utils/logging.py60
1 files changed, 58 insertions, 2 deletions
diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py
index 793500a..cf7ad27 100644
--- a/src/mlia/utils/logging.py
+++ b/src/mlia/utils/logging.py
@@ -4,6 +4,9 @@
from __future__ import annotations
import logging
+import os
+import sys
+import tempfile
from contextlib import contextmanager
from contextlib import ExitStack
from contextlib import redirect_stderr
@@ -12,6 +15,8 @@ from pathlib import Path
from typing import Any
from typing import Callable
from typing import Generator
+from typing import Iterable
+from typing import TextIO
class LoggerWriter:
@@ -35,7 +40,7 @@ class LoggerWriter:
def redirect_output(
logger: logging.Logger,
stdout_level: int = logging.INFO,
- stderr_level: int = logging.INFO,
+ stderr_level: int = logging.ERROR,
) -> Generator[None, None, None]:
"""Redirect standard output to the logger."""
stdout_to_log = LoggerWriter(logger, stdout_level)
@@ -48,6 +53,47 @@ def redirect_output(
yield
+@contextmanager
+def redirect_raw(
+ logger: logging.Logger, output: TextIO, log_level: int
+) -> Generator[None, None, None]:
+ """Redirect output using file descriptors."""
+ with tempfile.TemporaryFile(mode="r+") as tmp:
+ old_output_fd: int | None = None
+ try:
+ output_fd = output.fileno()
+ old_output_fd = os.dup(output_fd)
+ os.dup2(tmp.fileno(), output_fd)
+
+ yield
+ finally:
+ if old_output_fd is not None:
+ os.dup2(old_output_fd, output_fd)
+ os.close(old_output_fd)
+
+ tmp.seek(0)
+ for line in tmp.readlines():
+ logger.log(log_level, line.rstrip())
+
+
+@contextmanager
+def redirect_raw_output(
+ logger: logging.Logger,
+ stdout_level: int | None = logging.INFO,
+ stderr_level: int | None = logging.ERROR,
+) -> Generator[None, None, None]:
+ """Redirect output on the process level."""
+ with ExitStack() as exit_stack:
+ for level, output in [
+ (stdout_level, sys.stdout),
+ (stderr_level, sys.stderr),
+ ]:
+ if level is not None:
+ exit_stack.enter_context(redirect_raw(logger, output, level))
+
+ yield
+
+
class LogFilter(logging.Filter):
"""Configurable log filter."""
@@ -112,9 +158,19 @@ def create_log_handler(
def attach_handlers(
- handlers: list[logging.Handler], loggers: list[logging.Logger]
+ handlers: Iterable[logging.Handler], loggers: Iterable[logging.Logger]
) -> None:
"""Attach handlers to the loggers."""
for handler in handlers:
for logger in loggers:
logger.addHandler(handler)
+
+
+@contextmanager
+def log_action(action: str) -> Generator[None, None, None]:
+ """Log action."""
+ logger = logging.getLogger(__name__)
+
+ logger.info(action)
+ yield
+ logger.info("Done\n")