aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/utils
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/utils')
-rw-r--r--src/mlia/utils/__init__.py3
-rw-r--r--src/mlia/utils/console.py97
-rw-r--r--src/mlia/utils/download.py89
-rw-r--r--src/mlia/utils/filesystem.py124
-rw-r--r--src/mlia/utils/logging.py120
-rw-r--r--src/mlia/utils/misc.py9
-rw-r--r--src/mlia/utils/proc.py164
-rw-r--r--src/mlia/utils/types.py37
8 files changed, 643 insertions, 0 deletions
diff --git a/src/mlia/utils/__init__.py b/src/mlia/utils/__init__.py
new file mode 100644
index 0000000..ecb5ca1
--- /dev/null
+++ b/src/mlia/utils/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils module."""
diff --git a/src/mlia/utils/console.py b/src/mlia/utils/console.py
new file mode 100644
index 0000000..7cb3d83
--- /dev/null
+++ b/src/mlia/utils/console.py
@@ -0,0 +1,97 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Console output utility functions."""
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+from rich.console import Console
+from rich.console import RenderableType
+from rich.table import box
+from rich.table import Table
+from rich.text import Text
+
+
+def create_section_header(
+ section_name: Optional[str] = None, length: int = 80, sep: str = "-"
+) -> str:
+ """Return section header."""
+ if not section_name:
+ content = sep * length
+ else:
+ before = 3
+ spaces = 2
+ after = length - (len(section_name) + before + spaces)
+ if after < 0:
+ raise ValueError("Section name too long")
+ content = f"{sep * before} {section_name} {sep * after}"
+
+ return f"\n{content}\n"
+
+
+def apply_style(value: str, style: str) -> str:
+ """Apply style to the value."""
+ return f"[{style}]{value}"
+
+
+def style_improvement(result: bool) -> str:
+ """Return different text style based on result."""
+ return "green" if result else "yellow"
+
+
+def produce_table(
+ rows: Iterable,
+ headers: Optional[List[str]] = None,
+ table_style: str = "default",
+) -> str:
+ """Represent data in tabular form."""
+ table = _get_table(table_style)
+
+ if headers:
+ table.show_header = True
+ for header in headers:
+ table.add_column(header)
+
+ for row in rows:
+ table.add_row(*row)
+
+ return _convert_to_text(table)
+
+
+def _get_table(table_style: str) -> Table:
+ """Get Table instance for the provided style."""
+ if table_style == "default":
+ return Table(
+ show_header=False,
+ show_lines=True,
+ box=box.SQUARE_DOUBLE_HEAD,
+ )
+
+ if table_style == "nested":
+ return Table(
+ show_header=False,
+ box=None,
+ padding=(0, 1, 1, 0),
+ )
+
+ if table_style == "no_borders":
+ return Table(show_header=False, box=None)
+
+ raise Exception(f"Unsupported table style {table_style}")
+
+
+def _convert_to_text(*renderables: RenderableType) -> str:
+ """Convert renderable object to text."""
+ console = Console()
+ with console.capture() as capture:
+ for item in renderables:
+ console.print(item)
+
+ text = capture.get()
+ return text.rstrip()
+
+
+def remove_ascii_codes(value: str) -> str:
+ """Decode and remove ASCII codes."""
+ text = Text.from_ansi(value)
+ return text.plain
diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py
new file mode 100644
index 0000000..4658738
--- /dev/null
+++ b/src/mlia/utils/download.py
@@ -0,0 +1,89 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils for files downloading."""
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+import requests
+from rich.progress import BarColumn
+from rich.progress import DownloadColumn
+from rich.progress import FileSizeColumn
+from rich.progress import Progress
+from rich.progress import ProgressColumn
+from rich.progress import TextColumn
+
+from mlia.utils.filesystem import sha256
+from mlia.utils.types import parse_int
+
+
+def download_progress(
+ content_chunks: Iterable[bytes], content_length: Optional[int], label: Optional[str]
+) -> Iterable[bytes]:
+ """Show progress info while reading content."""
+ columns: List[ProgressColumn] = [TextColumn("{task.description}")]
+
+ if content_length is None:
+ total = float("inf")
+ columns.append(FileSizeColumn())
+ else:
+ total = content_length
+ columns.extend([BarColumn(), DownloadColumn(binary_units=True)])
+
+ with Progress(*columns) as progress:
+ task = progress.add_task(label or "Downloading", total=total)
+
+ for chunk in content_chunks:
+ progress.update(task, advance=len(chunk))
+ yield chunk
+
+
+def download(
+ url: str,
+ dest: Path,
+ show_progress: bool = False,
+ label: Optional[str] = None,
+ chunk_size: int = 8192,
+) -> None:
+ """Download the file."""
+ with requests.get(url, stream=True) as resp:
+ resp.raise_for_status()
+ content_chunks = resp.iter_content(chunk_size=chunk_size)
+
+ if show_progress:
+ content_length = parse_int(resp.headers.get("Content-Length"))
+ content_chunks = download_progress(content_chunks, content_length, label)
+
+ with open(dest, "wb") as file:
+ for chunk in content_chunks:
+ file.write(chunk)
+
+
+@dataclass
+class DownloadArtifact:
+ """Download artifact attributes."""
+
+ name: str
+ url: str
+ filename: str
+ version: str
+ sha256_hash: str
+
+ def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path:
+ """Download artifact into destination directory."""
+ if (dest := dest_dir / self.filename).exists():
+ raise ValueError(f"{dest} already exists")
+
+ download(
+ self.url,
+ dest,
+ show_progress=show_progress,
+ label=f"Downloading {self.name} ver. {self.version}",
+ )
+
+ if sha256(dest) != self.sha256_hash:
+ raise ValueError("Digests do not match")
+
+ return dest
diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py
new file mode 100644
index 0000000..73a88d9
--- /dev/null
+++ b/src/mlia/utils/filesystem.py
@@ -0,0 +1,124 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils related to file management."""
+import hashlib
+import importlib.resources as pkg_resources
+import json
+import os
+import shutil
+from contextlib import contextmanager
+from pathlib import Path
+from tempfile import mkstemp
+from tempfile import TemporaryDirectory
+from typing import Any
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Union
+
+
+def get_mlia_resources() -> Path:
+ """Get the path to the resources directory."""
+ with pkg_resources.path("mlia", "__init__.py") as init_path:
+ project_root = init_path.parent
+ return project_root / "resources"
+
+
+def get_vela_config() -> Path:
+ """Get the path to the default Vela config file."""
+ return get_mlia_resources() / "vela/vela.ini"
+
+
+def get_profiles_file() -> Path:
+ """Get the Ethos-U profiles file."""
+ return get_mlia_resources() / "profiles.json"
+
+
+def get_profiles_data() -> Dict[str, Dict[str, Any]]:
+ """Get the Ethos-U profile values as a dictionary."""
+ with open(get_profiles_file(), encoding="utf-8") as json_file:
+ profiles = json.load(json_file)
+
+ if not isinstance(profiles, dict):
+ raise Exception("Profiles data format is not valid")
+
+ return profiles
+
+
+def get_profile(target: str) -> Dict[str, Any]:
+ """Get settings for the provided target profile."""
+ profiles = get_profiles_data()
+
+ if target not in profiles:
+ raise Exception(f"Unable to find target profile {target}")
+
+ return profiles[target]
+
+
+def get_supported_profile_names() -> List[str]:
+ """Get the supported Ethos-U profile names."""
+ return list(get_profiles_data().keys())
+
+
+@contextmanager
+def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]:
+ """Create temp file and remove it after."""
+ _, tmp_file = mkstemp(suffix=suffix)
+
+ try:
+ yield Path(tmp_file)
+ finally:
+ os.remove(tmp_file)
+
+
+@contextmanager
+def temp_directory(suffix: Optional[str] = None) -> Generator[Path, None, None]:
+ """Create temp directory and remove it after."""
+ with TemporaryDirectory(suffix=suffix) as tmpdir:
+ yield Path(tmpdir)
+
+
+def file_chunks(
+ filepath: Union[Path, str], chunk_size: int = 4096
+) -> Generator[bytes, None, None]:
+ """Return sequence of the file chunks."""
+ with open(filepath, "rb") as file:
+ while data := file.read(chunk_size):
+ yield data
+
+
+def hexdigest(filepath: Union[Path, str], hash_obj: "hashlib._Hash") -> str:
+ """Return hex digest of the file."""
+ for chunk in file_chunks(filepath):
+ hash_obj.update(chunk)
+
+ return hash_obj.hexdigest()
+
+
+def sha256(filepath: Path) -> str:
+ """Return SHA256 hash of the file."""
+ return hexdigest(filepath, hashlib.sha256())
+
+
+def all_files_exist(paths: Iterable[Path]) -> bool:
+ """Check if all files are exist."""
+ return all(item.is_file() for item in paths)
+
+
+def all_paths_valid(paths: Iterable[Path]) -> bool:
+ """Check if all paths are valid."""
+ return all(item.exists() for item in paths)
+
+
+def copy_all(*paths: Path, dest: Path) -> None:
+ """Copy files/directories into destination folder."""
+ dest.mkdir(exist_ok=True)
+
+ for path in paths:
+ if path.is_file():
+ shutil.copy2(path, dest)
+
+ if path.is_dir():
+ shutil.copytree(path, dest, dirs_exist_ok=True)
diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py
new file mode 100644
index 0000000..86d7567
--- /dev/null
+++ b/src/mlia/utils/logging.py
@@ -0,0 +1,120 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Logging utility functions."""
+import logging
+from contextlib import contextmanager
+from contextlib import ExitStack
+from contextlib import redirect_stderr
+from contextlib import redirect_stdout
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import Generator
+from typing import List
+from typing import Optional
+
+
+class LoggerWriter:
+ """Redirect printed messages to the logger."""
+
+ def __init__(self, logger: logging.Logger, level: int):
+ """Init logger writer."""
+ self.logger = logger
+ self.level = level
+
+ def write(self, message: str) -> None:
+ """Write message."""
+ if message.strip() != "":
+ self.logger.log(self.level, message)
+
+ def flush(self) -> None:
+ """Flush buffers."""
+
+
+@contextmanager
+def redirect_output(
+ logger: logging.Logger,
+ stdout_level: int = logging.INFO,
+ stderr_level: int = logging.INFO,
+) -> Generator[None, None, None]:
+ """Redirect standard output to the logger."""
+ stdout_to_log = LoggerWriter(logger, stdout_level)
+ stderr_to_log = LoggerWriter(logger, stderr_level)
+
+ with ExitStack() as exit_stack:
+ exit_stack.enter_context(redirect_stdout(stdout_to_log)) # type: ignore
+ exit_stack.enter_context(redirect_stderr(stderr_to_log)) # type: ignore
+
+ yield
+
+
+class LogFilter(logging.Filter):
+ """Configurable log filter."""
+
+ def __init__(self, log_record_filter: Callable[[logging.LogRecord], bool]) -> None:
+ """Init log filter instance."""
+ super().__init__()
+ self.log_record_filter = log_record_filter
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ """Filter log messages."""
+ return self.log_record_filter(record)
+
+ @classmethod
+ def equals(cls, log_level: int) -> "LogFilter":
+ """Return log filter that filters messages by log level."""
+
+ def filter_by_level(log_record: logging.LogRecord) -> bool:
+ return log_record.levelno == log_level
+
+ return cls(filter_by_level)
+
+ @classmethod
+ def skip(cls, log_level: int) -> "LogFilter":
+ """Return log filter that skips messages with particular level."""
+
+ def skip_by_level(log_record: logging.LogRecord) -> bool:
+ return log_record.levelno != log_level
+
+ return cls(skip_by_level)
+
+
+def create_log_handler(
+ *,
+ file_path: Optional[Path] = None,
+ stream: Optional[Any] = None,
+ log_level: Optional[int] = None,
+ log_format: Optional[str] = None,
+ log_filter: Optional[logging.Filter] = None,
+ delay: bool = True,
+) -> logging.Handler:
+ """Create logger handler."""
+ handler: Optional[logging.Handler] = None
+
+ if file_path is not None:
+ handler = logging.FileHandler(file_path, delay=delay)
+ elif stream is not None:
+ handler = logging.StreamHandler(stream)
+
+ if handler is None:
+ raise Exception("Unable to create logging handler")
+
+ if log_level:
+ handler.setLevel(log_level)
+
+ if log_format:
+ handler.setFormatter(logging.Formatter(log_format))
+
+ if log_filter:
+ handler.addFilter(log_filter)
+
+ return handler
+
+
+def attach_handlers(
+ handlers: List[logging.Handler], loggers: List[logging.Logger]
+) -> None:
+ """Attach handlers to the loggers."""
+ for handler in handlers:
+ for logger in loggers:
+ logger.addHandler(handler)
diff --git a/src/mlia/utils/misc.py b/src/mlia/utils/misc.py
new file mode 100644
index 0000000..de95448
--- /dev/null
+++ b/src/mlia/utils/misc.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Various util functions."""
+
+
+def yes(prompt: str) -> bool:
+ """Return true if user confirms the action."""
+ response = input(f"{prompt} [y/n]: ")
+ return response in ["y", "Y"]
diff --git a/src/mlia/utils/proc.py b/src/mlia/utils/proc.py
new file mode 100644
index 0000000..39aca43
--- /dev/null
+++ b/src/mlia/utils/proc.py
@@ -0,0 +1,164 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils related to process management."""
+import os
+import signal
+import subprocess
+import time
+from abc import ABC
+from abc import abstractmethod
+from contextlib import contextmanager
+from contextlib import suppress
+from pathlib import Path
+from typing import Any
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+
+class OutputConsumer(ABC):
+ """Base class for the output consumers."""
+
+ @abstractmethod
+ def feed(self, line: str) -> None:
+ """Feed new line to the consumerr."""
+
+
+class RunningCommand:
+ """Running command."""
+
+ def __init__(self, process: subprocess.Popen) -> None:
+ """Init running command instance."""
+ self.process = process
+ self._output_consumers: Optional[List[OutputConsumer]] = None
+
+ def is_alive(self) -> bool:
+ """Return true if process is still alive."""
+ return self.process.poll() is None
+
+ def exit_code(self) -> Optional[int]:
+ """Return process's return code."""
+ return self.process.poll()
+
+ def stdout(self) -> Iterable[str]:
+ """Return std output of the process."""
+ assert self.process.stdout is not None
+
+ for line in self.process.stdout:
+ yield line
+
+ def kill(self) -> None:
+ """Kill the process."""
+ self.process.kill()
+
+ def send_signal(self, signal_num: int) -> None:
+ """Send signal to the process."""
+ self.process.send_signal(signal_num)
+
+ @property
+ def output_consumers(self) -> Optional[List[OutputConsumer]]:
+ """Property output_consumers."""
+ return self._output_consumers
+
+ @output_consumers.setter
+ def output_consumers(self, output_consumers: List[OutputConsumer]) -> None:
+ """Set output consumers."""
+ self._output_consumers = output_consumers
+
+ def consume_output(self) -> None:
+ """Pass program's output to the consumers."""
+ if self.process is None or self.output_consumers is None:
+ return
+
+ for line in self.stdout():
+ for consumer in self.output_consumers:
+ with suppress():
+ consumer.feed(line)
+
+ def stop(
+ self, wait: bool = True, num_of_attempts: int = 5, interval: float = 0.5
+ ) -> None:
+ """Stop execution."""
+ try:
+ if not self.is_alive():
+ return
+
+ self.process.send_signal(signal.SIGINT)
+ self.consume_output()
+
+ if not wait:
+ return
+
+ for _ in range(num_of_attempts):
+ time.sleep(interval)
+ if not self.is_alive():
+ break
+ else:
+ raise Exception("Unable to stop running command")
+ finally:
+ self._close_fd()
+
+ def _close_fd(self) -> None:
+ """Close file descriptors."""
+
+ def close(file_descriptor: Any) -> None:
+ """Check and close file."""
+ if file_descriptor is not None and hasattr(file_descriptor, "close"):
+ file_descriptor.close()
+
+ close(self.process.stdout)
+ close(self.process.stderr)
+
+ def wait(self, redirect_output: bool = False) -> None:
+ """Redirect process output to stdout and wait for completion."""
+ if redirect_output:
+ for line in self.stdout():
+ print(line, end="")
+
+ self.process.wait()
+
+
+class CommandExecutor:
+ """Command executor."""
+
+ @staticmethod
+ def execute(command: List[str]) -> Tuple[int, bytes, bytes]:
+ """Execute the command."""
+ result = subprocess.run(
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
+ )
+
+ return (result.returncode, result.stdout, result.stderr)
+
+ @staticmethod
+ def submit(command: List[str]) -> RunningCommand:
+ """Submit command for the execution."""
+ process = subprocess.Popen( # pylint: disable=consider-using-with
+ command,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT, # redirect command stderr to stdout
+ universal_newlines=True,
+ bufsize=1,
+ )
+
+ return RunningCommand(process)
+
+
+@contextmanager
+def working_directory(
+ working_dir: Path, create_dir: bool = False
+) -> Generator[Path, None, None]:
+ """Temporary change working directory."""
+ current_working_dir = Path.cwd()
+
+ if create_dir:
+ working_dir.mkdir()
+
+ os.chdir(working_dir)
+
+ try:
+ yield working_dir
+ finally:
+ os.chdir(current_working_dir)
diff --git a/src/mlia/utils/types.py b/src/mlia/utils/types.py
new file mode 100644
index 0000000..9b63928
--- /dev/null
+++ b/src/mlia/utils/types.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Types related utility functions."""
+from typing import Any
+from typing import Optional
+
+
+def is_list_of(data: Any, cls: type, elem_num: Optional[int] = None) -> bool:
+ """Check if data is a list of object of the same class."""
+ return (
+ isinstance(data, (tuple, list))
+ and all(isinstance(item, cls) for item in data)
+ and (elem_num is None or len(data) == elem_num)
+ )
+
+
+def is_number(value: str) -> bool:
+ """Return true if string contains a number."""
+ try:
+ float(value)
+ except ValueError:
+ return False
+
+ return True
+
+
+def parse_int(value: Any, default: Optional[int] = None) -> Optional[int]:
+ """Parse integer value."""
+ try:
+ return int(value)
+ except (TypeError, ValueError):
+ return default
+
+
+def only_one_selected(*options: bool) -> bool:
+ """Return true if only one True value found."""
+ return sum(options) == 1