From 0efca3cadbad5517a59884576ddb90cfe7ac30f8 Mon Sep 17 00:00:00 2001 From: Diego Russo Date: Mon, 30 May 2022 13:34:14 +0100 Subject: Add MLIA codebase Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd --- src/mlia/utils/__init__.py | 3 + src/mlia/utils/console.py | 97 +++++++++++++++++++++++++ src/mlia/utils/download.py | 89 +++++++++++++++++++++++ src/mlia/utils/filesystem.py | 124 ++++++++++++++++++++++++++++++++ src/mlia/utils/logging.py | 120 +++++++++++++++++++++++++++++++ src/mlia/utils/misc.py | 9 +++ src/mlia/utils/proc.py | 164 +++++++++++++++++++++++++++++++++++++++++++ src/mlia/utils/types.py | 37 ++++++++++ 8 files changed, 643 insertions(+) create mode 100644 src/mlia/utils/__init__.py create mode 100644 src/mlia/utils/console.py create mode 100644 src/mlia/utils/download.py create mode 100644 src/mlia/utils/filesystem.py create mode 100644 src/mlia/utils/logging.py create mode 100644 src/mlia/utils/misc.py create mode 100644 src/mlia/utils/proc.py create mode 100644 src/mlia/utils/types.py (limited to 'src/mlia/utils') diff --git a/src/mlia/utils/__init__.py b/src/mlia/utils/__init__.py new file mode 100644 index 0000000..ecb5ca1 --- /dev/null +++ b/src/mlia/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Utils module.""" diff --git a/src/mlia/utils/console.py b/src/mlia/utils/console.py new file mode 100644 index 0000000..7cb3d83 --- /dev/null +++ b/src/mlia/utils/console.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Console output utility functions.""" +from typing import Iterable +from typing import List +from typing import Optional + +from rich.console import Console +from rich.console import RenderableType +from rich.table import box +from rich.table import Table +from rich.text import Text + + +def create_section_header( + section_name: Optional[str] = None, length: int = 80, sep: str = "-" +) -> str: + """Return section header.""" + if not section_name: + content = sep * length + else: + before = 3 + spaces = 2 + after = length - (len(section_name) + before + spaces) + if after < 0: + raise ValueError("Section name too long") + content = f"{sep * before} {section_name} {sep * after}" + + return f"\n{content}\n" + + +def apply_style(value: str, style: str) -> str: + """Apply style to the value.""" + return f"[{style}]{value}" + + +def style_improvement(result: bool) -> str: + """Return different text style based on result.""" + return "green" if result else "yellow" + + +def produce_table( + rows: Iterable, + headers: Optional[List[str]] = None, + table_style: str = "default", +) -> str: + """Represent data in tabular form.""" + table = _get_table(table_style) + + if headers: + table.show_header = True + for header in headers: + table.add_column(header) + + for row in rows: + table.add_row(*row) + + return _convert_to_text(table) + + +def _get_table(table_style: str) -> Table: + """Get Table instance for the provided style.""" + if table_style == "default": + return Table( + show_header=False, + show_lines=True, + box=box.SQUARE_DOUBLE_HEAD, + ) + + if table_style == "nested": + return Table( + show_header=False, + box=None, + padding=(0, 1, 1, 0), + ) + + if table_style == "no_borders": + return Table(show_header=False, box=None) + + raise Exception(f"Unsupported table style {table_style}") + + +def _convert_to_text(*renderables: RenderableType) -> str: + """Convert renderable object to text.""" + console = Console() + with console.capture() as capture: + for item in renderables: + console.print(item) + + text = capture.get() + return text.rstrip() + + +def remove_ascii_codes(value: str) -> str: + """Decode and remove ASCII codes.""" + text = Text.from_ansi(value) + return text.plain diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py new file mode 100644 index 0000000..4658738 --- /dev/null +++ b/src/mlia/utils/download.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Utils for files downloading.""" +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable +from typing import List +from typing import Optional + +import requests +from rich.progress import BarColumn +from rich.progress import DownloadColumn +from rich.progress import FileSizeColumn +from rich.progress import Progress +from rich.progress import ProgressColumn +from rich.progress import TextColumn + +from mlia.utils.filesystem import sha256 +from mlia.utils.types import parse_int + + +def download_progress( + content_chunks: Iterable[bytes], content_length: Optional[int], label: Optional[str] +) -> Iterable[bytes]: + """Show progress info while reading content.""" + columns: List[ProgressColumn] = [TextColumn("{task.description}")] + + if content_length is None: + total = float("inf") + columns.append(FileSizeColumn()) + else: + total = content_length + columns.extend([BarColumn(), DownloadColumn(binary_units=True)]) + + with Progress(*columns) as progress: + task = progress.add_task(label or "Downloading", total=total) + + for chunk in content_chunks: + progress.update(task, advance=len(chunk)) + yield chunk + + +def download( + url: str, + dest: Path, + show_progress: bool = False, + label: Optional[str] = None, + chunk_size: int = 8192, +) -> None: + """Download the file.""" + with requests.get(url, stream=True) as resp: + resp.raise_for_status() + content_chunks = resp.iter_content(chunk_size=chunk_size) + + if show_progress: + content_length = parse_int(resp.headers.get("Content-Length")) + content_chunks = download_progress(content_chunks, content_length, label) + + with open(dest, "wb") as file: + for chunk in content_chunks: + file.write(chunk) + + +@dataclass +class DownloadArtifact: + """Download artifact attributes.""" + + name: str + url: str + filename: str + version: str + sha256_hash: str + + def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path: + """Download artifact into destination directory.""" + if (dest := dest_dir / self.filename).exists(): + raise ValueError(f"{dest} already exists") + + download( + self.url, + dest, + show_progress=show_progress, + label=f"Downloading {self.name} ver. {self.version}", + ) + + if sha256(dest) != self.sha256_hash: + raise ValueError("Digests do not match") + + return dest diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py new file mode 100644 index 0000000..73a88d9 --- /dev/null +++ b/src/mlia/utils/filesystem.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Utils related to file management.""" +import hashlib +import importlib.resources as pkg_resources +import json +import os +import shutil +from contextlib import contextmanager +from pathlib import Path +from tempfile import mkstemp +from tempfile import TemporaryDirectory +from typing import Any +from typing import Dict +from typing import Generator +from typing import Iterable +from typing import List +from typing import Optional +from typing import Union + + +def get_mlia_resources() -> Path: + """Get the path to the resources directory.""" + with pkg_resources.path("mlia", "__init__.py") as init_path: + project_root = init_path.parent + return project_root / "resources" + + +def get_vela_config() -> Path: + """Get the path to the default Vela config file.""" + return get_mlia_resources() / "vela/vela.ini" + + +def get_profiles_file() -> Path: + """Get the Ethos-U profiles file.""" + return get_mlia_resources() / "profiles.json" + + +def get_profiles_data() -> Dict[str, Dict[str, Any]]: + """Get the Ethos-U profile values as a dictionary.""" + with open(get_profiles_file(), encoding="utf-8") as json_file: + profiles = json.load(json_file) + + if not isinstance(profiles, dict): + raise Exception("Profiles data format is not valid") + + return profiles + + +def get_profile(target: str) -> Dict[str, Any]: + """Get settings for the provided target profile.""" + profiles = get_profiles_data() + + if target not in profiles: + raise Exception(f"Unable to find target profile {target}") + + return profiles[target] + + +def get_supported_profile_names() -> List[str]: + """Get the supported Ethos-U profile names.""" + return list(get_profiles_data().keys()) + + +@contextmanager +def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]: + """Create temp file and remove it after.""" + _, tmp_file = mkstemp(suffix=suffix) + + try: + yield Path(tmp_file) + finally: + os.remove(tmp_file) + + +@contextmanager +def temp_directory(suffix: Optional[str] = None) -> Generator[Path, None, None]: + """Create temp directory and remove it after.""" + with TemporaryDirectory(suffix=suffix) as tmpdir: + yield Path(tmpdir) + + +def file_chunks( + filepath: Union[Path, str], chunk_size: int = 4096 +) -> Generator[bytes, None, None]: + """Return sequence of the file chunks.""" + with open(filepath, "rb") as file: + while data := file.read(chunk_size): + yield data + + +def hexdigest(filepath: Union[Path, str], hash_obj: "hashlib._Hash") -> str: + """Return hex digest of the file.""" + for chunk in file_chunks(filepath): + hash_obj.update(chunk) + + return hash_obj.hexdigest() + + +def sha256(filepath: Path) -> str: + """Return SHA256 hash of the file.""" + return hexdigest(filepath, hashlib.sha256()) + + +def all_files_exist(paths: Iterable[Path]) -> bool: + """Check if all files are exist.""" + return all(item.is_file() for item in paths) + + +def all_paths_valid(paths: Iterable[Path]) -> bool: + """Check if all paths are valid.""" + return all(item.exists() for item in paths) + + +def copy_all(*paths: Path, dest: Path) -> None: + """Copy files/directories into destination folder.""" + dest.mkdir(exist_ok=True) + + for path in paths: + if path.is_file(): + shutil.copy2(path, dest) + + if path.is_dir(): + shutil.copytree(path, dest, dirs_exist_ok=True) diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py new file mode 100644 index 0000000..86d7567 --- /dev/null +++ b/src/mlia/utils/logging.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Logging utility functions.""" +import logging +from contextlib import contextmanager +from contextlib import ExitStack +from contextlib import redirect_stderr +from contextlib import redirect_stdout +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Generator +from typing import List +from typing import Optional + + +class LoggerWriter: + """Redirect printed messages to the logger.""" + + def __init__(self, logger: logging.Logger, level: int): + """Init logger writer.""" + self.logger = logger + self.level = level + + def write(self, message: str) -> None: + """Write message.""" + if message.strip() != "": + self.logger.log(self.level, message) + + def flush(self) -> None: + """Flush buffers.""" + + +@contextmanager +def redirect_output( + logger: logging.Logger, + stdout_level: int = logging.INFO, + stderr_level: int = logging.INFO, +) -> Generator[None, None, None]: + """Redirect standard output to the logger.""" + stdout_to_log = LoggerWriter(logger, stdout_level) + stderr_to_log = LoggerWriter(logger, stderr_level) + + with ExitStack() as exit_stack: + exit_stack.enter_context(redirect_stdout(stdout_to_log)) # type: ignore + exit_stack.enter_context(redirect_stderr(stderr_to_log)) # type: ignore + + yield + + +class LogFilter(logging.Filter): + """Configurable log filter.""" + + def __init__(self, log_record_filter: Callable[[logging.LogRecord], bool]) -> None: + """Init log filter instance.""" + super().__init__() + self.log_record_filter = log_record_filter + + def filter(self, record: logging.LogRecord) -> bool: + """Filter log messages.""" + return self.log_record_filter(record) + + @classmethod + def equals(cls, log_level: int) -> "LogFilter": + """Return log filter that filters messages by log level.""" + + def filter_by_level(log_record: logging.LogRecord) -> bool: + return log_record.levelno == log_level + + return cls(filter_by_level) + + @classmethod + def skip(cls, log_level: int) -> "LogFilter": + """Return log filter that skips messages with particular level.""" + + def skip_by_level(log_record: logging.LogRecord) -> bool: + return log_record.levelno != log_level + + return cls(skip_by_level) + + +def create_log_handler( + *, + file_path: Optional[Path] = None, + stream: Optional[Any] = None, + log_level: Optional[int] = None, + log_format: Optional[str] = None, + log_filter: Optional[logging.Filter] = None, + delay: bool = True, +) -> logging.Handler: + """Create logger handler.""" + handler: Optional[logging.Handler] = None + + if file_path is not None: + handler = logging.FileHandler(file_path, delay=delay) + elif stream is not None: + handler = logging.StreamHandler(stream) + + if handler is None: + raise Exception("Unable to create logging handler") + + if log_level: + handler.setLevel(log_level) + + if log_format: + handler.setFormatter(logging.Formatter(log_format)) + + if log_filter: + handler.addFilter(log_filter) + + return handler + + +def attach_handlers( + handlers: List[logging.Handler], loggers: List[logging.Logger] +) -> None: + """Attach handlers to the loggers.""" + for handler in handlers: + for logger in loggers: + logger.addHandler(handler) diff --git a/src/mlia/utils/misc.py b/src/mlia/utils/misc.py new file mode 100644 index 0000000..de95448 --- /dev/null +++ b/src/mlia/utils/misc.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Various util functions.""" + + +def yes(prompt: str) -> bool: + """Return true if user confirms the action.""" + response = input(f"{prompt} [y/n]: ") + return response in ["y", "Y"] diff --git a/src/mlia/utils/proc.py b/src/mlia/utils/proc.py new file mode 100644 index 0000000..39aca43 --- /dev/null +++ b/src/mlia/utils/proc.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Utils related to process management.""" +import os +import signal +import subprocess +import time +from abc import ABC +from abc import abstractmethod +from contextlib import contextmanager +from contextlib import suppress +from pathlib import Path +from typing import Any +from typing import Generator +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + + +class OutputConsumer(ABC): + """Base class for the output consumers.""" + + @abstractmethod + def feed(self, line: str) -> None: + """Feed new line to the consumerr.""" + + +class RunningCommand: + """Running command.""" + + def __init__(self, process: subprocess.Popen) -> None: + """Init running command instance.""" + self.process = process + self._output_consumers: Optional[List[OutputConsumer]] = None + + def is_alive(self) -> bool: + """Return true if process is still alive.""" + return self.process.poll() is None + + def exit_code(self) -> Optional[int]: + """Return process's return code.""" + return self.process.poll() + + def stdout(self) -> Iterable[str]: + """Return std output of the process.""" + assert self.process.stdout is not None + + for line in self.process.stdout: + yield line + + def kill(self) -> None: + """Kill the process.""" + self.process.kill() + + def send_signal(self, signal_num: int) -> None: + """Send signal to the process.""" + self.process.send_signal(signal_num) + + @property + def output_consumers(self) -> Optional[List[OutputConsumer]]: + """Property output_consumers.""" + return self._output_consumers + + @output_consumers.setter + def output_consumers(self, output_consumers: List[OutputConsumer]) -> None: + """Set output consumers.""" + self._output_consumers = output_consumers + + def consume_output(self) -> None: + """Pass program's output to the consumers.""" + if self.process is None or self.output_consumers is None: + return + + for line in self.stdout(): + for consumer in self.output_consumers: + with suppress(): + consumer.feed(line) + + def stop( + self, wait: bool = True, num_of_attempts: int = 5, interval: float = 0.5 + ) -> None: + """Stop execution.""" + try: + if not self.is_alive(): + return + + self.process.send_signal(signal.SIGINT) + self.consume_output() + + if not wait: + return + + for _ in range(num_of_attempts): + time.sleep(interval) + if not self.is_alive(): + break + else: + raise Exception("Unable to stop running command") + finally: + self._close_fd() + + def _close_fd(self) -> None: + """Close file descriptors.""" + + def close(file_descriptor: Any) -> None: + """Check and close file.""" + if file_descriptor is not None and hasattr(file_descriptor, "close"): + file_descriptor.close() + + close(self.process.stdout) + close(self.process.stderr) + + def wait(self, redirect_output: bool = False) -> None: + """Redirect process output to stdout and wait for completion.""" + if redirect_output: + for line in self.stdout(): + print(line, end="") + + self.process.wait() + + +class CommandExecutor: + """Command executor.""" + + @staticmethod + def execute(command: List[str]) -> Tuple[int, bytes, bytes]: + """Execute the command.""" + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True + ) + + return (result.returncode, result.stdout, result.stderr) + + @staticmethod + def submit(command: List[str]) -> RunningCommand: + """Submit command for the execution.""" + process = subprocess.Popen( # pylint: disable=consider-using-with + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # redirect command stderr to stdout + universal_newlines=True, + bufsize=1, + ) + + return RunningCommand(process) + + +@contextmanager +def working_directory( + working_dir: Path, create_dir: bool = False +) -> Generator[Path, None, None]: + """Temporary change working directory.""" + current_working_dir = Path.cwd() + + if create_dir: + working_dir.mkdir() + + os.chdir(working_dir) + + try: + yield working_dir + finally: + os.chdir(current_working_dir) diff --git a/src/mlia/utils/types.py b/src/mlia/utils/types.py new file mode 100644 index 0000000..9b63928 --- /dev/null +++ b/src/mlia/utils/types.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Types related utility functions.""" +from typing import Any +from typing import Optional + + +def is_list_of(data: Any, cls: type, elem_num: Optional[int] = None) -> bool: + """Check if data is a list of object of the same class.""" + return ( + isinstance(data, (tuple, list)) + and all(isinstance(item, cls) for item in data) + and (elem_num is None or len(data) == elem_num) + ) + + +def is_number(value: str) -> bool: + """Return true if string contains a number.""" + try: + float(value) + except ValueError: + return False + + return True + + +def parse_int(value: Any, default: Optional[int] = None) -> Optional[int]: + """Parse integer value.""" + try: + return int(value) + except (TypeError, ValueError): + return default + + +def only_one_selected(*options: bool) -> bool: + """Return true if only one True value found.""" + return sum(options) == 1 -- cgit v1.2.1