aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/utils
diff options
context:
space:
mode:
Diffstat (limited to 'src/aiet/utils')
-rw-r--r--src/aiet/utils/__init__.py3
-rw-r--r--src/aiet/utils/fs.py116
-rw-r--r--src/aiet/utils/helpers.py17
-rw-r--r--src/aiet/utils/proc.py283
4 files changed, 419 insertions, 0 deletions
diff --git a/src/aiet/utils/__init__.py b/src/aiet/utils/__init__.py
new file mode 100644
index 0000000..fc7ef7c
--- /dev/null
+++ b/src/aiet/utils/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""This module contains all utils shared across aiet project."""
diff --git a/src/aiet/utils/fs.py b/src/aiet/utils/fs.py
new file mode 100644
index 0000000..ea99a69
--- /dev/null
+++ b/src/aiet/utils/fs.py
@@ -0,0 +1,116 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module to host all file system related functions."""
+import importlib.resources as pkg_resources
+import re
+import shutil
+from pathlib import Path
+from typing import Any
+from typing import Literal
+from typing import Optional
+
+ResourceType = Literal["applications", "systems", "tools"]
+
+
+def get_aiet_resources() -> Path:
+ """Get resources folder path."""
+ with pkg_resources.path("aiet", "__init__.py") as init_path:
+ project_root = init_path.parent
+ return project_root / "resources"
+
+
+def get_resources(name: ResourceType) -> Path:
+ """Return the absolute path of the specified resource.
+
+ It uses importlib to return resources packaged with MANIFEST.in.
+ """
+ if not name:
+ raise ResourceWarning("Resource name is not provided")
+
+ resource_path = get_aiet_resources() / name
+ if resource_path.is_dir():
+ return resource_path
+
+ raise ResourceWarning("Resource '{}' not found.".format(name))
+
+
+def copy_directory_content(source: Path, destination: Path) -> None:
+ """Copy content of the source directory into destination directory."""
+ for item in source.iterdir():
+ src = source / item.name
+ dest = destination / item.name
+
+ if src.is_dir():
+ shutil.copytree(src, dest)
+ else:
+ shutil.copy2(src, dest)
+
+
+def remove_resource(resource_directory: str, resource_type: ResourceType) -> None:
+ """Remove resource data."""
+ resources = get_resources(resource_type)
+
+ resource_location = resources / resource_directory
+ if not resource_location.exists():
+ raise Exception("Resource {} does not exist".format(resource_directory))
+
+ if not resource_location.is_dir():
+ raise Exception("Wrong resource {}".format(resource_directory))
+
+ shutil.rmtree(resource_location)
+
+
+def remove_directory(directory_path: Optional[Path]) -> None:
+ """Remove directory."""
+ if not directory_path or not directory_path.is_dir():
+ raise Exception("No directory path provided")
+
+ shutil.rmtree(directory_path)
+
+
+def recreate_directory(directory_path: Optional[Path]) -> None:
+ """Recreate directory."""
+ if not directory_path:
+ raise Exception("No directory path provided")
+
+ if directory_path.exists() and not directory_path.is_dir():
+ raise Exception(
+ "Path {} does exist and it is not a directory".format(str(directory_path))
+ )
+
+ if directory_path.is_dir():
+ remove_directory(directory_path)
+
+ directory_path.mkdir()
+
+
+def read_file(file_path: Path, mode: Optional[str] = None) -> Any:
+ """Read file as string or bytearray."""
+ if file_path.is_file():
+ if mode is not None:
+ # Ignore pylint warning because mode can be 'binary' as well which
+ # is not compatible with specifying encodings.
+ with open(file_path, mode) as file: # pylint: disable=unspecified-encoding
+ return file.read()
+ else:
+ with open(file_path, encoding="utf-8") as file:
+ return file.read()
+
+ if mode == "rb":
+ return b""
+ return ""
+
+
+def read_file_as_string(file_path: Path) -> str:
+ """Read file as string."""
+ return str(read_file(file_path))
+
+
+def read_file_as_bytearray(file_path: Path) -> bytearray:
+ """Read a file as bytearray."""
+ return bytearray(read_file(file_path, mode="rb"))
+
+
+def valid_for_filename(value: str, replacement: str = "") -> str:
+ """Replace non alpha numeric characters."""
+ return re.sub(r"[^\w.]", replacement, value, flags=re.ASCII)
diff --git a/src/aiet/utils/helpers.py b/src/aiet/utils/helpers.py
new file mode 100644
index 0000000..6d3cd22
--- /dev/null
+++ b/src/aiet/utils/helpers.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Helpers functions."""
+import logging
+from typing import Any
+
+
+def set_verbosity(
+ ctx: Any, option: Any, verbosity: Any # pylint: disable=unused-argument
+) -> None:
+ """Set the logging level according to the verbosity."""
+ # Unused arguments must be present here in definition as these are required in
+ # function definition when set as a callback
+ if verbosity == 1:
+ logging.getLogger().setLevel(logging.INFO)
+ elif verbosity > 1:
+ logging.getLogger().setLevel(logging.DEBUG)
diff --git a/src/aiet/utils/proc.py b/src/aiet/utils/proc.py
new file mode 100644
index 0000000..b6f4357
--- /dev/null
+++ b/src/aiet/utils/proc.py
@@ -0,0 +1,283 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Processes module.
+
+This module contains all classes and functions for dealing with Linux
+processes.
+"""
+import csv
+import datetime
+import logging
+import shlex
+import signal
+import time
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import NamedTuple
+from typing import Optional
+from typing import Tuple
+
+import psutil
+from sh import Command
+from sh import CommandNotFound
+from sh import ErrorReturnCode
+from sh import RunningCommand
+
+from aiet.utils.fs import valid_for_filename
+
+
+class CommandFailedException(Exception):
+ """Exception for failed command execution."""
+
+
+class ShellCommand:
+ """Wrapper class for shell commands."""
+
+ def __init__(self, base_log_path: str = "/tmp") -> None:
+ """Initialise the class.
+
+ base_log_path: it is the base directory where logs will be stored
+ """
+ self.base_log_path = base_log_path
+
+ def run(
+ self,
+ cmd: str,
+ *args: str,
+ _cwd: Optional[Path] = None,
+ _tee: bool = True,
+ _bg: bool = True,
+ _out: Any = None,
+ _err: Any = None,
+ _search_paths: Optional[List[Path]] = None
+ ) -> RunningCommand:
+ """Run the shell command with the given arguments.
+
+ There are special arguments that modify the behaviour of the process.
+ _cwd: current working directory
+ _tee: it redirects the stdout both to console and file
+ _bg: if True, it runs the process in background and the command is not
+ blocking.
+ _out: use this object for stdout redirect,
+ _err: use this object for stderr redirect,
+ _search_paths: If presented used for searching executable
+ """
+ try:
+ kwargs = {}
+ if _cwd:
+ kwargs["_cwd"] = str(_cwd)
+ command = Command(cmd, _search_paths).bake(args, **kwargs)
+ except CommandNotFound as error:
+ logging.error("Command '%s' not found", error.args[0])
+ raise error
+
+ out, err = _out, _err
+ if not _out and not _err:
+ out, err = [
+ str(item)
+ for item in self.get_stdout_stderr_paths(self.base_log_path, cmd)
+ ]
+
+ return command(_out=out, _err=err, _tee=_tee, _bg=_bg, _bg_exc=False)
+
+ @classmethod
+ def get_stdout_stderr_paths(cls, base_log_path: str, cmd: str) -> Tuple[Path, Path]:
+ """Construct and returns the paths of stdout/stderr files."""
+ timestamp = datetime.datetime.now().timestamp()
+ base_path = Path(base_log_path)
+ base = base_path / "{}_{}".format(valid_for_filename(cmd, "_"), timestamp)
+ stdout = base.with_suffix(".out")
+ stderr = base.with_suffix(".err")
+ try:
+ stdout.touch()
+ stderr.touch()
+ except FileNotFoundError as error:
+ logging.error("File not found: %s", error.filename)
+ raise error
+ return stdout, stderr
+
+
+def parse_command(command: str, shell: str = "bash") -> List[str]:
+ """Parse command."""
+ cmd, *args = shlex.split(command, posix=True)
+
+ if is_shell_script(cmd):
+ args = [cmd] + args
+ cmd = shell
+
+ return [cmd] + args
+
+
+def get_stdout_stderr_paths(
+ command: str, base_log_path: str = "/tmp"
+) -> Tuple[Path, Path]:
+ """Construct and returns the paths of stdout/stderr files."""
+ cmd, *_ = parse_command(command)
+
+ return ShellCommand.get_stdout_stderr_paths(base_log_path, cmd)
+
+
+def execute_command( # pylint: disable=invalid-name
+ command: str,
+ cwd: Path,
+ bg: bool = False,
+ shell: str = "bash",
+ out: Any = None,
+ err: Any = None,
+) -> RunningCommand:
+ """Execute shell command."""
+ cmd, *args = parse_command(command, shell)
+
+ search_paths = None
+ if cmd != shell and (cwd / cmd).is_file():
+ search_paths = [cwd]
+
+ return ShellCommand().run(
+ cmd, *args, _cwd=cwd, _bg=bg, _search_paths=search_paths, _out=out, _err=err
+ )
+
+
+def is_shell_script(cmd: str) -> bool:
+ """Check if command is shell script."""
+ return cmd.endswith(".sh")
+
+
+def run_and_wait(
+ command: str,
+ cwd: Path,
+ terminate_on_error: bool = False,
+ out: Any = None,
+ err: Any = None,
+) -> Tuple[int, bytearray, bytearray]:
+ """
+ Run command and wait while it is executing.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+ running_cmd: Optional[RunningCommand] = None
+ try:
+ running_cmd = execute_command(command, cwd, bg=True, out=out, err=err)
+ return running_cmd.exit_code, running_cmd.stdout, running_cmd.stderr
+ except ErrorReturnCode as cmd_failed:
+ raise CommandFailedException() from cmd_failed
+ except Exception as error:
+ is_running = running_cmd is not None and running_cmd.is_alive()
+ if terminate_on_error and is_running:
+ print("Terminating ...")
+ terminate_command(running_cmd)
+
+ raise error
+
+
+def terminate_command(
+ running_cmd: RunningCommand,
+ wait: bool = True,
+ wait_period: float = 0.5,
+ number_of_attempts: int = 20,
+) -> None:
+ """Terminate running command."""
+ try:
+ running_cmd.process.signal_group(signal.SIGINT)
+ if wait:
+ for _ in range(number_of_attempts):
+ time.sleep(wait_period)
+ if not running_cmd.is_alive():
+ return
+ print(
+ "Unable to terminate process {}. Sending SIGTERM...".format(
+ running_cmd.process.pid
+ )
+ )
+ running_cmd.process.signal_group(signal.SIGTERM)
+ except ProcessLookupError:
+ pass
+
+
+def terminate_external_process(
+ process: psutil.Process,
+ wait_period: float = 0.5,
+ number_of_attempts: int = 20,
+ wait_for_termination: float = 5.0,
+) -> None:
+ """Terminate external process."""
+ try:
+ process.terminate()
+ for _ in range(number_of_attempts):
+ if not process.is_running():
+ return
+ time.sleep(wait_period)
+
+ if process.is_running():
+ process.terminate()
+ time.sleep(wait_for_termination)
+ except psutil.Error:
+ print("Unable to terminate process")
+
+
+class ProcessInfo(NamedTuple):
+ """Process information."""
+
+ name: str
+ executable: str
+ cwd: str
+ pid: int
+
+
+def save_process_info(pid: int, pid_file: Path) -> None:
+ """Save process information to file."""
+ try:
+ parent = psutil.Process(pid)
+ children = parent.children(recursive=True)
+ family = [parent] + children
+
+ with open(pid_file, "w", encoding="utf-8") as file:
+ csv_writer = csv.writer(file)
+ for member in family:
+ process_info = ProcessInfo(
+ name=member.name(),
+ executable=member.exe(),
+ cwd=member.cwd(),
+ pid=member.pid,
+ )
+ csv_writer.writerow(process_info)
+ except psutil.NoSuchProcess:
+ # if process does not exist or finishes before
+ # function call then nothing could be saved
+ # just ignore this exception and exit
+ pass
+
+
+def read_process_info(pid_file: Path) -> List[ProcessInfo]:
+ """Read information about previous system processes."""
+ if not pid_file.is_file():
+ return []
+
+ result = []
+ with open(pid_file, encoding="utf-8") as file:
+ csv_reader = csv.reader(file)
+ for row in csv_reader:
+ name, executable, cwd, pid = row
+ result.append(
+ ProcessInfo(name=name, executable=executable, cwd=cwd, pid=int(pid))
+ )
+
+ return result
+
+
+def print_command_stdout(command: RunningCommand) -> None:
+ """Print the stdout of a command.
+
+ The command has 2 states: running and done.
+ If the command is running, the output is taken by the running process.
+ If the command has ended its execution, the stdout is taken from stdout
+ property
+ """
+ if command.is_alive():
+ while True:
+ try:
+ print(command.next(), end="")
+ except StopIteration:
+ break
+ else:
+ print(command.stdout)