aboutsummaryrefslogtreecommitdiff
path: root/src/aiet
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-06-28 10:29:35 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-07-08 10:57:19 +0100
commitc9b4089b3037b5943565d76242d3016b8776f8d2 (patch)
tree3de24f79dedf0f26f492a7fa1562bf684e13a055 /src/aiet
parentba2c7fcccf37e8c81946f0776714c64f73191787 (diff)
downloadmlia-c9b4089b3037b5943565d76242d3016b8776f8d2.tar.gz
MLIA-546 Merge AIET into MLIA
Merge the deprecated AIET interface for backend execution into MLIA: - Execute backends directly (without subprocess and the aiet CLI) - Fix issues with the unit tests - Remove src/aiet and tests/aiet - Re-factor code to replace 'aiet' with 'backend' - Adapt and improve unit tests after re-factoring - Remove dependencies that are not needed anymore (click and cloup) Change-Id: I450734c6a3f705ba9afde41862b29e797e511f7c
Diffstat (limited to 'src/aiet')
-rw-r--r--src/aiet/__init__.py7
-rw-r--r--src/aiet/backend/__init__.py3
-rw-r--r--src/aiet/backend/application.py187
-rw-r--r--src/aiet/backend/common.py532
-rw-r--r--src/aiet/backend/config.py107
-rw-r--r--src/aiet/backend/controller.py134
-rw-r--r--src/aiet/backend/execution.py859
-rw-r--r--src/aiet/backend/output_parser.py176
-rw-r--r--src/aiet/backend/protocol.py325
-rw-r--r--src/aiet/backend/source.py209
-rw-r--r--src/aiet/backend/system.py289
-rw-r--r--src/aiet/backend/tool.py109
-rw-r--r--src/aiet/cli/__init__.py28
-rw-r--r--src/aiet/cli/application.py362
-rw-r--r--src/aiet/cli/common.py173
-rw-r--r--src/aiet/cli/completion.py72
-rw-r--r--src/aiet/cli/system.py122
-rw-r--r--src/aiet/cli/tool.py143
-rw-r--r--src/aiet/main.py13
-rw-r--r--src/aiet/resources/applications/.gitignore6
-rw-r--r--src/aiet/resources/systems/.gitignore6
-rw-r--r--src/aiet/resources/tools/vela/aiet-config.json73
-rw-r--r--src/aiet/resources/tools/vela/aiet-config.json.license3
-rw-r--r--src/aiet/resources/tools/vela/check_model.py75
-rw-r--r--src/aiet/resources/tools/vela/run_vela.py65
-rw-r--r--src/aiet/resources/tools/vela/vela.ini53
-rw-r--r--src/aiet/utils/__init__.py3
-rw-r--r--src/aiet/utils/fs.py116
-rw-r--r--src/aiet/utils/helpers.py17
-rw-r--r--src/aiet/utils/proc.py283
30 files changed, 0 insertions, 4550 deletions
diff --git a/src/aiet/__init__.py b/src/aiet/__init__.py
deleted file mode 100644
index 49304b1..0000000
--- a/src/aiet/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Init of aiet."""
-import pkg_resources
-
-
-__version__ = pkg_resources.get_distribution("mlia").version
diff --git a/src/aiet/backend/__init__.py b/src/aiet/backend/__init__.py
deleted file mode 100644
index 3d60372..0000000
--- a/src/aiet/backend/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Backend module."""
diff --git a/src/aiet/backend/application.py b/src/aiet/backend/application.py
deleted file mode 100644
index f6ef815..0000000
--- a/src/aiet/backend/application.py
+++ /dev/null
@@ -1,187 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Application backend module."""
-import re
-from pathlib import Path
-from typing import Any
-from typing import cast
-from typing import Dict
-from typing import List
-from typing import Optional
-
-from aiet.backend.common import Backend
-from aiet.backend.common import ConfigurationException
-from aiet.backend.common import DataPaths
-from aiet.backend.common import get_backend_configs
-from aiet.backend.common import get_backend_directories
-from aiet.backend.common import load_application_or_tool_configs
-from aiet.backend.common import load_config
-from aiet.backend.common import remove_backend
-from aiet.backend.config import ApplicationConfig
-from aiet.backend.config import ExtendedApplicationConfig
-from aiet.backend.source import create_destination_and_install
-from aiet.backend.source import get_source
-from aiet.utils.fs import get_resources
-
-
-def get_available_application_directory_names() -> List[str]:
- """Return a list of directory names for all available applications."""
- return [entry.name for entry in get_backend_directories("applications")]
-
-
-def get_available_applications() -> List["Application"]:
- """Return a list with all available applications."""
- available_applications = []
- for config_json in get_backend_configs("applications"):
- config_entries = cast(List[ExtendedApplicationConfig], load_config(config_json))
- for config_entry in config_entries:
- config_entry["config_location"] = config_json.parent.absolute()
- applications = load_applications(config_entry)
- available_applications += applications
-
- return sorted(available_applications, key=lambda application: application.name)
-
-
-def get_application(
- application_name: str, system_name: Optional[str] = None
-) -> List["Application"]:
- """Return a list of application instances with provided name."""
- return [
- application
- for application in get_available_applications()
- if application.name == application_name
- and (not system_name or application.can_run_on(system_name))
- ]
-
-
-def install_application(source_path: Path) -> None:
- """Install application."""
- try:
- source = get_source(source_path)
- config = cast(List[ExtendedApplicationConfig], source.config())
- applications_to_install = [
- s for entry in config for s in load_applications(entry)
- ]
- except Exception as error:
- raise ConfigurationException("Unable to read application definition") from error
-
- if not applications_to_install:
- raise ConfigurationException("No application definition found")
-
- available_applications = get_available_applications()
- already_installed = [
- s for s in applications_to_install if s in available_applications
- ]
- if already_installed:
- names = {application.name for application in already_installed}
- raise ConfigurationException(
- "Applications [{}] are already installed".format(",".join(names))
- )
-
- create_destination_and_install(source, get_resources("applications"))
-
-
-def remove_application(directory_name: str) -> None:
- """Remove application directory."""
- remove_backend(directory_name, "applications")
-
-
-def get_unique_application_names(system_name: Optional[str] = None) -> List[str]:
- """Extract a list of unique application names of all application available."""
- return list(
- set(
- application.name
- for application in get_available_applications()
- if not system_name or application.can_run_on(system_name)
- )
- )
-
-
-class Application(Backend):
- """Class for representing a single application component."""
-
- def __init__(self, config: ApplicationConfig) -> None:
- """Construct a Application instance from a dict."""
- super().__init__(config)
-
- self.supported_systems = config.get("supported_systems", [])
- self.deploy_data = config.get("deploy_data", [])
-
- def __eq__(self, other: object) -> bool:
- """Overload operator ==."""
- if not isinstance(other, Application):
- return False
-
- return (
- super().__eq__(other)
- and self.name == other.name
- and set(self.supported_systems) == set(other.supported_systems)
- )
-
- def can_run_on(self, system_name: str) -> bool:
- """Check if the application can run on the system passed as argument."""
- return system_name in self.supported_systems
-
- def get_deploy_data(self) -> List[DataPaths]:
- """Validate and return data specified in the config file."""
- if self.config_location is None:
- raise ConfigurationException(
- "Unable to get application {} config location".format(self.name)
- )
-
- deploy_data = []
- for item in self.deploy_data:
- src, dst = item
- src_full_path = self.config_location / src
- assert src_full_path.exists(), "{} does not exists".format(src_full_path)
- deploy_data.append(DataPaths(src_full_path, dst))
- return deploy_data
-
- def get_details(self) -> Dict[str, Any]:
- """Return dictionary with information about the Application instance."""
- output = {
- "type": "application",
- "name": self.name,
- "description": self.description,
- "supported_systems": self.supported_systems,
- "commands": self._get_command_details(),
- }
-
- return output
-
- def remove_unused_params(self) -> None:
- """Remove unused params in commands.
-
- After merging default and system related configuration application
- could have parameters that are not being used in commands. They
- should be removed.
- """
- for command in self.commands.values():
- indexes_or_aliases = [
- m
- for cmd_str in command.command_strings
- for m in re.findall(r"{user_params:(?P<index_or_alias>\w+)}", cmd_str)
- ]
-
- only_aliases = all(not item.isnumeric() for item in indexes_or_aliases)
- if only_aliases:
- used_params = [
- param
- for param in command.params
- if param.alias in indexes_or_aliases
- ]
- command.params = used_params
-
-
-def load_applications(config: ExtendedApplicationConfig) -> List[Application]:
- """Load application.
-
- Application configuration could contain different parameters/commands for different
- supported systems. For each supported system this function will return separate
- Application instance with appropriate configuration.
- """
- configs = load_application_or_tool_configs(config, ApplicationConfig)
- applications = [Application(cfg) for cfg in configs]
- for application in applications:
- application.remove_unused_params()
- return applications
diff --git a/src/aiet/backend/common.py b/src/aiet/backend/common.py
deleted file mode 100644
index b887ee7..0000000
--- a/src/aiet/backend/common.py
+++ /dev/null
@@ -1,532 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Contain all common functions for the backends."""
-import json
-import logging
-import re
-from abc import ABC
-from collections import Counter
-from pathlib import Path
-from typing import Any
-from typing import Callable
-from typing import cast
-from typing import Dict
-from typing import Final
-from typing import IO
-from typing import Iterable
-from typing import List
-from typing import Match
-from typing import NamedTuple
-from typing import Optional
-from typing import Pattern
-from typing import Tuple
-from typing import Type
-from typing import Union
-
-from aiet.backend.config import BackendConfig
-from aiet.backend.config import BaseBackendConfig
-from aiet.backend.config import NamedExecutionConfig
-from aiet.backend.config import UserParamConfig
-from aiet.backend.config import UserParamsConfig
-from aiet.utils.fs import get_resources
-from aiet.utils.fs import remove_resource
-from aiet.utils.fs import ResourceType
-
-
-AIET_CONFIG_FILE: Final[str] = "aiet-config.json"
-
-
-class ConfigurationException(Exception):
- """Configuration exception."""
-
-
-def get_backend_config(dir_path: Path) -> Path:
- """Get path to backendir configuration file."""
- return dir_path / AIET_CONFIG_FILE
-
-
-def get_backend_configs(resource_type: ResourceType) -> Iterable[Path]:
- """Get path to the backend configs for provided resource_type."""
- return (
- get_backend_config(entry) for entry in get_backend_directories(resource_type)
- )
-
-
-def get_backend_directories(resource_type: ResourceType) -> Iterable[Path]:
- """Get path to the backend directories for provided resource_type."""
- return (
- entry
- for entry in get_resources(resource_type).iterdir()
- if is_backend_directory(entry)
- )
-
-
-def is_backend_directory(dir_path: Path) -> bool:
- """Check if path is backend's configuration directory."""
- return dir_path.is_dir() and get_backend_config(dir_path).is_file()
-
-
-def remove_backend(directory_name: str, resource_type: ResourceType) -> None:
- """Remove backend with provided type and directory_name."""
- if not directory_name:
- raise Exception("No directory name provided")
-
- remove_resource(directory_name, resource_type)
-
-
-def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig:
- """Return a loaded json file."""
- if config is None:
- raise Exception("Unable to read config")
-
- if isinstance(config, Path):
- with config.open() as json_file:
- return cast(BackendConfig, json.load(json_file))
-
- return cast(BackendConfig, json.load(config))
-
-
-def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]:
- """Split the parameter string in name and optional value.
-
- It manages the following cases:
- --param=1 -> --param, 1
- --param 1 -> --param, 1
- --flag -> --flag, None
- """
- data = re.split(" |=", parameter)
- if len(data) == 1:
- param_name = data[0]
- param_value = None
- else:
- param_name = " ".join(data[0:-1])
- param_value = data[-1]
- return param_name, param_value
-
-
-class DataPaths(NamedTuple):
- """DataPaths class."""
-
- src: Path
- dst: str
-
-
-class Backend(ABC):
- """Backend class."""
-
- # pylint: disable=too-many-instance-attributes
-
- def __init__(self, config: BaseBackendConfig):
- """Initialize backend."""
- name = config.get("name")
- if not name:
- raise ConfigurationException("Name is empty")
-
- self.name = name
- self.description = config.get("description", "")
- self.config_location = config.get("config_location")
- self.variables = config.get("variables", {})
- self.build_dir = config.get("build_dir")
- self.lock = config.get("lock", False)
- if self.build_dir:
- self.build_dir = self._substitute_variables(self.build_dir)
- self.annotations = config.get("annotations", {})
-
- self._parse_commands_and_params(config)
-
- def validate_parameter(self, command_name: str, parameter: str) -> bool:
- """Validate the parameter string against the application configuration.
-
- We take the parameter string, extract the parameter name/value and
- check them against the current configuration.
- """
- param_name, param_value = parse_raw_parameter(parameter)
- valid_param_name = valid_param_value = False
-
- command = self.commands.get(command_name)
- if not command:
- raise AttributeError("Unknown command: '{}'".format(command_name))
-
- # Iterate over all available parameters until we have a match.
- for param in command.params:
- if self._same_parameter(param_name, param):
- valid_param_name = True
- # This is a non-empty list
- if param.values:
- # We check if the value is allowed in the configuration
- valid_param_value = param_value in param.values
- else:
- # In this case we don't validate the value and accept
- # whatever we have set.
- valid_param_value = True
- break
-
- return valid_param_name and valid_param_value
-
- def __eq__(self, other: object) -> bool:
- """Overload operator ==."""
- if not isinstance(other, Backend):
- return False
-
- return (
- self.name == other.name
- and self.description == other.description
- and self.commands == other.commands
- )
-
- def __repr__(self) -> str:
- """Represent the Backend instance by its name."""
- return self.name
-
- def _parse_commands_and_params(self, config: BaseBackendConfig) -> None:
- """Parse commands and user parameters."""
- self.commands: Dict[str, Command] = {}
-
- commands = config.get("commands")
- if commands:
- params = config.get("user_params")
-
- for command_name in commands.keys():
- command_params = self._parse_params(params, command_name)
- command_strings = [
- self._substitute_variables(cmd)
- for cmd in commands.get(command_name, [])
- ]
- self.commands[command_name] = Command(command_strings, command_params)
-
- def _substitute_variables(self, str_val: str) -> str:
- """Substitute variables in string.
-
- Variables is being substituted at backend's creation stage because
- they could contain references to other params which will be
- resolved later.
- """
- if not str_val:
- return str_val
-
- var_pattern: Final[Pattern] = re.compile(r"{variables:(?P<var_name>\w+)}")
-
- def var_value(match: Match) -> str:
- var_name = match["var_name"]
- if var_name not in self.variables:
- raise ConfigurationException("Unknown variable {}".format(var_name))
-
- return self.variables[var_name]
-
- return var_pattern.sub(var_value, str_val) # type: ignore
-
- @classmethod
- def _parse_params(
- cls, params: Optional[UserParamsConfig], command: str
- ) -> List["Param"]:
- if not params:
- return []
-
- return [cls._parse_param(p) for p in params.get(command, [])]
-
- @classmethod
- def _parse_param(cls, param: UserParamConfig) -> "Param":
- """Parse a single parameter."""
- name = param.get("name")
- if name is not None and not name:
- raise ConfigurationException("Parameter has an empty 'name' attribute.")
- values = param.get("values", None)
- default_value = param.get("default_value", None)
- description = param.get("description", "")
- alias = param.get("alias")
-
- return Param(
- name=name,
- description=description,
- values=values,
- default_value=default_value,
- alias=alias,
- )
-
- def _get_command_details(self) -> Dict:
- command_details = {
- command_name: command.get_details()
- for command_name, command in self.commands.items()
- }
- return command_details
-
- def _get_user_param_value(
- self, user_params: List[str], param: "Param"
- ) -> Optional[str]:
- """Get the user-specified value of a parameter."""
- for user_param in user_params:
- user_param_name, user_param_value = parse_raw_parameter(user_param)
- if user_param_name == param.name:
- warn_message = (
- "The direct use of parameter name is deprecated"
- " and might be removed in the future.\n"
- f"Please use alias '{param.alias}' instead of "
- "'{user_param_name}' to provide the parameter."
- )
- logging.warning(warn_message)
-
- if self._same_parameter(user_param_name, param):
- return user_param_value
-
- return None
-
- @staticmethod
- def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool:
- """Compare user parameter name with param name or alias."""
- # Strip the "=" sign in the param_name. This is needed just for
- # comparison with the parameters passed by the user.
- # The equal sign needs to be honoured when re-building the
- # parameter back.
- param_name = None if not param.name else param.name.rstrip("=")
- return user_param_name_or_alias in [param_name, param.alias]
-
- def resolved_parameters(
- self, command_name: str, user_params: List[str]
- ) -> List[Tuple[Optional[str], "Param"]]:
- """Return list of parameters with values."""
- result: List[Tuple[Optional[str], "Param"]] = []
- command = self.commands.get(command_name)
- if not command:
- return result
-
- for param in command.params:
- value = self._get_user_param_value(user_params, param)
- if not value:
- value = param.default_value
- result.append((value, param))
-
- return result
-
- def build_command(
- self,
- command_name: str,
- user_params: List[str],
- param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str],
- ) -> List[str]:
- """
- Return a list of executable command strings.
-
- Given a command and associated parameters, returns a list of executable command
- strings.
- """
- command = self.commands.get(command_name)
- if not command:
- raise ConfigurationException(
- "Command '{}' could not be found.".format(command_name)
- )
-
- commands_to_run = []
-
- params_values = self.resolved_parameters(command_name, user_params)
- for cmd_str in command.command_strings:
- cmd_str = resolve_all_parameters(
- cmd_str, param_resolver, command_name, params_values
- )
- commands_to_run.append(cmd_str)
-
- return commands_to_run
-
-
-class Param:
- """Class for representing a generic application parameter."""
-
- def __init__( # pylint: disable=too-many-arguments
- self,
- name: Optional[str],
- description: str,
- values: Optional[List[str]] = None,
- default_value: Optional[str] = None,
- alias: Optional[str] = None,
- ) -> None:
- """Construct a Param instance."""
- if not name and not alias:
- raise ConfigurationException(
- "Either name, alias or both must be set to identify a parameter."
- )
- self.name = name
- self.values = values
- self.description = description
- self.default_value = default_value
- self.alias = alias
-
- def get_details(self) -> Dict:
- """Return a dictionary with all relevant information of a Param."""
- return {key: value for key, value in self.__dict__.items() if value}
-
- def __eq__(self, other: object) -> bool:
- """Overload operator ==."""
- if not isinstance(other, Param):
- return False
-
- return (
- self.name == other.name
- and self.values == other.values
- and self.default_value == other.default_value
- and self.description == other.description
- )
-
-
-class Command:
- """Class for representing a command."""
-
- def __init__(
- self, command_strings: List[str], params: Optional[List[Param]] = None
- ) -> None:
- """Construct a Command instance."""
- self.command_strings = command_strings
-
- if params:
- self.params = params
- else:
- self.params = []
-
- self._validate()
-
- def _validate(self) -> None:
- """Validate command."""
- if not self.params:
- return
-
- aliases = [param.alias for param in self.params if param.alias is not None]
- repeated_aliases = [
- alias for alias, count in Counter(aliases).items() if count > 1
- ]
-
- if repeated_aliases:
- raise ConfigurationException(
- "Non unique aliases {}".format(", ".join(repeated_aliases))
- )
-
- both_name_and_alias = [
- param.name
- for param in self.params
- if param.name in aliases and param.name != param.alias
- ]
- if both_name_and_alias:
- raise ConfigurationException(
- "Aliases {} could not be used as parameter name".format(
- ", ".join(both_name_and_alias)
- )
- )
-
- def get_details(self) -> Dict:
- """Return a dictionary with all relevant information of a Command."""
- output = {
- "command_strings": self.command_strings,
- "user_params": [param.get_details() for param in self.params],
- }
- return output
-
- def __eq__(self, other: object) -> bool:
- """Overload operator ==."""
- if not isinstance(other, Command):
- return False
-
- return (
- self.command_strings == other.command_strings
- and self.params == other.params
- )
-
-
-def resolve_all_parameters(
- str_val: str,
- param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str],
- command_name: Optional[str] = None,
- params_values: Optional[List[Tuple[Optional[str], Param]]] = None,
-) -> str:
- """Resolve all parameters in the string."""
- if not str_val:
- return str_val
-
- param_pattern: Final[Pattern] = re.compile(r"{(?P<param_name>[\w.:]+)}")
- while param_pattern.findall(str_val):
- str_val = param_pattern.sub(
- lambda m: param_resolver(
- m["param_name"], command_name or "", params_values or []
- ),
- str_val,
- )
- return str_val
-
-
-def load_application_or_tool_configs(
- config: Any,
- config_type: Type[Any],
- is_system_required: bool = True,
-) -> Any:
- """Get one config for each system supported by the application/tool.
-
- The configuration could contain different parameters/commands for different
- supported systems. For each supported system this function will return separate
- config with appropriate configuration.
- """
- merged_configs = []
- supported_systems: Optional[List[NamedExecutionConfig]] = config.get(
- "supported_systems"
- )
- if not supported_systems:
- if is_system_required:
- raise ConfigurationException("No supported systems definition provided")
- # Create an empty system to be used in the parsing below
- supported_systems = [cast(NamedExecutionConfig, {})]
-
- default_user_params = config.get("user_params", {})
-
- def merge_config(system: NamedExecutionConfig) -> Any:
- system_name = system.get("name")
- if not system_name and is_system_required:
- raise ConfigurationException(
- "Unable to read supported system definition, name is missed"
- )
-
- merged_config = config_type(**config)
- merged_config["supported_systems"] = [system_name] if system_name else []
- # merge default configuration and specific to the system
- merged_config["commands"] = {
- **config.get("commands", {}),
- **system.get("commands", {}),
- }
-
- params = {}
- tool_user_params = system.get("user_params", {})
- command_names = tool_user_params.keys() | default_user_params.keys()
- for command_name in command_names:
- if command_name not in merged_config["commands"]:
- continue
-
- params_default = default_user_params.get(command_name, [])
- params_tool = tool_user_params.get(command_name, [])
- if not params_default or not params_tool:
- params[command_name] = params_tool or params_default
- if params_default and params_tool:
- if any(not p.get("alias") for p in params_default):
- raise ConfigurationException(
- "Default parameters for command {} should have aliases".format(
- command_name
- )
- )
- if any(not p.get("alias") for p in params_tool):
- raise ConfigurationException(
- "{} parameters for command {} should have aliases".format(
- system_name, command_name
- )
- )
-
- merged_by_alias = {
- **{p.get("alias"): p for p in params_default},
- **{p.get("alias"): p for p in params_tool},
- }
- params[command_name] = list(merged_by_alias.values())
-
- merged_config["user_params"] = params
- merged_config["build_dir"] = system.get("build_dir", config.get("build_dir"))
- merged_config["lock"] = system.get("lock", config.get("lock", False))
- merged_config["variables"] = {
- **config.get("variables", {}),
- **system.get("variables", {}),
- }
- return merged_config
-
- merged_configs = [merge_config(system) for system in supported_systems]
-
- return merged_configs
diff --git a/src/aiet/backend/config.py b/src/aiet/backend/config.py
deleted file mode 100644
index dd42012..0000000
--- a/src/aiet/backend/config.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Contain definition of backend configuration."""
-from pathlib import Path
-from typing import Dict
-from typing import List
-from typing import Literal
-from typing import Optional
-from typing import Tuple
-from typing import TypedDict
-from typing import Union
-
-
-class UserParamConfig(TypedDict, total=False):
- """User parameter configuration."""
-
- name: Optional[str]
- default_value: str
- values: List[str]
- description: str
- alias: str
-
-
-UserParamsConfig = Dict[str, List[UserParamConfig]]
-
-
-class ExecutionConfig(TypedDict, total=False):
- """Execution configuration."""
-
- commands: Dict[str, List[str]]
- user_params: UserParamsConfig
- build_dir: str
- variables: Dict[str, str]
- lock: bool
-
-
-class NamedExecutionConfig(ExecutionConfig):
- """Execution configuration with name."""
-
- name: str
-
-
-class BaseBackendConfig(ExecutionConfig, total=False):
- """Base backend configuration."""
-
- name: str
- description: str
- config_location: Path
- annotations: Dict[str, Union[str, List[str]]]
-
-
-class ApplicationConfig(BaseBackendConfig, total=False):
- """Application configuration."""
-
- supported_systems: List[str]
- deploy_data: List[Tuple[str, str]]
-
-
-class ExtendedApplicationConfig(BaseBackendConfig, total=False):
- """Extended application configuration."""
-
- supported_systems: List[NamedExecutionConfig]
- deploy_data: List[Tuple[str, str]]
-
-
-class ProtocolConfig(TypedDict, total=False):
- """Protocol config."""
-
- protocol: Literal["local", "ssh"]
-
-
-class SSHConfig(ProtocolConfig, total=False):
- """SSH configuration."""
-
- username: str
- password: str
- hostname: str
- port: str
-
-
-class LocalProtocolConfig(ProtocolConfig, total=False):
- """Local protocol config."""
-
-
-class SystemConfig(BaseBackendConfig, total=False):
- """System configuration."""
-
- data_transfer: Union[SSHConfig, LocalProtocolConfig]
- reporting: Dict[str, Dict]
-
-
-class ToolConfig(BaseBackendConfig, total=False):
- """Tool configuration."""
-
- supported_systems: List[str]
-
-
-class ExtendedToolConfig(BaseBackendConfig, total=False):
- """Extended tool configuration."""
-
- supported_systems: List[NamedExecutionConfig]
-
-
-BackendItemConfig = Union[ApplicationConfig, SystemConfig, ToolConfig]
-BackendConfig = Union[
- List[ExtendedApplicationConfig], List[SystemConfig], List[ToolConfig]
-]
diff --git a/src/aiet/backend/controller.py b/src/aiet/backend/controller.py
deleted file mode 100644
index 2650902..0000000
--- a/src/aiet/backend/controller.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Controller backend module."""
-import time
-from pathlib import Path
-from typing import List
-from typing import Optional
-from typing import Tuple
-
-import psutil
-import sh
-
-from aiet.backend.common import ConfigurationException
-from aiet.utils.fs import read_file_as_string
-from aiet.utils.proc import execute_command
-from aiet.utils.proc import get_stdout_stderr_paths
-from aiet.utils.proc import read_process_info
-from aiet.utils.proc import save_process_info
-from aiet.utils.proc import terminate_command
-from aiet.utils.proc import terminate_external_process
-
-
-class SystemController:
- """System controller class."""
-
- def __init__(self) -> None:
- """Create new instance of service controller."""
- self.cmd: Optional[sh.RunningCommand] = None
- self.out_path: Optional[Path] = None
- self.err_path: Optional[Path] = None
-
- def before_start(self) -> None:
- """Run actions before system start."""
-
- def after_start(self) -> None:
- """Run actions after system start."""
-
- def start(self, commands: List[str], cwd: Path) -> None:
- """Start system."""
- if not isinstance(cwd, Path) or not cwd.is_dir():
- raise ConfigurationException("Wrong working directory {}".format(cwd))
-
- if len(commands) != 1:
- raise ConfigurationException("System should have only one command to run")
-
- startup_command = commands[0]
- if not startup_command:
- raise ConfigurationException("No startup command provided")
-
- self.before_start()
-
- self.out_path, self.err_path = get_stdout_stderr_paths(startup_command)
-
- self.cmd = execute_command(
- startup_command,
- cwd,
- bg=True,
- out=str(self.out_path),
- err=str(self.err_path),
- )
-
- self.after_start()
-
- def stop(
- self, wait: bool = False, wait_period: float = 0.5, number_of_attempts: int = 20
- ) -> None:
- """Stop system."""
- if self.cmd is not None and self.is_running():
- terminate_command(self.cmd, wait, wait_period, number_of_attempts)
-
- def is_running(self) -> bool:
- """Check if underlying process is running."""
- return self.cmd is not None and self.cmd.is_alive()
-
- def get_output(self) -> Tuple[str, str]:
- """Return application output."""
- if self.cmd is None or self.out_path is None or self.err_path is None:
- return ("", "")
-
- return (read_file_as_string(self.out_path), read_file_as_string(self.err_path))
-
-
-class SystemControllerSingleInstance(SystemController):
- """System controller with support of system's single instance."""
-
- def __init__(self, pid_file_path: Optional[Path] = None) -> None:
- """Create new instance of the service controller."""
- super().__init__()
- self.pid_file_path = pid_file_path
-
- def before_start(self) -> None:
- """Run actions before system start."""
- self._check_if_previous_instance_is_running()
-
- def after_start(self) -> None:
- """Run actions after system start."""
- self._save_process_info()
-
- def _check_if_previous_instance_is_running(self) -> None:
- """Check if another instance of the system is running."""
- process_info = read_process_info(self._pid_file())
-
- for item in process_info:
- try:
- process = psutil.Process(item.pid)
- same_process = (
- process.name() == item.name
- and process.exe() == item.executable
- and process.cwd() == item.cwd
- )
- if same_process:
- print(
- "Stopping previous instance of the system [{}]".format(item.pid)
- )
- terminate_external_process(process)
- except psutil.NoSuchProcess:
- pass
-
- def _save_process_info(self, wait_period: float = 2) -> None:
- """Save information about system's processes."""
- if self.cmd is None or not self.is_running():
- return
-
- # give some time for the system to start
- time.sleep(wait_period)
-
- save_process_info(self.cmd.process.pid, self._pid_file())
-
- def _pid_file(self) -> Path:
- """Return path to file which is used for saving process info."""
- if not self.pid_file_path:
- raise Exception("No pid file path presented")
-
- return self.pid_file_path
diff --git a/src/aiet/backend/execution.py b/src/aiet/backend/execution.py
deleted file mode 100644
index 1653ee2..0000000
--- a/src/aiet/backend/execution.py
+++ /dev/null
@@ -1,859 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Application execution module."""
-import itertools
-import json
-import random
-import re
-import string
-import sys
-import time
-import warnings
-from collections import defaultdict
-from contextlib import contextmanager
-from contextlib import ExitStack
-from pathlib import Path
-from typing import Any
-from typing import Callable
-from typing import cast
-from typing import ContextManager
-from typing import Dict
-from typing import Generator
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import TypedDict
-from typing import Union
-
-from filelock import FileLock
-from filelock import Timeout
-
-from aiet.backend.application import Application
-from aiet.backend.application import get_application
-from aiet.backend.common import Backend
-from aiet.backend.common import ConfigurationException
-from aiet.backend.common import DataPaths
-from aiet.backend.common import Param
-from aiet.backend.common import parse_raw_parameter
-from aiet.backend.common import resolve_all_parameters
-from aiet.backend.output_parser import Base64OutputParser
-from aiet.backend.output_parser import OutputParser
-from aiet.backend.output_parser import RegexOutputParser
-from aiet.backend.system import ControlledSystem
-from aiet.backend.system import get_system
-from aiet.backend.system import StandaloneSystem
-from aiet.backend.system import System
-from aiet.backend.tool import get_tool
-from aiet.backend.tool import Tool
-from aiet.utils.fs import recreate_directory
-from aiet.utils.fs import remove_directory
-from aiet.utils.fs import valid_for_filename
-from aiet.utils.proc import run_and_wait
-
-
-class AnotherInstanceIsRunningException(Exception):
- """Concurrent execution error."""
-
-
-class ConnectionException(Exception):
- """Connection exception."""
-
-
-class ExecutionParams(TypedDict, total=False):
- """Execution parameters."""
-
- disable_locking: bool
- unique_build_dir: bool
-
-
-class ExecutionContext:
- """Command execution context."""
-
- # pylint: disable=too-many-arguments,too-many-instance-attributes
- def __init__(
- self,
- app: Union[Application, Tool],
- app_params: List[str],
- system: Optional[System],
- system_params: List[str],
- custom_deploy_data: Optional[List[DataPaths]] = None,
- execution_params: Optional[ExecutionParams] = None,
- report_file: Optional[Path] = None,
- ):
- """Init execution context."""
- self.app = app
- self.app_params = app_params
- self.custom_deploy_data = custom_deploy_data or []
- self.system = system
- self.system_params = system_params
- self.execution_params = execution_params or ExecutionParams()
- self.report_file = report_file
-
- self.reporter: Optional[Reporter]
- if self.report_file:
- # Create reporter with output parsers
- parsers: List[OutputParser] = []
- if system and system.reporting:
- # Add RegexOutputParser, if it is configured in the system
- parsers.append(RegexOutputParser("system", system.reporting["regex"]))
- # Add Base64 parser for applications
- parsers.append(Base64OutputParser("application"))
- self.reporter = Reporter(parsers=parsers)
- else:
- self.reporter = None # No reporter needed.
-
- self.param_resolver = ParamResolver(self)
- self._resolved_build_dir: Optional[Path] = None
-
- @property
- def is_deploy_needed(self) -> bool:
- """Check if application requires data deployment."""
- if isinstance(self.app, Application):
- return (
- len(self.app.get_deploy_data()) > 0 or len(self.custom_deploy_data) > 0
- )
- return False
-
- @property
- def is_locking_required(self) -> bool:
- """Return true if any form of locking required."""
- return not self._disable_locking() and (
- self.app.lock or (self.system is not None and self.system.lock)
- )
-
- @property
- def is_build_required(self) -> bool:
- """Return true if application build required."""
- return "build" in self.app.commands
-
- @property
- def is_unique_build_dir_required(self) -> bool:
- """Return true if unique build dir required."""
- return self.execution_params.get("unique_build_dir", False)
-
- def build_dir(self) -> Path:
- """Return resolved application build dir."""
- if self._resolved_build_dir is not None:
- return self._resolved_build_dir
-
- if (
- not isinstance(self.app.config_location, Path)
- or not self.app.config_location.is_dir()
- ):
- raise ConfigurationException(
- "Application {} has wrong config location".format(self.app.name)
- )
-
- _build_dir = self.app.build_dir
- if _build_dir:
- _build_dir = resolve_all_parameters(_build_dir, self.param_resolver)
-
- if not _build_dir:
- raise ConfigurationException(
- "No build directory defined for the app {}".format(self.app.name)
- )
-
- if self.is_unique_build_dir_required:
- random_suffix = "".join(
- random.choices(string.ascii_lowercase + string.digits, k=7)
- )
- _build_dir = "{}_{}".format(_build_dir, random_suffix)
-
- self._resolved_build_dir = self.app.config_location / _build_dir
- return self._resolved_build_dir
-
- def _disable_locking(self) -> bool:
- """Return true if locking should be disabled."""
- return self.execution_params.get("disable_locking", False)
-
-
-class ParamResolver:
- """Parameter resolver."""
-
- def __init__(self, context: ExecutionContext):
- """Init parameter resolver."""
- self.ctx = context
-
- @staticmethod
- def resolve_user_params(
- cmd_name: Optional[str],
- index_or_alias: str,
- resolved_params: Optional[List[Tuple[Optional[str], Param]]],
- ) -> str:
- """Resolve user params."""
- if not cmd_name or resolved_params is None:
- raise ConfigurationException("Unable to resolve user params")
-
- param_value: Optional[str] = None
- param: Optional[Param] = None
-
- if index_or_alias.isnumeric():
- i = int(index_or_alias)
- if i not in range(len(resolved_params)):
- raise ConfigurationException(
- "Invalid index {} for user params of command {}".format(i, cmd_name)
- )
- param_value, param = resolved_params[i]
- else:
- for val, par in resolved_params:
- if par.alias == index_or_alias:
- param_value, param = val, par
- break
-
- if param is None:
- raise ConfigurationException(
- "No user parameter for command '{}' with alias '{}'.".format(
- cmd_name, index_or_alias
- )
- )
-
- if param_value:
- # We need to handle to cases of parameters here:
- # 1) Optional parameters (non-positional with a name and value)
- # 2) Positional parameters (value only, no name needed)
- # Default to empty strings for positional arguments
- param_name = ""
- separator = ""
- if param.name is not None:
- # A valid param name means we have an optional/non-positional argument:
- # The separator is an empty string in case the param_name
- # has an equal sign as we have to honour it.
- # If the parameter doesn't end with an equal sign then a
- # space character is injected to split the parameter name
- # and its value
- param_name = param.name
- separator = "" if param.name.endswith("=") else " "
-
- return "{param_name}{separator}{param_value}".format(
- param_name=param_name,
- separator=separator,
- param_value=param_value,
- )
-
- if param.name is None:
- raise ConfigurationException(
- "Missing user parameter with alias '{}' for command '{}'.".format(
- index_or_alias, cmd_name
- )
- )
-
- return param.name # flag: just return the parameter name
-
- def resolve_commands_and_params(
- self, backend_type: str, cmd_name: str, return_params: bool, index_or_alias: str
- ) -> str:
- """Resolve command or command's param value."""
- if backend_type == "system":
- backend = cast(Backend, self.ctx.system)
- backend_params = self.ctx.system_params
- else: # Application or Tool backend
- backend = cast(Backend, self.ctx.app)
- backend_params = self.ctx.app_params
-
- if cmd_name not in backend.commands:
- raise ConfigurationException("Command {} not found".format(cmd_name))
-
- if return_params:
- params = backend.resolved_parameters(cmd_name, backend_params)
- if index_or_alias.isnumeric():
- i = int(index_or_alias)
- if i not in range(len(params)):
- raise ConfigurationException(
- "Invalid parameter index {} for command {}".format(i, cmd_name)
- )
-
- param_value = params[i][0]
- else:
- param_value = None
- for value, param in params:
- if param.alias == index_or_alias:
- param_value = value
- break
-
- if not param_value:
- raise ConfigurationException(
- (
- "No value for parameter with index or alias {} of command {}"
- ).format(index_or_alias, cmd_name)
- )
- return param_value
-
- if not index_or_alias.isnumeric():
- raise ConfigurationException("Bad command index {}".format(index_or_alias))
-
- i = int(index_or_alias)
- commands = backend.build_command(cmd_name, backend_params, self.param_resolver)
- if i not in range(len(commands)):
- raise ConfigurationException(
- "Invalid index {} for command {}".format(i, cmd_name)
- )
-
- return commands[i]
-
- def resolve_variables(self, backend_type: str, var_name: str) -> str:
- """Resolve variable value."""
- if backend_type == "system":
- backend = cast(Backend, self.ctx.system)
- else: # Application or Tool backend
- backend = cast(Backend, self.ctx.app)
-
- if var_name not in backend.variables:
- raise ConfigurationException("Unknown variable {}".format(var_name))
-
- return backend.variables[var_name]
-
- def param_matcher(
- self,
- param_name: str,
- cmd_name: Optional[str],
- resolved_params: Optional[List[Tuple[Optional[str], Param]]],
- ) -> str:
- """Regexp to resolve a param from the param_name."""
- # this pattern supports parameter names like "application.commands.run:0" and
- # "system.commands.run.params:0"
- # Note: 'software' is included for backward compatibility.
- commands_and_params_match = re.match(
- r"(?P<type>application|software|tool|system)[.]commands[.]"
- r"(?P<name>\w+)"
- r"(?P<params>[.]params|)[:]"
- r"(?P<index_or_alias>\w+)",
- param_name,
- )
-
- if commands_and_params_match:
- backend_type, cmd_name, return_params, index_or_alias = (
- commands_and_params_match["type"],
- commands_and_params_match["name"],
- commands_and_params_match["params"],
- commands_and_params_match["index_or_alias"],
- )
- return self.resolve_commands_and_params(
- backend_type, cmd_name, bool(return_params), index_or_alias
- )
-
- # Note: 'software' is included for backward compatibility.
- variables_match = re.match(
- r"(?P<type>application|software|tool|system)[.]variables:(?P<var_name>\w+)",
- param_name,
- )
- if variables_match:
- backend_type, var_name = (
- variables_match["type"],
- variables_match["var_name"],
- )
- return self.resolve_variables(backend_type, var_name)
-
- user_params_match = re.match(r"user_params:(?P<index_or_alias>\w+)", param_name)
- if user_params_match:
- index_or_alias = user_params_match["index_or_alias"]
- return self.resolve_user_params(cmd_name, index_or_alias, resolved_params)
-
- raise ConfigurationException(
- "Unable to resolve parameter {}".format(param_name)
- )
-
- def param_resolver(
- self,
- param_name: str,
- cmd_name: Optional[str] = None,
- resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None,
- ) -> str:
- """Resolve parameter value based on current execution context."""
- # Note: 'software.*' is included for backward compatibility.
- resolved_param = None
- if param_name in ["application.name", "tool.name", "software.name"]:
- resolved_param = self.ctx.app.name
- elif param_name in [
- "application.description",
- "tool.description",
- "software.description",
- ]:
- resolved_param = self.ctx.app.description
- elif self.ctx.app.config_location and (
- param_name
- in ["application.config_dir", "tool.config_dir", "software.config_dir"]
- ):
- resolved_param = str(self.ctx.app.config_location.absolute())
- elif self.ctx.app.build_dir and (
- param_name
- in ["application.build_dir", "tool.build_dir", "software.build_dir"]
- ):
- resolved_param = str(self.ctx.build_dir().absolute())
- elif self.ctx.system is not None:
- if param_name == "system.name":
- resolved_param = self.ctx.system.name
- elif param_name == "system.description":
- resolved_param = self.ctx.system.description
- elif param_name == "system.config_dir" and self.ctx.system.config_location:
- resolved_param = str(self.ctx.system.config_location.absolute())
-
- if not resolved_param:
- resolved_param = self.param_matcher(param_name, cmd_name, resolved_params)
- return resolved_param
-
- def __call__(
- self,
- param_name: str,
- cmd_name: Optional[str] = None,
- resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None,
- ) -> str:
- """Resolve provided parameter."""
- return self.param_resolver(param_name, cmd_name, resolved_params)
-
-
-class Reporter:
- """Report metrics from the simulation output."""
-
- def __init__(self, parsers: Optional[List[OutputParser]] = None) -> None:
- """Create an empty reporter (i.e. no parsers registered)."""
- self.parsers: List[OutputParser] = parsers if parsers is not None else []
- self._report: Dict[str, Any] = defaultdict(lambda: defaultdict(dict))
-
- def parse(self, output: bytearray) -> None:
- """Parse output and append parsed metrics to internal report dict."""
- for parser in self.parsers:
- # Merge metrics from different parsers (do not overwrite)
- self._report[parser.name]["metrics"].update(parser(output))
-
- def get_filtered_output(self, output: bytearray) -> bytearray:
- """Filter the output according to each parser."""
- for parser in self.parsers:
- output = parser.filter_out_parsed_content(output)
- return output
-
- def report(self, ctx: ExecutionContext) -> Dict[str, Any]:
- """Add static simulation info to parsed data and return the report."""
- report: Dict[str, Any] = defaultdict(dict)
- # Add static simulation info
- report.update(self._static_info(ctx))
- # Add metrics parsed from the output
- for key, val in self._report.items():
- report[key].update(val)
- return report
-
- @staticmethod
- def save(report: Dict[str, Any], report_file: Path) -> None:
- """Save the report to a JSON file."""
- with open(report_file, "w", encoding="utf-8") as file:
- json.dump(report, file, indent=4)
-
- @staticmethod
- def _compute_all_params(cli_params: List[str], backend: Backend) -> Dict[str, str]:
- """
- Build a dict of all parameters, {name:value}.
-
- Param values taken from command line if specified, defaults otherwise.
- """
- # map of params passed from the cli ["p1=v1","p2=v2"] -> {"p1":"v1", "p2":"v2"}
- app_params_map = dict(parse_raw_parameter(expr) for expr in cli_params)
-
- # a map of params declared in the application, with values taken from the CLI,
- # defaults otherwise
- all_params = {
- (p.alias or p.name): app_params_map.get(
- cast(str, p.name), cast(str, p.default_value)
- )
- for cmd in backend.commands.values()
- for p in cmd.params
- }
- return cast(Dict[str, str], all_params)
-
- @staticmethod
- def _static_info(ctx: ExecutionContext) -> Dict[str, Any]:
- """Extract static simulation information from the context."""
- if ctx.system is None:
- raise ValueError("No system available to report.")
-
- info = {
- "system": {
- "name": ctx.system.name,
- "params": Reporter._compute_all_params(ctx.system_params, ctx.system),
- },
- "application": {
- "name": ctx.app.name,
- "params": Reporter._compute_all_params(ctx.app_params, ctx.app),
- },
- }
- return info
-
-
-def validate_parameters(
- backend: Backend, command_names: List[str], params: List[str]
-) -> None:
- """Check parameters passed to backend."""
- for param in params:
- acceptable = any(
- backend.validate_parameter(command_name, param)
- for command_name in command_names
- if command_name in backend.commands
- )
-
- if not acceptable:
- backend_type = "System" if isinstance(backend, System) else "Application"
- raise ValueError(
- "{} parameter '{}' not valid for command '{}'".format(
- backend_type, param, " or ".join(command_names)
- )
- )
-
-
-def get_application_by_name_and_system(
- application_name: str, system_name: str
-) -> Application:
- """Get application."""
- applications = get_application(application_name, system_name)
- if not applications:
- raise ValueError(
- "Application '{}' doesn't support the system '{}'".format(
- application_name, system_name
- )
- )
-
- if len(applications) != 1:
- raise ValueError(
- "Error during getting application {} for the system {}".format(
- application_name, system_name
- )
- )
-
- return applications[0]
-
-
-def get_application_and_system(
- application_name: str, system_name: str
-) -> Tuple[Application, System]:
- """Return application and system by provided names."""
- system = get_system(system_name)
- if not system:
- raise ValueError("System {} is not found".format(system_name))
-
- application = get_application_by_name_and_system(application_name, system_name)
-
- return application, system
-
-
-def execute_application_command( # pylint: disable=too-many-arguments
- command_name: str,
- application_name: str,
- application_params: List[str],
- system_name: str,
- system_params: List[str],
- custom_deploy_data: List[DataPaths],
-) -> None:
- """Execute application command.
-
- .. deprecated:: 21.12
- """
- warnings.warn(
- "Use 'run_application()' instead. Use of 'execute_application_command()' is "
- "deprecated and might be removed in a future release.",
- DeprecationWarning,
- )
-
- if command_name not in ["build", "run"]:
- raise ConfigurationException("Unsupported command {}".format(command_name))
-
- application, system = get_application_and_system(application_name, system_name)
- validate_parameters(application, [command_name], application_params)
- validate_parameters(system, [command_name], system_params)
-
- ctx = ExecutionContext(
- app=application,
- app_params=application_params,
- system=system,
- system_params=system_params,
- custom_deploy_data=custom_deploy_data,
- )
-
- if command_name == "run":
- execute_application_command_run(ctx)
- else:
- execute_application_command_build(ctx)
-
-
-# pylint: disable=too-many-arguments
-def run_application(
- application_name: str,
- application_params: List[str],
- system_name: str,
- system_params: List[str],
- custom_deploy_data: List[DataPaths],
- report_file: Optional[Path] = None,
-) -> None:
- """Run application on the provided system."""
- application, system = get_application_and_system(application_name, system_name)
- validate_parameters(application, ["build", "run"], application_params)
- validate_parameters(system, ["build", "run"], system_params)
-
- execution_params = ExecutionParams()
- if isinstance(system, StandaloneSystem):
- execution_params["disable_locking"] = True
- execution_params["unique_build_dir"] = True
-
- ctx = ExecutionContext(
- app=application,
- app_params=application_params,
- system=system,
- system_params=system_params,
- custom_deploy_data=custom_deploy_data,
- execution_params=execution_params,
- report_file=report_file,
- )
-
- with build_dir_manager(ctx):
- if ctx.is_build_required:
- execute_application_command_build(ctx)
-
- execute_application_command_run(ctx)
-
-
-def execute_application_command_build(ctx: ExecutionContext) -> None:
- """Execute application command 'build'."""
- with ExitStack() as context_stack:
- for manager in get_context_managers("build", ctx):
- context_stack.enter_context(manager(ctx))
-
- build_dir = ctx.build_dir()
- recreate_directory(build_dir)
-
- build_commands = ctx.app.build_command(
- "build", ctx.app_params, ctx.param_resolver
- )
- execute_commands_locally(build_commands, build_dir)
-
-
-def execute_commands_locally(commands: List[str], cwd: Path) -> None:
- """Execute list of commands locally."""
- for command in commands:
- print("Running: {}".format(command))
- run_and_wait(
- command, cwd, terminate_on_error=True, out=sys.stdout, err=sys.stderr
- )
-
-
-def execute_application_command_run(ctx: ExecutionContext) -> None:
- """Execute application command."""
- assert ctx.system is not None, "System must be provided."
- if ctx.is_deploy_needed and not ctx.system.supports_deploy:
- raise ConfigurationException(
- "System {} does not support data deploy".format(ctx.system.name)
- )
-
- with ExitStack() as context_stack:
- for manager in get_context_managers("run", ctx):
- context_stack.enter_context(manager(ctx))
-
- print("Generating commands to execute")
- commands_to_run = build_run_commands(ctx)
-
- if ctx.system.connectable:
- establish_connection(ctx)
-
- if ctx.system.supports_deploy:
- deploy_data(ctx)
-
- for command in commands_to_run:
- print("Running: {}".format(command))
- exit_code, std_output, std_err = ctx.system.run(command)
-
- if exit_code != 0:
- print("Application exited with exit code {}".format(exit_code))
-
- if ctx.reporter:
- ctx.reporter.parse(std_output)
- std_output = ctx.reporter.get_filtered_output(std_output)
-
- print(std_output.decode("utf8"), end="")
- print(std_err.decode("utf8"), end="")
-
- if ctx.reporter:
- report = ctx.reporter.report(ctx)
- ctx.reporter.save(report, cast(Path, ctx.report_file))
-
-
-def establish_connection(
- ctx: ExecutionContext, retries: int = 90, interval: float = 15.0
-) -> None:
- """Establish connection with the system."""
- assert ctx.system is not None, "System is required."
- host, port = ctx.system.connection_details()
- print(
- "Trying to establish connection with '{}:{}' - "
- "{} retries every {} seconds ".format(host, port, retries, interval),
- end="",
- )
-
- try:
- for _ in range(retries):
- print(".", end="", flush=True)
-
- if ctx.system.establish_connection():
- break
-
- if isinstance(ctx.system, ControlledSystem) and not ctx.system.is_running():
- print(
- "\n\n---------- {} execution failed ----------".format(
- ctx.system.name
- )
- )
- stdout, stderr = ctx.system.get_output()
- print(stdout)
- print(stderr)
-
- raise Exception("System is not running")
-
- wait(interval)
- else:
- raise ConnectionException("Couldn't connect to '{}:{}'.".format(host, port))
- finally:
- print()
-
-
-def wait(interval: float) -> None:
- """Wait for a period of time."""
- time.sleep(interval)
-
-
-def deploy_data(ctx: ExecutionContext) -> None:
- """Deploy data to the system."""
- if isinstance(ctx.app, Application):
- # Only application can deploy data (tools can not)
- assert ctx.system is not None, "System is required."
- for item in itertools.chain(ctx.app.get_deploy_data(), ctx.custom_deploy_data):
- print("Deploying {} onto {}".format(item.src, item.dst))
- ctx.system.deploy(item.src, item.dst)
-
-
-def build_run_commands(ctx: ExecutionContext) -> List[str]:
- """Build commands to run application."""
- if isinstance(ctx.system, StandaloneSystem):
- return ctx.system.build_command("run", ctx.system_params, ctx.param_resolver)
-
- return ctx.app.build_command("run", ctx.app_params, ctx.param_resolver)
-
-
-@contextmanager
-def controlled_system_manager(ctx: ExecutionContext) -> Generator[None, None, None]:
- """Context manager used for system initialisation before run."""
- system = cast(ControlledSystem, ctx.system)
- commands = system.build_command("run", ctx.system_params, ctx.param_resolver)
- pid_file_path: Optional[Path] = None
- if ctx.is_locking_required:
- file_lock_path = get_file_lock_path(ctx)
- pid_file_path = file_lock_path.parent / "{}.pid".format(file_lock_path.stem)
-
- system.start(commands, ctx.is_locking_required, pid_file_path)
- try:
- yield
- finally:
- print("Shutting down sequence...")
- print("Stopping {}... (It could take few seconds)".format(system.name))
- system.stop(wait=True)
- print("{} stopped successfully.".format(system.name))
-
-
-@contextmanager
-def lock_execution_manager(ctx: ExecutionContext) -> Generator[None, None, None]:
- """Lock execution manager."""
- file_lock_path = get_file_lock_path(ctx)
- file_lock = FileLock(str(file_lock_path))
-
- try:
- file_lock.acquire(timeout=1)
- except Timeout as error:
- raise AnotherInstanceIsRunningException() from error
-
- try:
- yield
- finally:
- file_lock.release()
-
-
-def get_file_lock_path(ctx: ExecutionContext, lock_dir: Path = Path("/tmp")) -> Path:
- """Get file lock path."""
- lock_modules = []
- if ctx.app.lock:
- lock_modules.append(ctx.app.name)
- if ctx.system is not None and ctx.system.lock:
- lock_modules.append(ctx.system.name)
- lock_filename = ""
- if lock_modules:
- lock_filename = "_".join(["middleware"] + lock_modules) + ".lock"
-
- if lock_filename:
- lock_filename = resolve_all_parameters(lock_filename, ctx.param_resolver)
- lock_filename = valid_for_filename(lock_filename)
-
- if not lock_filename:
- raise ConfigurationException("No filename for lock provided")
-
- if not isinstance(lock_dir, Path) or not lock_dir.is_dir():
- raise ConfigurationException(
- "Invalid directory {} for lock files provided".format(lock_dir)
- )
-
- return lock_dir / lock_filename
-
-
-@contextmanager
-def build_dir_manager(ctx: ExecutionContext) -> Generator[None, None, None]:
- """Build directory manager."""
- try:
- yield
- finally:
- if (
- ctx.is_build_required
- and ctx.is_unique_build_dir_required
- and ctx.build_dir().is_dir()
- ):
- remove_directory(ctx.build_dir())
-
-
-def get_context_managers(
- command_name: str, ctx: ExecutionContext
-) -> Sequence[Callable[[ExecutionContext], ContextManager[None]]]:
- """Get context manager for the system."""
- managers = []
-
- if ctx.is_locking_required:
- managers.append(lock_execution_manager)
-
- if command_name == "run":
- if isinstance(ctx.system, ControlledSystem):
- managers.append(controlled_system_manager)
-
- return managers
-
-
-def get_tool_by_system(tool_name: str, system_name: Optional[str]) -> Tool:
- """Return tool (optionally by provided system name."""
- tools = get_tool(tool_name, system_name)
- if not tools:
- raise ConfigurationException(
- "Tool '{}' not found or doesn't support the system '{}'".format(
- tool_name, system_name
- )
- )
- if len(tools) != 1:
- raise ConfigurationException(
- "Please specify the system for tool {}.".format(tool_name)
- )
- tool = tools[0]
-
- return tool
-
-
-def execute_tool_command(
- tool_name: str,
- tool_params: List[str],
- system_name: Optional[str] = None,
-) -> None:
- """Execute the tool command locally calling the 'run' command."""
- tool = get_tool_by_system(tool_name, system_name)
- ctx = ExecutionContext(
- app=tool, app_params=tool_params, system=None, system_params=[]
- )
- commands = tool.build_command("run", tool_params, ctx.param_resolver)
-
- execute_commands_locally(commands, Path.cwd())
diff --git a/src/aiet/backend/output_parser.py b/src/aiet/backend/output_parser.py
deleted file mode 100644
index 111772a..0000000
--- a/src/aiet/backend/output_parser.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Definition of output parsers (including base class OutputParser)."""
-import base64
-import json
-import re
-from abc import ABC
-from abc import abstractmethod
-from typing import Any
-from typing import Dict
-from typing import Union
-
-
-class OutputParser(ABC):
- """Abstract base class for output parsers."""
-
- def __init__(self, name: str) -> None:
- """Set up the name of the parser."""
- super().__init__()
- self.name = name
-
- @abstractmethod
- def __call__(self, output: bytearray) -> Dict[str, Any]:
- """Parse the output and return a map of names to metrics."""
- return {}
-
- # pylint: disable=no-self-use
- def filter_out_parsed_content(self, output: bytearray) -> bytearray:
- """
- Filter out the parsed content from the output.
-
- Does nothing by default. Can be overridden in subclasses.
- """
- return output
-
-
-class RegexOutputParser(OutputParser):
- """Parser of standard output data using regular expressions."""
-
- _TYPE_MAP = {"str": str, "float": float, "int": int}
-
- def __init__(
- self,
- name: str,
- regex_config: Dict[str, Dict[str, str]],
- ) -> None:
- """
- Set up the parser with the regular expressions.
-
- The regex_config is mapping from a name to a dict with keys 'pattern'
- and 'type':
- - The 'pattern' holds the regular expression that must contain exactly
- one capturing parenthesis
- - The 'type' can be one of ['str', 'float', 'int'].
-
- Example:
- ```
- {"Metric1": {"pattern": ".*= *(.*)", "type": "str"}}
- ```
-
- The different regular expressions from the config are combined using
- non-capturing parenthesis, i.e. regular expressions must not overlap
- if more than one match per line is expected.
- """
- super().__init__(name)
-
- self._verify_config(regex_config)
- self._regex_cfg = regex_config
-
- # Compile regular expression to match in the output
- self._regex = re.compile(
- "|".join("(?:{0})".format(x["pattern"]) for x in self._regex_cfg.values())
- )
-
- def __call__(self, output: bytearray) -> Dict[str, Union[str, float, int]]:
- """
- Parse the output and return a map of names to metrics.
-
- Example:
- Assuming a regex_config as used as example in `__init__()` and the
- following output:
- ```
- Simulation finished:
- SIMULATION_STATUS = SUCCESS
- Simulation DONE
- ```
- Then calling the parser should return the following dict:
- ```
- {
- "Metric1": "SUCCESS"
- }
- ```
- """
- metrics = {}
- output_str = output.decode("utf-8")
- results = self._regex.findall(output_str)
- for line_result in results:
- for idx, (name, cfg) in enumerate(self._regex_cfg.items()):
- # The result(s) returned by findall() are either a single string
- # or a tuple (depending on the number of groups etc.)
- result = (
- line_result if isinstance(line_result, str) else line_result[idx]
- )
- if result:
- mapped_result = self._TYPE_MAP[cfg["type"]](result)
- metrics[name] = mapped_result
- return metrics
-
- def _verify_config(self, regex_config: Dict[str, Dict[str, str]]) -> None:
- """Make sure we have a valid regex_config.
-
- I.e.
- - Exactly one capturing parenthesis per pattern
- - Correct types
- """
- for name, cfg in regex_config.items():
- # Check that there is one capturing group defined in the pattern.
- regex = re.compile(cfg["pattern"])
- if regex.groups != 1:
- raise ValueError(
- f"Pattern for metric '{name}' must have exactly one "
- f"capturing parenthesis, but it has {regex.groups}."
- )
- # Check if type is supported
- if not cfg["type"] in self._TYPE_MAP:
- raise TypeError(
- f"Type '{cfg['type']}' for metric '{name}' is not "
- f"supported. Choose from: {list(self._TYPE_MAP.keys())}."
- )
-
-
-class Base64OutputParser(OutputParser):
- """
- Parser to extract base64-encoded JSON from tagged standard output.
-
- Example of the tagged output:
- ```
- # Encoded JSON: {"test": 1234}
- <metrics>eyJ0ZXN0IjogMTIzNH0</metrics>
- ```
- """
-
- TAG_NAME = "metrics"
-
- def __init__(self, name: str) -> None:
- """Set up the regular expression to extract tagged strings."""
- super().__init__(name)
- self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)</{self.TAG_NAME}>")
-
- def __call__(self, output: bytearray) -> Dict[str, Any]:
- """
- Parse the output and return a map of index (as string) to decoded JSON.
-
- Example:
- Using the tagged output from the class docs the parser should return
- the following dict:
- ```
- {
- "0": {"test": 1234}
- }
- ```
- """
- metrics = {}
- output_str = output.decode("utf-8")
- results = self._regex.findall(output_str)
- for idx, result_base64 in enumerate(results):
- result_json = base64.b64decode(result_base64, validate=True)
- result = json.loads(result_json)
- metrics[str(idx)] = result
-
- return metrics
-
- def filter_out_parsed_content(self, output: bytearray) -> bytearray:
- """Filter out base64-encoded content from the output."""
- output_str = self._regex.sub("", output.decode("utf-8"))
- return bytearray(output_str.encode("utf-8"))
diff --git a/src/aiet/backend/protocol.py b/src/aiet/backend/protocol.py
deleted file mode 100644
index c621436..0000000
--- a/src/aiet/backend/protocol.py
+++ /dev/null
@@ -1,325 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Contain protocol related classes and functions."""
-from abc import ABC
-from abc import abstractmethod
-from contextlib import closing
-from pathlib import Path
-from typing import Any
-from typing import cast
-from typing import Iterable
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import paramiko
-
-from aiet.backend.common import ConfigurationException
-from aiet.backend.config import LocalProtocolConfig
-from aiet.backend.config import SSHConfig
-from aiet.utils.proc import run_and_wait
-
-
-# Redirect all paramiko thread exceptions to a file otherwise these will be
-# printed to stderr.
-paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO)
-
-
-class SSHConnectionException(Exception):
- """SSH connection exception."""
-
-
-class SupportsClose(ABC):
- """Class indicates support of close operation."""
-
- @abstractmethod
- def close(self) -> None:
- """Close protocol session."""
-
-
-class SupportsDeploy(ABC):
- """Class indicates support of deploy operation."""
-
- @abstractmethod
- def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
- """Abstract method for deploy data."""
-
-
-class SupportsConnection(ABC):
- """Class indicates that protocol uses network connections."""
-
- @abstractmethod
- def establish_connection(self) -> bool:
- """Establish connection with underlying system."""
-
- @abstractmethod
- def connection_details(self) -> Tuple[str, int]:
- """Return connection details (host, port)."""
-
-
-class Protocol(ABC):
- """Abstract class for representing the protocol."""
-
- def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None:
- """Initialize the class using a dict."""
- self.__dict__.update(iterable, **kwargs)
- self._validate()
-
- @abstractmethod
- def _validate(self) -> None:
- """Abstract method for config data validation."""
-
- @abstractmethod
- def run(
- self, command: str, retry: bool = False
- ) -> Tuple[int, bytearray, bytearray]:
- """
- Abstract method for running commands.
-
- Returns a tuple: (exit_code, stdout, stderr)
- """
-
-
-class CustomSFTPClient(paramiko.SFTPClient):
- """Class for creating a custom sftp client."""
-
- def put_dir(self, source: Path, target: str) -> None:
- """Upload the source directory to the target path.
-
- The target directory needs to exists and the last directory of the
- source will be created under the target with all its content.
- """
- # Create the target directory
- self._mkdir(target, ignore_existing=True)
- # Create the last directory in the source on the target
- self._mkdir("{}/{}".format(target, source.name), ignore_existing=True)
- # Go through the whole content of source
- for item in sorted(source.glob("**/*")):
- relative_path = item.relative_to(source.parent)
- remote_target = target / relative_path
- if item.is_file():
- self.put(str(item), str(remote_target))
- else:
- self._mkdir(str(remote_target), ignore_existing=True)
-
- def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None:
- """Extend mkdir functionality.
-
- This version adds an option to not fail if the folder exists.
- """
- try:
- super().mkdir(path, mode)
- except IOError as error:
- if ignore_existing:
- pass
- else:
- raise error
-
-
-class LocalProtocol(Protocol):
- """Class for local protocol."""
-
- protocol: str
- cwd: Path
-
- def run(
- self, command: str, retry: bool = False
- ) -> Tuple[int, bytearray, bytearray]:
- """
- Run command locally.
-
- Returns a tuple: (exit_code, stdout, stderr)
- """
- if not isinstance(self.cwd, Path) or not self.cwd.is_dir():
- raise ConfigurationException("Wrong working directory {}".format(self.cwd))
-
- stdout = bytearray()
- stderr = bytearray()
-
- return run_and_wait(
- command, self.cwd, terminate_on_error=True, out=stdout, err=stderr
- )
-
- def _validate(self) -> None:
- """Validate protocol configuration."""
- assert hasattr(self, "protocol") and self.protocol == "local"
- assert hasattr(self, "cwd")
-
-
-class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection):
- """Class for SSH protocol."""
-
- protocol: str
- username: str
- password: str
- hostname: str
- port: int
-
- def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None:
- """Initialize the class using a dict."""
- super().__init__(iterable, **kwargs)
- # Internal state to store if the system is connectable. It will be set
- # to true at the first connection instance
- self.client: Optional[paramiko.client.SSHClient] = None
- self.port = int(self.port)
-
- def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]:
- """
- Run command over SSH.
-
- Returns a tuple: (exit_code, stdout, stderr)
- """
- transport = self._get_transport()
- with closing(transport.open_session()) as channel:
- # Enable shell's .profile settings and execute command
- channel.exec_command("bash -l -c '{}'".format(command))
- exit_status = -1
- stdout = bytearray()
- stderr = bytearray()
- while True:
- if channel.exit_status_ready():
- exit_status = channel.recv_exit_status()
- # Call it one last time to read any leftover in the channel
- self._recv_stdout_err(channel, stdout, stderr)
- break
- self._recv_stdout_err(channel, stdout, stderr)
-
- return exit_status, stdout, stderr
-
- def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
- """Deploy src to remote dst over SSH.
-
- src and dst should be path to a file or directory.
- """
- transport = self._get_transport()
- sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport))
-
- with closing(sftp):
- if src.is_dir():
- sftp.put_dir(src, dst)
- elif src.is_file():
- sftp.put(str(src), dst)
- else:
- raise Exception("Deploy error: file type not supported")
-
- # After the deployment of files, sync the remote filesystem to flush
- # buffers to hard disk
- self.run("sync")
-
- def close(self) -> None:
- """Close protocol session."""
- if self.client is not None:
- print("Try syncing remote file system...")
- # Before stopping the system, we try to run sync to make sure all
- # data are flushed on disk.
- self.run("sync", retry=False)
- self._close_client(self.client)
-
- def establish_connection(self) -> bool:
- """Establish connection with underlying system."""
- if self.client is not None:
- return True
-
- self.client = self._connect()
- return self.client is not None
-
- def _get_transport(self) -> paramiko.transport.Transport:
- """Get transport."""
- self.establish_connection()
-
- if self.client is None:
- raise SSHConnectionException(
- "Couldn't connect to '{}:{}'.".format(self.hostname, self.port)
- )
-
- transport = self.client.get_transport()
- if not transport:
- raise Exception("Unable to get transport")
-
- return transport
-
- def connection_details(self) -> Tuple[str, int]:
- """Return connection details of underlying system."""
- return (self.hostname, self.port)
-
- def _connect(self) -> Optional[paramiko.client.SSHClient]:
- """Try to establish connection."""
- client: Optional[paramiko.client.SSHClient] = None
- try:
- client = paramiko.client.SSHClient()
- client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- client.connect(
- self.hostname,
- self.port,
- self.username,
- self.password,
- # next parameters should be set to False to disable authentication
- # using ssh keys
- allow_agent=False,
- look_for_keys=False,
- )
- return client
- except (
- # OSError raised on first attempt to connect when running inside Docker
- OSError,
- paramiko.ssh_exception.NoValidConnectionsError,
- paramiko.ssh_exception.SSHException,
- ):
- # even if connection is not established socket could be still
- # open, it should be closed
- self._close_client(client)
-
- return None
-
- @staticmethod
- def _close_client(client: Optional[paramiko.client.SSHClient]) -> None:
- """Close ssh client."""
- try:
- if client is not None:
- client.close()
- except Exception: # pylint: disable=broad-except
- pass
-
- @classmethod
- def _recv_stdout_err(
- cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray
- ) -> None:
- """Read from channel to stdout/stder."""
- chunk_size = 512
- if channel.recv_ready():
- stdout_chunk = channel.recv(chunk_size)
- stdout.extend(stdout_chunk)
- if channel.recv_stderr_ready():
- stderr_chunk = channel.recv_stderr(chunk_size)
- stderr.extend(stderr_chunk)
-
- def _validate(self) -> None:
- """Check if there are all the info for establishing the connection."""
- assert hasattr(self, "protocol") and self.protocol == "ssh"
- assert hasattr(self, "username")
- assert hasattr(self, "password")
- assert hasattr(self, "hostname")
- assert hasattr(self, "port")
-
-
-class ProtocolFactory:
- """Factory class to return the appropriate Protocol class."""
-
- @staticmethod
- def get_protocol(
- config: Optional[Union[SSHConfig, LocalProtocolConfig]],
- **kwargs: Union[str, Path, None]
- ) -> Union[SSHProtocol, LocalProtocol]:
- """Return the right protocol instance based on the config."""
- if not config:
- raise ValueError("No protocol config provided")
-
- protocol = config["protocol"]
- if protocol == "ssh":
- return SSHProtocol(config)
-
- if protocol == "local":
- cwd = kwargs.get("cwd")
- return LocalProtocol(config, cwd=cwd)
-
- raise ValueError("Protocol not supported: '{}'".format(protocol))
diff --git a/src/aiet/backend/source.py b/src/aiet/backend/source.py
deleted file mode 100644
index dec175a..0000000
--- a/src/aiet/backend/source.py
+++ /dev/null
@@ -1,209 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Contain source related classes and functions."""
-import os
-import shutil
-import tarfile
-from abc import ABC
-from abc import abstractmethod
-from pathlib import Path
-from tarfile import TarFile
-from typing import Optional
-from typing import Union
-
-from aiet.backend.common import AIET_CONFIG_FILE
-from aiet.backend.common import ConfigurationException
-from aiet.backend.common import get_backend_config
-from aiet.backend.common import is_backend_directory
-from aiet.backend.common import load_config
-from aiet.backend.config import BackendConfig
-from aiet.utils.fs import copy_directory_content
-
-
-class Source(ABC):
- """Source class."""
-
- @abstractmethod
- def name(self) -> Optional[str]:
- """Get source name."""
-
- @abstractmethod
- def config(self) -> Optional[BackendConfig]:
- """Get configuration file content."""
-
- @abstractmethod
- def install_into(self, destination: Path) -> None:
- """Install source into destination directory."""
-
- @abstractmethod
- def create_destination(self) -> bool:
- """Return True if destination folder should be created before installation."""
-
-
-class DirectorySource(Source):
- """DirectorySource class."""
-
- def __init__(self, directory_path: Path) -> None:
- """Create the DirectorySource instance."""
- assert isinstance(directory_path, Path)
- self.directory_path = directory_path
-
- def name(self) -> str:
- """Return name of source."""
- return self.directory_path.name
-
- def config(self) -> Optional[BackendConfig]:
- """Return configuration file content."""
- if not is_backend_directory(self.directory_path):
- raise ConfigurationException("No configuration file found")
-
- config_file = get_backend_config(self.directory_path)
- return load_config(config_file)
-
- def install_into(self, destination: Path) -> None:
- """Install source into destination directory."""
- if not destination.is_dir():
- raise ConfigurationException("Wrong destination {}".format(destination))
-
- if not self.directory_path.is_dir():
- raise ConfigurationException(
- "Directory {} does not exist".format(self.directory_path)
- )
-
- copy_directory_content(self.directory_path, destination)
-
- def create_destination(self) -> bool:
- """Return True if destination folder should be created before installation."""
- return True
-
-
-class TarArchiveSource(Source):
- """TarArchiveSource class."""
-
- def __init__(self, archive_path: Path) -> None:
- """Create the TarArchiveSource class."""
- assert isinstance(archive_path, Path)
- self.archive_path = archive_path
- self._config: Optional[BackendConfig] = None
- self._has_top_level_folder: Optional[bool] = None
- self._name: Optional[str] = None
-
- def _read_archive_content(self) -> None:
- """Read various information about archive."""
- # get source name from archive name (everything without extensions)
- extensions = "".join(self.archive_path.suffixes)
- self._name = self.archive_path.name.rstrip(extensions)
-
- if not self.archive_path.exists():
- return
-
- with self._open(self.archive_path) as archive:
- try:
- config_entry = archive.getmember(AIET_CONFIG_FILE)
- self._has_top_level_folder = False
- except KeyError as error_no_config:
- try:
- archive_entries = archive.getnames()
- entries_common_prefix = os.path.commonprefix(archive_entries)
- top_level_dir = entries_common_prefix.rstrip("/")
-
- if not top_level_dir:
- raise RuntimeError(
- "Archive has no top level directory"
- ) from error_no_config
-
- config_path = "{}/{}".format(top_level_dir, AIET_CONFIG_FILE)
-
- config_entry = archive.getmember(config_path)
- self._has_top_level_folder = True
- self._name = top_level_dir
- except (KeyError, RuntimeError) as error_no_root_dir_or_config:
- raise ConfigurationException(
- "No configuration file found"
- ) from error_no_root_dir_or_config
-
- content = archive.extractfile(config_entry)
- self._config = load_config(content)
-
- def config(self) -> Optional[BackendConfig]:
- """Return configuration file content."""
- if self._config is None:
- self._read_archive_content()
-
- return self._config
-
- def name(self) -> Optional[str]:
- """Return name of the source."""
- if self._name is None:
- self._read_archive_content()
-
- return self._name
-
- def create_destination(self) -> bool:
- """Return True if destination folder must be created before installation."""
- if self._has_top_level_folder is None:
- self._read_archive_content()
-
- return not self._has_top_level_folder
-
- def install_into(self, destination: Path) -> None:
- """Install source into destination directory."""
- if not destination.is_dir():
- raise ConfigurationException("Wrong destination {}".format(destination))
-
- with self._open(self.archive_path) as archive:
- archive.extractall(destination)
-
- def _open(self, archive_path: Path) -> TarFile:
- """Open archive file."""
- if not archive_path.is_file():
- raise ConfigurationException("File {} does not exist".format(archive_path))
-
- if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"):
- mode = "r:gz"
- else:
- raise ConfigurationException(
- "Unsupported archive type {}".format(archive_path)
- )
-
- # The returned TarFile object can be used as a context manager (using
- # 'with') by the calling instance.
- return tarfile.open( # pylint: disable=consider-using-with
- self.archive_path, mode=mode
- )
-
-
-def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]:
- """Return appropriate source instance based on provided source path."""
- if source_path.is_file():
- return TarArchiveSource(source_path)
-
- if source_path.is_dir():
- return DirectorySource(source_path)
-
- raise ConfigurationException("Unable to read {}".format(source_path))
-
-
-def create_destination_and_install(source: Source, resource_path: Path) -> None:
- """Create destination directory and install source.
-
- This function is used for actual installation of system/backend New
- directory will be created inside :resource_path: if needed If for example
- archive contains top level folder then no need to create new directory
- """
- destination = resource_path
- create_destination = source.create_destination()
-
- if create_destination:
- name = source.name()
- if not name:
- raise ConfigurationException("Unable to get source name")
-
- destination = resource_path / name
- destination.mkdir()
- try:
- source.install_into(destination)
- except Exception as error:
- if create_destination:
- shutil.rmtree(destination)
- raise error
diff --git a/src/aiet/backend/system.py b/src/aiet/backend/system.py
deleted file mode 100644
index 48f1bb1..0000000
--- a/src/aiet/backend/system.py
+++ /dev/null
@@ -1,289 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""System backend module."""
-from pathlib import Path
-from typing import Any
-from typing import cast
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-from aiet.backend.common import Backend
-from aiet.backend.common import ConfigurationException
-from aiet.backend.common import get_backend_configs
-from aiet.backend.common import get_backend_directories
-from aiet.backend.common import load_config
-from aiet.backend.common import remove_backend
-from aiet.backend.config import SystemConfig
-from aiet.backend.controller import SystemController
-from aiet.backend.controller import SystemControllerSingleInstance
-from aiet.backend.protocol import ProtocolFactory
-from aiet.backend.protocol import SupportsClose
-from aiet.backend.protocol import SupportsConnection
-from aiet.backend.protocol import SupportsDeploy
-from aiet.backend.source import create_destination_and_install
-from aiet.backend.source import get_source
-from aiet.utils.fs import get_resources
-
-
-def get_available_systems_directory_names() -> List[str]:
- """Return a list of directory names for all avialable systems."""
- return [entry.name for entry in get_backend_directories("systems")]
-
-
-def get_available_systems() -> List["System"]:
- """Return a list with all available systems."""
- available_systems = []
- for config_json in get_backend_configs("systems"):
- config_entries = cast(List[SystemConfig], (load_config(config_json)))
- for config_entry in config_entries:
- config_entry["config_location"] = config_json.parent.absolute()
- system = load_system(config_entry)
- available_systems.append(system)
-
- return sorted(available_systems, key=lambda system: system.name)
-
-
-def get_system(system_name: str) -> Optional["System"]:
- """Return a system instance with the same name passed as argument."""
- available_systems = get_available_systems()
- for system in available_systems:
- if system_name == system.name:
- return system
- return None
-
-
-def install_system(source_path: Path) -> None:
- """Install new system."""
- try:
- source = get_source(source_path)
- config = cast(List[SystemConfig], source.config())
- systems_to_install = [load_system(entry) for entry in config]
- except Exception as error:
- raise ConfigurationException("Unable to read system definition") from error
-
- if not systems_to_install:
- raise ConfigurationException("No system definition found")
-
- available_systems = get_available_systems()
- already_installed = [s for s in systems_to_install if s in available_systems]
- if already_installed:
- names = [system.name for system in already_installed]
- raise ConfigurationException(
- "Systems [{}] are already installed".format(",".join(names))
- )
-
- create_destination_and_install(source, get_resources("systems"))
-
-
-def remove_system(directory_name: str) -> None:
- """Remove system."""
- remove_backend(directory_name, "systems")
-
-
-class System(Backend):
- """System class."""
-
- def __init__(self, config: SystemConfig) -> None:
- """Construct the System class using the dictionary passed."""
- super().__init__(config)
-
- self._setup_data_transfer(config)
- self._setup_reporting(config)
-
- def _setup_data_transfer(self, config: SystemConfig) -> None:
- data_transfer_config = config.get("data_transfer")
- protocol = ProtocolFactory().get_protocol(
- data_transfer_config, cwd=self.config_location
- )
- self.protocol = protocol
-
- def _setup_reporting(self, config: SystemConfig) -> None:
- self.reporting = config.get("reporting")
-
- def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]:
- """
- Run command on the system.
-
- Returns a tuple: (exit_code, stdout, stderr)
- """
- return self.protocol.run(command, retry)
-
- def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
- """Deploy files to the system."""
- if isinstance(self.protocol, SupportsDeploy):
- self.protocol.deploy(src, dst, retry)
-
- @property
- def supports_deploy(self) -> bool:
- """Check if protocol supports deploy operation."""
- return isinstance(self.protocol, SupportsDeploy)
-
- @property
- def connectable(self) -> bool:
- """Check if protocol supports connection."""
- return isinstance(self.protocol, SupportsConnection)
-
- def establish_connection(self) -> bool:
- """Establish connection with the system."""
- if not isinstance(self.protocol, SupportsConnection):
- raise ConfigurationException(
- "System {} does not support connections".format(self.name)
- )
-
- return self.protocol.establish_connection()
-
- def connection_details(self) -> Tuple[str, int]:
- """Return connection details."""
- if not isinstance(self.protocol, SupportsConnection):
- raise ConfigurationException(
- "System {} does not support connections".format(self.name)
- )
-
- return self.protocol.connection_details()
-
- def __eq__(self, other: object) -> bool:
- """Overload operator ==."""
- if not isinstance(other, System):
- return False
-
- return super().__eq__(other) and self.name == other.name
-
- def get_details(self) -> Dict[str, Any]:
- """Return a dictionary with all relevant information of a System."""
- output = {
- "type": "system",
- "name": self.name,
- "description": self.description,
- "data_transfer_protocol": self.protocol.protocol,
- "commands": self._get_command_details(),
- "annotations": self.annotations,
- }
-
- return output
-
-
-class StandaloneSystem(System):
- """StandaloneSystem class."""
-
-
-def get_controller(
- single_instance: bool, pid_file_path: Optional[Path] = None
-) -> SystemController:
- """Get system controller."""
- if single_instance:
- return SystemControllerSingleInstance(pid_file_path)
-
- return SystemController()
-
-
-class ControlledSystem(System):
- """ControlledSystem class."""
-
- def __init__(self, config: SystemConfig):
- """Construct the ControlledSystem class using the dictionary passed."""
- super().__init__(config)
- self.controller: Optional[SystemController] = None
-
- def start(
- self,
- commands: List[str],
- single_instance: bool = True,
- pid_file_path: Optional[Path] = None,
- ) -> None:
- """Launch the system."""
- if (
- not isinstance(self.config_location, Path)
- or not self.config_location.is_dir()
- ):
- raise ConfigurationException(
- "System {} has wrong config location".format(self.name)
- )
-
- self.controller = get_controller(single_instance, pid_file_path)
- self.controller.start(commands, self.config_location)
-
- def is_running(self) -> bool:
- """Check if system is running."""
- if not self.controller:
- return False
-
- return self.controller.is_running()
-
- def get_output(self) -> Tuple[str, str]:
- """Return system output."""
- if not self.controller:
- return "", ""
-
- return self.controller.get_output()
-
- def stop(self, wait: bool = False) -> None:
- """Stop the system."""
- if not self.controller:
- raise Exception("System has not been started")
-
- if isinstance(self.protocol, SupportsClose):
- try:
- self.protocol.close()
- except Exception as error: # pylint: disable=broad-except
- print(error)
- self.controller.stop(wait)
-
-
-def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]:
- """Load system based on it's execution type."""
- data_transfer = config.get("data_transfer", {})
- protocol = data_transfer.get("protocol")
- populate_shared_params(config)
-
- if protocol == "ssh":
- return ControlledSystem(config)
-
- if protocol == "local":
- return StandaloneSystem(config)
-
- raise ConfigurationException(
- "Unsupported execution type for protocol {}".format(protocol)
- )
-
-
-def populate_shared_params(config: SystemConfig) -> None:
- """Populate command parameters with shared parameters."""
- user_params = config.get("user_params")
- if not user_params or "shared" not in user_params:
- return
-
- shared_user_params = user_params["shared"]
- if not shared_user_params:
- return
-
- only_aliases = all(p.get("alias") for p in shared_user_params)
- if not only_aliases:
- raise ConfigurationException("All shared parameters should have aliases")
-
- commands = config.get("commands", {})
- for cmd_name in ["build", "run"]:
- command = commands.get(cmd_name)
- if command is None:
- commands[cmd_name] = []
- cmd_user_params = user_params.get(cmd_name)
- if not cmd_user_params:
- cmd_user_params = shared_user_params
- else:
- only_aliases = all(p.get("alias") for p in cmd_user_params)
- if not only_aliases:
- raise ConfigurationException(
- "All parameters for command {} should have aliases".format(cmd_name)
- )
- merged_by_alias = {
- **{p.get("alias"): p for p in shared_user_params},
- **{p.get("alias"): p for p in cmd_user_params},
- }
- cmd_user_params = list(merged_by_alias.values())
-
- user_params[cmd_name] = cmd_user_params
-
- config["commands"] = commands
- del user_params["shared"]
diff --git a/src/aiet/backend/tool.py b/src/aiet/backend/tool.py
deleted file mode 100644
index d643665..0000000
--- a/src/aiet/backend/tool.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Tool backend module."""
-from typing import Any
-from typing import cast
-from typing import Dict
-from typing import List
-from typing import Optional
-
-from aiet.backend.common import Backend
-from aiet.backend.common import ConfigurationException
-from aiet.backend.common import get_backend_configs
-from aiet.backend.common import get_backend_directories
-from aiet.backend.common import load_application_or_tool_configs
-from aiet.backend.common import load_config
-from aiet.backend.config import ExtendedToolConfig
-from aiet.backend.config import ToolConfig
-
-
-def get_available_tool_directory_names() -> List[str]:
- """Return a list of directory names for all available tools."""
- return [entry.name for entry in get_backend_directories("tools")]
-
-
-def get_available_tools() -> List["Tool"]:
- """Return a list with all available tools."""
- available_tools = []
- for config_json in get_backend_configs("tools"):
- config_entries = cast(List[ExtendedToolConfig], load_config(config_json))
- for config_entry in config_entries:
- config_entry["config_location"] = config_json.parent.absolute()
- tools = load_tools(config_entry)
- available_tools += tools
-
- return sorted(available_tools, key=lambda tool: tool.name)
-
-
-def get_tool(tool_name: str, system_name: Optional[str] = None) -> List["Tool"]:
- """Return a tool instance with the same name passed as argument."""
- return [
- tool
- for tool in get_available_tools()
- if tool.name == tool_name and (not system_name or tool.can_run_on(system_name))
- ]
-
-
-def get_unique_tool_names(system_name: Optional[str] = None) -> List[str]:
- """Extract a list of unique tool names of all tools available."""
- return list(
- set(
- tool.name
- for tool in get_available_tools()
- if not system_name or tool.can_run_on(system_name)
- )
- )
-
-
-class Tool(Backend):
- """Class for representing a single tool component."""
-
- def __init__(self, config: ToolConfig) -> None:
- """Construct a Tool instance from a dict."""
- super().__init__(config)
-
- self.supported_systems = config.get("supported_systems", [])
-
- if "run" not in self.commands:
- raise ConfigurationException("A Tool must have a 'run' command.")
-
- def __eq__(self, other: object) -> bool:
- """Overload operator ==."""
- if not isinstance(other, Tool):
- return False
-
- return (
- super().__eq__(other)
- and self.name == other.name
- and set(self.supported_systems) == set(other.supported_systems)
- )
-
- def can_run_on(self, system_name: str) -> bool:
- """Check if the tool can run on the system passed as argument."""
- return system_name in self.supported_systems
-
- def get_details(self) -> Dict[str, Any]:
- """Return dictionary with all relevant information of the Tool instance."""
- output = {
- "type": "tool",
- "name": self.name,
- "description": self.description,
- "supported_systems": self.supported_systems,
- "commands": self._get_command_details(),
- }
-
- return output
-
-
-def load_tools(config: ExtendedToolConfig) -> List[Tool]:
- """Load tool.
-
- Tool configuration could contain different parameters/commands for different
- supported systems. For each supported system this function will return separate
- Tool instance with appropriate configuration.
- """
- configs = load_application_or_tool_configs(
- config, ToolConfig, is_system_required=False
- )
- tools = [Tool(cfg) for cfg in configs]
- return tools
diff --git a/src/aiet/cli/__init__.py b/src/aiet/cli/__init__.py
deleted file mode 100644
index bcd17c3..0000000
--- a/src/aiet/cli/__init__.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Module to mange the CLI interface."""
-import click
-
-from aiet import __version__
-from aiet.cli.application import application_cmd
-from aiet.cli.completion import completion_cmd
-from aiet.cli.system import system_cmd
-from aiet.cli.tool import tool_cmd
-from aiet.utils.helpers import set_verbosity
-
-
-@click.group()
-@click.version_option(__version__)
-@click.option(
- "-v", "--verbose", default=0, count=True, callback=set_verbosity, expose_value=False
-)
-@click.pass_context
-def cli(ctx: click.Context) -> None: # pylint: disable=unused-argument
- """AIET: AI Evaluation Toolkit."""
- # Unused arguments must be present here in definition to pass click context.
-
-
-cli.add_command(application_cmd)
-cli.add_command(system_cmd)
-cli.add_command(tool_cmd)
-cli.add_command(completion_cmd)
diff --git a/src/aiet/cli/application.py b/src/aiet/cli/application.py
deleted file mode 100644
index 59b652d..0000000
--- a/src/aiet/cli/application.py
+++ /dev/null
@@ -1,362 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-FileCopyrightText: Copyright (c) 2021, Gianluca Gippetto. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
-"""Module to manage the CLI interface of applications."""
-import json
-import logging
-import re
-from pathlib import Path
-from typing import Any
-from typing import IO
-from typing import List
-from typing import Optional
-from typing import Tuple
-
-import click
-import cloup
-
-from aiet.backend.application import get_application
-from aiet.backend.application import get_available_application_directory_names
-from aiet.backend.application import get_unique_application_names
-from aiet.backend.application import install_application
-from aiet.backend.application import remove_application
-from aiet.backend.common import DataPaths
-from aiet.backend.execution import execute_application_command
-from aiet.backend.execution import run_application
-from aiet.backend.system import get_available_systems
-from aiet.cli.common import get_format
-from aiet.cli.common import middleware_exception_handler
-from aiet.cli.common import middleware_signal_handler
-from aiet.cli.common import print_command_details
-from aiet.cli.common import set_format
-
-
-@click.group(name="application")
-@click.option(
- "-f",
- "--format",
- "format_",
- type=click.Choice(["cli", "json"]),
- default="cli",
- show_default=True,
-)
-@click.pass_context
-def application_cmd(ctx: click.Context, format_: str) -> None:
- """Sub command to manage applications."""
- set_format(ctx, format_)
-
-
-@application_cmd.command(name="list")
-@click.pass_context
-@click.option(
- "-s",
- "--system",
- "system_name",
- type=click.Choice([s.name for s in get_available_systems()]),
- required=False,
-)
-def list_cmd(ctx: click.Context, system_name: str) -> None:
- """List all available applications."""
- unique_application_names = get_unique_application_names(system_name)
- unique_application_names.sort()
- if get_format(ctx) == "json":
- data = {"type": "application", "available": unique_application_names}
- print(json.dumps(data))
- else:
- print("Available applications:\n")
- print(*unique_application_names, sep="\n")
-
-
-@application_cmd.command(name="details")
-@click.option(
- "-n",
- "--name",
- "application_name",
- type=click.Choice(get_unique_application_names()),
- required=True,
-)
-@click.option(
- "-s",
- "--system",
- "system_name",
- type=click.Choice([s.name for s in get_available_systems()]),
- required=False,
-)
-@click.pass_context
-def details_cmd(ctx: click.Context, application_name: str, system_name: str) -> None:
- """Details of a specific application."""
- applications = get_application(application_name, system_name)
- if not applications:
- raise click.UsageError(
- "Application '{}' doesn't support the system '{}'".format(
- application_name, system_name
- )
- )
-
- if get_format(ctx) == "json":
- applications_details = [s.get_details() for s in applications]
- print(json.dumps(applications_details))
- else:
- for application in applications:
- application_details = application.get_details()
- application_details_template = (
- 'Application "{name}" details\nDescription: {description}'
- )
-
- print(
- application_details_template.format(
- name=application_details["name"],
- description=application_details["description"],
- )
- )
-
- print(
- "\nSupported systems: {}".format(
- ", ".join(application_details["supported_systems"])
- )
- )
-
- command_details = application_details["commands"]
-
- for command, details in command_details.items():
- print("\n{} commands:".format(command))
- print_command_details(details)
-
-
-# pylint: disable=too-many-arguments
-@application_cmd.command(name="execute")
-@click.option(
- "-n",
- "--name",
- "application_name",
- type=click.Choice(get_unique_application_names()),
- required=True,
-)
-@click.option(
- "-s",
- "--system",
- "system_name",
- type=click.Choice([s.name for s in get_available_systems()]),
- required=True,
-)
-@click.option(
- "-c",
- "--command",
- "command_name",
- type=click.Choice(["build", "run"]),
- required=True,
-)
-@click.option("-p", "--param", "application_params", multiple=True)
-@click.option("--system-param", "system_params", multiple=True)
-@click.option("-d", "--deploy", "deploy_params", multiple=True)
-@middleware_signal_handler
-@middleware_exception_handler
-def execute_cmd(
- application_name: str,
- system_name: str,
- command_name: str,
- application_params: List[str],
- system_params: List[str],
- deploy_params: List[str],
-) -> None:
- """Execute application commands. DEPRECATED! Use 'aiet application run' instead."""
- logging.warning(
- "Please use 'aiet application run' instead. Use of 'aiet application "
- "execute' is deprecated and might be removed in a future release."
- )
-
- custom_deploy_data = get_custom_deploy_data(command_name, deploy_params)
-
- execute_application_command(
- command_name,
- application_name,
- application_params,
- system_name,
- system_params,
- custom_deploy_data,
- )
-
-
-@cloup.command(name="run")
-@cloup.option(
- "-n",
- "--name",
- "application_name",
- type=click.Choice(get_unique_application_names()),
-)
-@cloup.option(
- "-s",
- "--system",
- "system_name",
- type=click.Choice([s.name for s in get_available_systems()]),
-)
-@cloup.option("-p", "--param", "application_params", multiple=True)
-@cloup.option("--system-param", "system_params", multiple=True)
-@cloup.option("-d", "--deploy", "deploy_params", multiple=True)
-@click.option(
- "-r",
- "--report",
- "report_file",
- type=Path,
- help="Create a report file in JSON format containing metrics parsed from "
- "the simulation output as specified in the aiet-config.json.",
-)
-@cloup.option(
- "--config",
- "config_file",
- type=click.File("r"),
- help="Read options from a config file rather than from the command line. "
- "The config file is a json file.",
-)
-@cloup.constraint(
- cloup.constraints.If(
- cloup.constraints.conditions.Not(
- cloup.constraints.conditions.IsSet("config_file")
- ),
- then=cloup.constraints.require_all,
- ),
- ["system_name", "application_name"],
-)
-@cloup.constraint(
- cloup.constraints.If("config_file", then=cloup.constraints.accept_none),
- [
- "system_name",
- "application_name",
- "application_params",
- "system_params",
- "deploy_params",
- ],
-)
-@middleware_signal_handler
-@middleware_exception_handler
-def run_cmd(
- application_name: str,
- system_name: str,
- application_params: List[str],
- system_params: List[str],
- deploy_params: List[str],
- report_file: Optional[Path],
- config_file: Optional[IO[str]],
-) -> None:
- """Execute application commands."""
- if config_file:
- payload_data = json.load(config_file)
- (
- system_name,
- application_name,
- application_params,
- system_params,
- deploy_params,
- report_file,
- ) = parse_payload_run_config(payload_data)
-
- custom_deploy_data = get_custom_deploy_data("run", deploy_params)
-
- run_application(
- application_name,
- application_params,
- system_name,
- system_params,
- custom_deploy_data,
- report_file,
- )
-
-
-application_cmd.add_command(run_cmd)
-
-
-def parse_payload_run_config(
- payload_data: dict,
-) -> Tuple[str, str, List[str], List[str], List[str], Optional[Path]]:
- """Parse the payload into a tuple."""
- system_id = payload_data.get("id")
- arguments: Optional[Any] = payload_data.get("arguments")
-
- if not isinstance(system_id, str):
- raise click.ClickException("invalid payload json: no system 'id'")
- if not isinstance(arguments, dict):
- raise click.ClickException("invalid payload json: no arguments object")
-
- application_name = arguments.pop("application", None)
- if not isinstance(application_name, str):
- raise click.ClickException("invalid payload json: no application_id")
-
- report_path = arguments.pop("report_path", None)
-
- application_params = []
- system_params = []
- deploy_params = []
-
- for (param_key, value) in arguments.items():
- (par, _) = re.subn("^application/", "", param_key)
- (par, found_sys_param) = re.subn("^system/", "", par)
- (par, found_deploy_param) = re.subn("^deploy/", "", par)
-
- param_expr = par + "=" + value
- if found_sys_param:
- system_params.append(param_expr)
- elif found_deploy_param:
- deploy_params.append(par)
- else:
- application_params.append(param_expr)
-
- return (
- system_id,
- application_name,
- application_params,
- system_params,
- deploy_params,
- report_path,
- )
-
-
-def get_custom_deploy_data(
- command_name: str, deploy_params: List[str]
-) -> List[DataPaths]:
- """Get custom deploy data information."""
- custom_deploy_data: List[DataPaths] = []
- if not deploy_params:
- return custom_deploy_data
-
- for param in deploy_params:
- parts = param.split(":")
- if not len(parts) == 2 or any(not part.strip() for part in parts):
- raise click.ClickException(
- "Invalid deploy parameter '{}' for command {}".format(
- param, command_name
- )
- )
- data_path = DataPaths(Path(parts[0]), parts[1])
- if not data_path.src.exists():
- raise click.ClickException("Path {} does not exist".format(data_path.src))
- custom_deploy_data.append(data_path)
-
- return custom_deploy_data
-
-
-@application_cmd.command(name="install")
-@click.option(
- "-s",
- "--source",
- "source",
- required=True,
- help="Path to the directory or archive with application definition",
-)
-def install_cmd(source: str) -> None:
- """Install new application."""
- source_path = Path(source)
- install_application(source_path)
-
-
-@application_cmd.command(name="remove")
-@click.option(
- "-d",
- "--directory_name",
- "directory_name",
- type=click.Choice(get_available_application_directory_names()),
- required=True,
- help="Name of the directory with application",
-)
-def remove_cmd(directory_name: str) -> None:
- """Remove application."""
- remove_application(directory_name)
diff --git a/src/aiet/cli/common.py b/src/aiet/cli/common.py
deleted file mode 100644
index 1d157b6..0000000
--- a/src/aiet/cli/common.py
+++ /dev/null
@@ -1,173 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Common functions for cli module."""
-import enum
-import logging
-from functools import wraps
-from signal import SIG_IGN
-from signal import SIGINT
-from signal import signal as signal_handler
-from signal import SIGTERM
-from typing import Any
-from typing import Callable
-from typing import cast
-from typing import Dict
-
-from click import ClickException
-from click import Context
-from click import UsageError
-
-from aiet.backend.common import ConfigurationException
-from aiet.backend.execution import AnotherInstanceIsRunningException
-from aiet.backend.execution import ConnectionException
-from aiet.backend.protocol import SSHConnectionException
-from aiet.utils.proc import CommandFailedException
-
-
-class MiddlewareExitCode(enum.IntEnum):
- """Middleware exit codes."""
-
- SUCCESS = 0
- # exit codes 1 and 2 are used by click
- SHUTDOWN_REQUESTED = 3
- BACKEND_ERROR = 4
- CONCURRENT_ERROR = 5
- CONNECTION_ERROR = 6
- CONFIGURATION_ERROR = 7
- MODEL_OPTIMISED_ERROR = 8
- INVALID_TFLITE_FILE_ERROR = 9
-
-
-class CustomClickException(ClickException):
- """Custom click exception."""
-
- def show(self, file: Any = None) -> None:
- """Override show method."""
- super().show(file)
-
- logging.debug("Execution failed with following exception: ", exc_info=self)
-
-
-class MiddlewareShutdownException(CustomClickException):
- """Exception indicates that user requested middleware shutdown."""
-
- exit_code = int(MiddlewareExitCode.SHUTDOWN_REQUESTED)
-
-
-class BackendException(CustomClickException):
- """Exception indicates that command failed."""
-
- exit_code = int(MiddlewareExitCode.BACKEND_ERROR)
-
-
-class ConcurrentErrorException(CustomClickException):
- """Exception indicates concurrent execution error."""
-
- exit_code = int(MiddlewareExitCode.CONCURRENT_ERROR)
-
-
-class BackendConnectionException(CustomClickException):
- """Exception indicates that connection could not be established."""
-
- exit_code = int(MiddlewareExitCode.CONNECTION_ERROR)
-
-
-class BackendConfigurationException(CustomClickException):
- """Exception indicates some configuration issue."""
-
- exit_code = int(MiddlewareExitCode.CONFIGURATION_ERROR)
-
-
-class ModelOptimisedException(CustomClickException):
- """Exception indicates input file has previously been Vela optimised."""
-
- exit_code = int(MiddlewareExitCode.MODEL_OPTIMISED_ERROR)
-
-
-class InvalidTFLiteFileError(CustomClickException):
- """Exception indicates input TFLite file is misformatted."""
-
- exit_code = int(MiddlewareExitCode.INVALID_TFLITE_FILE_ERROR)
-
-
-def print_command_details(command: Dict) -> None:
- """Print command details including parameters."""
- command_strings = command["command_strings"]
- print("Commands: {}".format(command_strings))
- user_params = command["user_params"]
- for i, param in enumerate(user_params, 1):
- print("User parameter #{}".format(i))
- print("\tName: {}".format(param.get("name", "-")))
- print("\tDescription: {}".format(param["description"]))
- print("\tPossible values: {}".format(param.get("values", "-")))
- print("\tDefault value: {}".format(param.get("default_value", "-")))
- print("\tAlias: {}".format(param.get("alias", "-")))
-
-
-def raise_exception_at_signal(
- signum: int, frame: Any # pylint: disable=unused-argument
-) -> None:
- """Handle signals."""
- # Disable both SIGINT and SIGTERM signals. Further SIGINT and SIGTERM
- # signals will be ignored as we allow a graceful shutdown.
- # Unused arguments must be present here in definition as used in signal handler
- # callback
-
- signal_handler(SIGINT, SIG_IGN)
- signal_handler(SIGTERM, SIG_IGN)
- raise MiddlewareShutdownException("Middleware shutdown requested")
-
-
-def middleware_exception_handler(func: Callable) -> Callable:
- """Handle backend exceptions decorator."""
-
- @wraps(func)
- def wrapper(*args: Any, **kwargs: Any) -> Any:
- try:
- return func(*args, **kwargs)
- except (MiddlewareShutdownException, UsageError, ClickException) as error:
- # click should take care of these exceptions
- raise error
- except ValueError as error:
- raise ClickException(str(error)) from error
- except AnotherInstanceIsRunningException as error:
- raise ConcurrentErrorException(
- "Another instance of the system is running"
- ) from error
- except (SSHConnectionException, ConnectionException) as error:
- raise BackendConnectionException(str(error)) from error
- except ConfigurationException as error:
- raise BackendConfigurationException(str(error)) from error
- except (CommandFailedException, Exception) as error:
- raise BackendException(
- "Execution failed. Please check output for the details."
- ) from error
-
- return wrapper
-
-
-def middleware_signal_handler(func: Callable) -> Callable:
- """Handle signals decorator."""
-
- @wraps(func)
- def wrapper(*args: Any, **kwargs: Any) -> Any:
- # Set up signal handlers for SIGINT (ctrl-c) and SIGTERM (kill command)
- # The handler ignores further signals and it raises an exception
- signal_handler(SIGINT, raise_exception_at_signal)
- signal_handler(SIGTERM, raise_exception_at_signal)
-
- return func(*args, **kwargs)
-
- return wrapper
-
-
-def set_format(ctx: Context, format_: str) -> None:
- """Save format in click context."""
- ctx_obj = ctx.ensure_object(dict)
- ctx_obj["format"] = format_
-
-
-def get_format(ctx: Context) -> str:
- """Get format from click context."""
- ctx_obj = cast(Dict[str, str], ctx.ensure_object(dict))
- return ctx_obj["format"]
diff --git a/src/aiet/cli/completion.py b/src/aiet/cli/completion.py
deleted file mode 100644
index 71f054f..0000000
--- a/src/aiet/cli/completion.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""
-Add auto completion to different shells with these helpers.
-
-See: https://click.palletsprojects.com/en/8.0.x/shell-completion/
-"""
-import click
-
-
-def _get_package_name() -> str:
- return __name__.split(".", maxsplit=1)[0]
-
-
-# aiet completion bash
-@click.group(name="completion")
-def completion_cmd() -> None:
- """Enable auto completion for your shell."""
-
-
-@completion_cmd.command(name="bash")
-def bash_cmd() -> None:
- """
- Enable auto completion for bash.
-
- Use this command to activate completion in the current bash:
-
- eval "`aiet completion bash`"
-
- Use this command to add auto completion to bash globally, if you have aiet
- installed globally (requires starting a new shell afterwards):
-
- aiet completion bash >> ~/.bashrc
- """
- package_name = _get_package_name()
- print(f'eval "$(_{package_name.upper()}_COMPLETE=bash_source {package_name})"')
-
-
-@completion_cmd.command(name="zsh")
-def zsh_cmd() -> None:
- """
- Enable auto completion for zsh.
-
- Use this command to activate completion in the current zsh:
-
- eval "`aiet completion zsh`"
-
- Use this command to add auto completion to zsh globally, if you have aiet
- installed globally (requires starting a new shell afterwards):
-
- aiet completion zsh >> ~/.zshrc
- """
- package_name = _get_package_name()
- print(f'eval "$(_{package_name.upper()}_COMPLETE=zsh_source {package_name})"')
-
-
-@completion_cmd.command(name="fish")
-def fish_cmd() -> None:
- """
- Enable auto completion for fish.
-
- Use this command to activate completion in the current fish:
-
- eval "`aiet completion fish`"
-
- Use this command to add auto completion to fish globally, if you have aiet
- installed globally (requires starting a new shell afterwards):
-
- aiet completion fish >> ~/.config/fish/completions/aiet.fish
- """
- package_name = _get_package_name()
- print(f'eval "(env _{package_name.upper()}_COMPLETE=fish_source {package_name})"')
diff --git a/src/aiet/cli/system.py b/src/aiet/cli/system.py
deleted file mode 100644
index f1f7637..0000000
--- a/src/aiet/cli/system.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Module to manage the CLI interface of systems."""
-import json
-from pathlib import Path
-from typing import cast
-
-import click
-
-from aiet.backend.application import get_available_applications
-from aiet.backend.system import get_available_systems
-from aiet.backend.system import get_available_systems_directory_names
-from aiet.backend.system import get_system
-from aiet.backend.system import install_system
-from aiet.backend.system import remove_system
-from aiet.backend.system import System
-from aiet.cli.common import get_format
-from aiet.cli.common import print_command_details
-from aiet.cli.common import set_format
-
-
-@click.group(name="system")
-@click.option(
- "-f",
- "--format",
- "format_",
- type=click.Choice(["cli", "json"]),
- default="cli",
- show_default=True,
-)
-@click.pass_context
-def system_cmd(ctx: click.Context, format_: str) -> None:
- """Sub command to manage systems."""
- set_format(ctx, format_)
-
-
-@system_cmd.command(name="list")
-@click.pass_context
-def list_cmd(ctx: click.Context) -> None:
- """List all available systems."""
- available_systems = get_available_systems()
- system_names = [system.name for system in available_systems]
- if get_format(ctx) == "json":
- data = {"type": "system", "available": system_names}
- print(json.dumps(data))
- else:
- print("Available systems:\n")
- print(*system_names, sep="\n")
-
-
-@system_cmd.command(name="details")
-@click.option(
- "-n",
- "--name",
- "system_name",
- type=click.Choice([s.name for s in get_available_systems()]),
- required=True,
-)
-@click.pass_context
-def details_cmd(ctx: click.Context, system_name: str) -> None:
- """Details of a specific system."""
- system = cast(System, get_system(system_name))
- applications = [
- s.name for s in get_available_applications() if s.can_run_on(system.name)
- ]
- system_details = system.get_details()
- if get_format(ctx) == "json":
- system_details["available_application"] = applications
- print(json.dumps(system_details))
- else:
- system_details_template = (
- 'System "{name}" details\n'
- "Description: {description}\n"
- "Data Transfer Protocol: {protocol}\n"
- "Available Applications: {available_application}"
- )
- print(
- system_details_template.format(
- name=system_details["name"],
- description=system_details["description"],
- protocol=system_details["data_transfer_protocol"],
- available_application=", ".join(applications),
- )
- )
-
- if system_details["annotations"]:
- print("Annotations:")
- for ann_name, ann_value in system_details["annotations"].items():
- print("\t{}: {}".format(ann_name, ann_value))
-
- command_details = system_details["commands"]
- for command, details in command_details.items():
- print("\n{} commands:".format(command))
- print_command_details(details)
-
-
-@system_cmd.command(name="install")
-@click.option(
- "-s",
- "--source",
- "source",
- required=True,
- help="Path to the directory or archive with system definition",
-)
-def install_cmd(source: str) -> None:
- """Install new system."""
- source_path = Path(source)
- install_system(source_path)
-
-
-@system_cmd.command(name="remove")
-@click.option(
- "-d",
- "--directory_name",
- "directory_name",
- type=click.Choice(get_available_systems_directory_names()),
- required=True,
- help="Name of the directory with system",
-)
-def remove_cmd(directory_name: str) -> None:
- """Remove system by given name."""
- remove_system(directory_name)
diff --git a/src/aiet/cli/tool.py b/src/aiet/cli/tool.py
deleted file mode 100644
index 2c80821..0000000
--- a/src/aiet/cli/tool.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Module to manage the CLI interface of tools."""
-import json
-from typing import Any
-from typing import List
-from typing import Optional
-
-import click
-
-from aiet.backend.execution import execute_tool_command
-from aiet.backend.tool import get_tool
-from aiet.backend.tool import get_unique_tool_names
-from aiet.cli.common import get_format
-from aiet.cli.common import middleware_exception_handler
-from aiet.cli.common import middleware_signal_handler
-from aiet.cli.common import print_command_details
-from aiet.cli.common import set_format
-
-
-@click.group(name="tool")
-@click.option(
- "-f",
- "--format",
- "format_",
- type=click.Choice(["cli", "json"]),
- default="cli",
- show_default=True,
-)
-@click.pass_context
-def tool_cmd(ctx: click.Context, format_: str) -> None:
- """Sub command to manage tools."""
- set_format(ctx, format_)
-
-
-@tool_cmd.command(name="list")
-@click.pass_context
-def list_cmd(ctx: click.Context) -> None:
- """List all available tools."""
- # raise NotImplementedError("TODO")
- tool_names = get_unique_tool_names()
- tool_names.sort()
- if get_format(ctx) == "json":
- data = {"type": "tool", "available": tool_names}
- print(json.dumps(data))
- else:
- print("Available tools:\n")
- print(*tool_names, sep="\n")
-
-
-def validate_system(
- ctx: click.Context,
- _: click.Parameter, # param is not used
- value: Any,
-) -> Any:
- """Validate provided system name depending on the the tool name."""
- tool_name = ctx.params["tool_name"]
- tools = get_tool(tool_name, value)
- if not tools:
- supported_systems = [tool.supported_systems[0] for tool in get_tool(tool_name)]
- raise click.BadParameter(
- message="'{}' is not one of {}.".format(
- value,
- ", ".join("'{}'".format(system) for system in supported_systems),
- ),
- ctx=ctx,
- )
- return value
-
-
-@tool_cmd.command(name="details")
-@click.option(
- "-n",
- "--name",
- "tool_name",
- type=click.Choice(get_unique_tool_names()),
- required=True,
-)
-@click.option(
- "-s",
- "--system",
- "system_name",
- callback=validate_system,
- required=False,
-)
-@click.pass_context
-@middleware_signal_handler
-@middleware_exception_handler
-def details_cmd(ctx: click.Context, tool_name: str, system_name: Optional[str]) -> None:
- """Details of a specific tool."""
- tools = get_tool(tool_name, system_name)
- if get_format(ctx) == "json":
- tools_details = [s.get_details() for s in tools]
- print(json.dumps(tools_details))
- else:
- for tool in tools:
- tool_details = tool.get_details()
- tool_details_template = 'Tool "{name}" details\nDescription: {description}'
-
- print(
- tool_details_template.format(
- name=tool_details["name"],
- description=tool_details["description"],
- )
- )
-
- print(
- "\nSupported systems: {}".format(
- ", ".join(tool_details["supported_systems"])
- )
- )
-
- command_details = tool_details["commands"]
-
- for command, details in command_details.items():
- print("\n{} commands:".format(command))
- print_command_details(details)
-
-
-# pylint: disable=too-many-arguments
-@tool_cmd.command(name="execute")
-@click.option(
- "-n",
- "--name",
- "tool_name",
- type=click.Choice(get_unique_tool_names()),
- required=True,
-)
-@click.option("-p", "--param", "tool_params", multiple=True)
-@click.option(
- "-s",
- "--system",
- "system_name",
- callback=validate_system,
- required=False,
-)
-@middleware_signal_handler
-@middleware_exception_handler
-def execute_cmd(
- tool_name: str, tool_params: List[str], system_name: Optional[str]
-) -> None:
- """Execute tool commands."""
- execute_tool_command(tool_name, tool_params, system_name)
diff --git a/src/aiet/main.py b/src/aiet/main.py
deleted file mode 100644
index 6898ad9..0000000
--- a/src/aiet/main.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Entry point module of AIET."""
-from aiet.cli import cli
-
-
-def main() -> None:
- """Entry point of aiet application."""
- cli() # pylint: disable=no-value-for-parameter
-
-
-if __name__ == "__main__":
- main()
diff --git a/src/aiet/resources/applications/.gitignore b/src/aiet/resources/applications/.gitignore
deleted file mode 100644
index 0226166..0000000
--- a/src/aiet/resources/applications/.gitignore
+++ /dev/null
@@ -1,6 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-# Ignore everything in this directory
-*
-# Except this file
-!.gitignore
diff --git a/src/aiet/resources/systems/.gitignore b/src/aiet/resources/systems/.gitignore
deleted file mode 100644
index 0226166..0000000
--- a/src/aiet/resources/systems/.gitignore
+++ /dev/null
@@ -1,6 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-# Ignore everything in this directory
-*
-# Except this file
-!.gitignore
diff --git a/src/aiet/resources/tools/vela/aiet-config.json b/src/aiet/resources/tools/vela/aiet-config.json
deleted file mode 100644
index c12f291..0000000
--- a/src/aiet/resources/tools/vela/aiet-config.json
+++ /dev/null
@@ -1,73 +0,0 @@
-[
- {
- "name": "vela",
- "description": "Neural network model compiler for Arm Ethos-U NPUs",
- "supported_systems": [
- {
- "name": "Corstone-300: Cortex-M55+Ethos-U55"
- },
- {
- "name": "Corstone-310: Cortex-M85+Ethos-U55"
- },
- {
- "name": "Corstone-300: Cortex-M55+Ethos-U65",
- "variables": {
- "accelerator_config_prefix": "ethos-u65",
- "system_config": "Ethos_U65_High_End",
- "shared_sram": "U65_Shared_Sram"
- },
- "user_params": {
- "run": [
- {
- "description": "MACs per cycle",
- "values": [
- "256",
- "512"
- ],
- "default_value": "512",
- "alias": "mac"
- }
- ]
- }
- }
- ],
- "variables": {
- "accelerator_config_prefix": "ethos-u55",
- "system_config": "Ethos_U55_High_End_Embedded",
- "shared_sram": "U55_Shared_Sram"
- },
- "commands": {
- "run": [
- "run_vela {user_params:input} {user_params:output} --config {tool.config_dir}/vela.ini --accelerator-config {variables:accelerator_config_prefix}-{user_params:mac} --system-config {variables:system_config} --memory-mode {variables:shared_sram} --optimise Performance"
- ]
- },
- "user_params": {
- "run": [
- {
- "description": "MACs per cycle",
- "values": [
- "32",
- "64",
- "128",
- "256"
- ],
- "default_value": "128",
- "alias": "mac"
- },
- {
- "name": "--input-model",
- "description": "Path to the TFLite model",
- "values": [],
- "alias": "input"
- },
- {
- "name": "--output-model",
- "description": "Path to the output model file of the vela-optimisation step. The vela output is saved in the parent directory.",
- "values": [],
- "default_value": "output_model.tflite",
- "alias": "output"
- }
- ]
- }
- }
-]
diff --git a/src/aiet/resources/tools/vela/aiet-config.json.license b/src/aiet/resources/tools/vela/aiet-config.json.license
deleted file mode 100644
index 9b83bfc..0000000
--- a/src/aiet/resources/tools/vela/aiet-config.json.license
+++ /dev/null
@@ -1,3 +0,0 @@
-SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-
-SPDX-License-Identifier: Apache-2.0
diff --git a/src/aiet/resources/tools/vela/check_model.py b/src/aiet/resources/tools/vela/check_model.py
deleted file mode 100644
index 7c700b1..0000000
--- a/src/aiet/resources/tools/vela/check_model.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Check if a TFLite model file is Vela-optimised."""
-import struct
-from pathlib import Path
-
-from ethosu.vela.tflite.Model import Model
-
-from aiet.cli.common import InvalidTFLiteFileError
-from aiet.cli.common import ModelOptimisedException
-from aiet.utils.fs import read_file_as_bytearray
-
-
-def get_model_from_file(input_model_file: Path) -> Model:
- """Generate Model instance from TFLite file using flatc generated code."""
- buffer = read_file_as_bytearray(input_model_file)
- try:
- model = Model.GetRootAsModel(buffer, 0)
- except (TypeError, RuntimeError, struct.error) as tflite_error:
- raise InvalidTFLiteFileError(
- f"Error reading in model from {input_model_file}."
- ) from tflite_error
- return model
-
-
-def is_vela_optimised(tflite_model: Model) -> bool:
- """Return True if 'ethos-u' custom operator found in the Model."""
- operators = get_operators_from_model(tflite_model)
-
- custom_codes = get_custom_codes_from_operators(operators)
-
- return check_custom_codes_for_ethosu(custom_codes)
-
-
-def get_operators_from_model(tflite_model: Model) -> list:
- """Return list of the unique operator codes used in the Model."""
- return [
- tflite_model.OperatorCodes(index)
- for index in range(tflite_model.OperatorCodesLength())
- ]
-
-
-def get_custom_codes_from_operators(operators: list) -> list:
- """Return list of each operator's CustomCode() strings, if they exist."""
- return [
- operator.CustomCode()
- for operator in operators
- if operator.CustomCode() is not None
- ]
-
-
-def check_custom_codes_for_ethosu(custom_codes: list) -> bool:
- """Check for existence of ethos-u string in the custom codes."""
- return any(
- custom_code_name.decode("utf-8") == "ethos-u"
- for custom_code_name in custom_codes
- )
-
-
-def check_model(tflite_file_name: str) -> None:
- """Raise an exception if model in given file is Vela optimised."""
- tflite_path = Path(tflite_file_name)
-
- tflite_model = get_model_from_file(tflite_path)
-
- if is_vela_optimised(tflite_model):
- raise ModelOptimisedException(
- f"TFLite model in {tflite_file_name} is already "
- f"vela optimised ('ethos-u' custom op detected)."
- )
-
- print(
- f"TFLite model in {tflite_file_name} is not vela optimised "
- f"('ethos-u' custom op not detected)."
- )
diff --git a/src/aiet/resources/tools/vela/run_vela.py b/src/aiet/resources/tools/vela/run_vela.py
deleted file mode 100644
index 2c1b0be..0000000
--- a/src/aiet/resources/tools/vela/run_vela.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Wrapper to only run Vela when the input is not already optimised."""
-import shutil
-import subprocess
-from pathlib import Path
-from typing import Tuple
-
-import click
-
-from aiet.cli.common import ModelOptimisedException
-from aiet.resources.tools.vela.check_model import check_model
-
-
-def vela_output_model_path(input_model: str, output_dir: str) -> Path:
- """Construct the path to the Vela output file."""
- in_path = Path(input_model)
- tflite_vela = Path(output_dir) / f"{in_path.stem}_vela{in_path.suffix}"
- return tflite_vela
-
-
-def execute_vela(vela_args: Tuple, output_dir: Path, input_model: str) -> None:
- """Execute vela as external call."""
- cmd = ["vela"] + list(vela_args)
- cmd += ["--output-dir", str(output_dir)] # Re-add parsed out_dir to arguments
- cmd += [input_model]
- subprocess.run(cmd, check=True)
-
-
-@click.command(context_settings=dict(ignore_unknown_options=True))
-@click.option(
- "--input-model",
- "-i",
- type=click.Path(exists=True, file_okay=True, readable=True),
- required=True,
-)
-@click.option("--output-model", "-o", type=click.Path(), required=True)
-# Collect the remaining arguments to be directly forwarded to Vela
-@click.argument("vela-args", nargs=-1, type=click.UNPROCESSED)
-def run_vela(input_model: str, output_model: str, vela_args: Tuple) -> None:
- """Check input, run Vela (if needed) and copy optimised file to destination."""
- output_dir = Path(output_model).parent
- try:
- check_model(input_model) # raises an exception if already Vela-optimised
- execute_vela(vela_args, output_dir, input_model)
- print("Vela optimisation complete.")
- src_model = vela_output_model_path(input_model, str(output_dir))
- except ModelOptimisedException as ex:
- # Input already optimized: copy input file to destination path and return
- print(f"Input already vela-optimised.\n{ex}")
- src_model = Path(input_model)
- except subprocess.CalledProcessError as ex:
- print(ex)
- raise SystemExit(ex.returncode) from ex
-
- try:
- shutil.copyfile(src_model, output_model)
- except (shutil.SameFileError, OSError) as ex:
- print(ex)
- raise SystemExit(ex.errno) from ex
-
-
-def main() -> None:
- """Entry point of check_model application."""
- run_vela() # pylint: disable=no-value-for-parameter
diff --git a/src/aiet/resources/tools/vela/vela.ini b/src/aiet/resources/tools/vela/vela.ini
deleted file mode 100644
index 5996553..0000000
--- a/src/aiet/resources/tools/vela/vela.ini
+++ /dev/null
@@ -1,53 +0,0 @@
-; SPDX-FileCopyrightText: Copyright 2021-2022, Arm Limited and/or its affiliates.
-; SPDX-License-Identifier: Apache-2.0
-
-; -----------------------------------------------------------------------------
-; Vela configuration file
-
-; -----------------------------------------------------------------------------
-; System Configuration
-
-; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s)
-[System_Config.Ethos_U55_High_End_Embedded]
-core_clock=500e6
-axi0_port=Sram
-axi1_port=OffChipFlash
-Sram_clock_scale=1.0
-Sram_burst_length=32
-Sram_read_latency=32
-Sram_write_latency=32
-OffChipFlash_clock_scale=0.125
-OffChipFlash_burst_length=128
-OffChipFlash_read_latency=64
-OffChipFlash_write_latency=64
-
-; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s)
-[System_Config.Ethos_U65_High_End]
-core_clock=1e9
-axi0_port=Sram
-axi1_port=Dram
-Sram_clock_scale=1.0
-Sram_burst_length=32
-Sram_read_latency=32
-Sram_write_latency=32
-Dram_clock_scale=0.234375
-Dram_burst_length=128
-Dram_read_latency=500
-Dram_write_latency=250
-
-; -----------------------------------------------------------------------------
-; Memory Mode
-
-; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software
-; The non-SRAM memory is assumed to be read-only
-[Memory_Mode.U55_Shared_Sram]
-const_mem_area=Axi1
-arena_mem_area=Axi0
-cache_mem_area=Axi0
-arena_cache_size=4194304
-
-[Memory_Mode.U65_Shared_Sram]
-const_mem_area=Axi1
-arena_mem_area=Axi0
-cache_mem_area=Axi0
-arena_cache_size=2097152
diff --git a/src/aiet/utils/__init__.py b/src/aiet/utils/__init__.py
deleted file mode 100644
index fc7ef7c..0000000
--- a/src/aiet/utils/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""This module contains all utils shared across aiet project."""
diff --git a/src/aiet/utils/fs.py b/src/aiet/utils/fs.py
deleted file mode 100644
index ea99a69..0000000
--- a/src/aiet/utils/fs.py
+++ /dev/null
@@ -1,116 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Module to host all file system related functions."""
-import importlib.resources as pkg_resources
-import re
-import shutil
-from pathlib import Path
-from typing import Any
-from typing import Literal
-from typing import Optional
-
-ResourceType = Literal["applications", "systems", "tools"]
-
-
-def get_aiet_resources() -> Path:
- """Get resources folder path."""
- with pkg_resources.path("aiet", "__init__.py") as init_path:
- project_root = init_path.parent
- return project_root / "resources"
-
-
-def get_resources(name: ResourceType) -> Path:
- """Return the absolute path of the specified resource.
-
- It uses importlib to return resources packaged with MANIFEST.in.
- """
- if not name:
- raise ResourceWarning("Resource name is not provided")
-
- resource_path = get_aiet_resources() / name
- if resource_path.is_dir():
- return resource_path
-
- raise ResourceWarning("Resource '{}' not found.".format(name))
-
-
-def copy_directory_content(source: Path, destination: Path) -> None:
- """Copy content of the source directory into destination directory."""
- for item in source.iterdir():
- src = source / item.name
- dest = destination / item.name
-
- if src.is_dir():
- shutil.copytree(src, dest)
- else:
- shutil.copy2(src, dest)
-
-
-def remove_resource(resource_directory: str, resource_type: ResourceType) -> None:
- """Remove resource data."""
- resources = get_resources(resource_type)
-
- resource_location = resources / resource_directory
- if not resource_location.exists():
- raise Exception("Resource {} does not exist".format(resource_directory))
-
- if not resource_location.is_dir():
- raise Exception("Wrong resource {}".format(resource_directory))
-
- shutil.rmtree(resource_location)
-
-
-def remove_directory(directory_path: Optional[Path]) -> None:
- """Remove directory."""
- if not directory_path or not directory_path.is_dir():
- raise Exception("No directory path provided")
-
- shutil.rmtree(directory_path)
-
-
-def recreate_directory(directory_path: Optional[Path]) -> None:
- """Recreate directory."""
- if not directory_path:
- raise Exception("No directory path provided")
-
- if directory_path.exists() and not directory_path.is_dir():
- raise Exception(
- "Path {} does exist and it is not a directory".format(str(directory_path))
- )
-
- if directory_path.is_dir():
- remove_directory(directory_path)
-
- directory_path.mkdir()
-
-
-def read_file(file_path: Path, mode: Optional[str] = None) -> Any:
- """Read file as string or bytearray."""
- if file_path.is_file():
- if mode is not None:
- # Ignore pylint warning because mode can be 'binary' as well which
- # is not compatible with specifying encodings.
- with open(file_path, mode) as file: # pylint: disable=unspecified-encoding
- return file.read()
- else:
- with open(file_path, encoding="utf-8") as file:
- return file.read()
-
- if mode == "rb":
- return b""
- return ""
-
-
-def read_file_as_string(file_path: Path) -> str:
- """Read file as string."""
- return str(read_file(file_path))
-
-
-def read_file_as_bytearray(file_path: Path) -> bytearray:
- """Read a file as bytearray."""
- return bytearray(read_file(file_path, mode="rb"))
-
-
-def valid_for_filename(value: str, replacement: str = "") -> str:
- """Replace non alpha numeric characters."""
- return re.sub(r"[^\w.]", replacement, value, flags=re.ASCII)
diff --git a/src/aiet/utils/helpers.py b/src/aiet/utils/helpers.py
deleted file mode 100644
index 6d3cd22..0000000
--- a/src/aiet/utils/helpers.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Helpers functions."""
-import logging
-from typing import Any
-
-
-def set_verbosity(
- ctx: Any, option: Any, verbosity: Any # pylint: disable=unused-argument
-) -> None:
- """Set the logging level according to the verbosity."""
- # Unused arguments must be present here in definition as these are required in
- # function definition when set as a callback
- if verbosity == 1:
- logging.getLogger().setLevel(logging.INFO)
- elif verbosity > 1:
- logging.getLogger().setLevel(logging.DEBUG)
diff --git a/src/aiet/utils/proc.py b/src/aiet/utils/proc.py
deleted file mode 100644
index b6f4357..0000000
--- a/src/aiet/utils/proc.py
+++ /dev/null
@@ -1,283 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Processes module.
-
-This module contains all classes and functions for dealing with Linux
-processes.
-"""
-import csv
-import datetime
-import logging
-import shlex
-import signal
-import time
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import NamedTuple
-from typing import Optional
-from typing import Tuple
-
-import psutil
-from sh import Command
-from sh import CommandNotFound
-from sh import ErrorReturnCode
-from sh import RunningCommand
-
-from aiet.utils.fs import valid_for_filename
-
-
-class CommandFailedException(Exception):
- """Exception for failed command execution."""
-
-
-class ShellCommand:
- """Wrapper class for shell commands."""
-
- def __init__(self, base_log_path: str = "/tmp") -> None:
- """Initialise the class.
-
- base_log_path: it is the base directory where logs will be stored
- """
- self.base_log_path = base_log_path
-
- def run(
- self,
- cmd: str,
- *args: str,
- _cwd: Optional[Path] = None,
- _tee: bool = True,
- _bg: bool = True,
- _out: Any = None,
- _err: Any = None,
- _search_paths: Optional[List[Path]] = None
- ) -> RunningCommand:
- """Run the shell command with the given arguments.
-
- There are special arguments that modify the behaviour of the process.
- _cwd: current working directory
- _tee: it redirects the stdout both to console and file
- _bg: if True, it runs the process in background and the command is not
- blocking.
- _out: use this object for stdout redirect,
- _err: use this object for stderr redirect,
- _search_paths: If presented used for searching executable
- """
- try:
- kwargs = {}
- if _cwd:
- kwargs["_cwd"] = str(_cwd)
- command = Command(cmd, _search_paths).bake(args, **kwargs)
- except CommandNotFound as error:
- logging.error("Command '%s' not found", error.args[0])
- raise error
-
- out, err = _out, _err
- if not _out and not _err:
- out, err = [
- str(item)
- for item in self.get_stdout_stderr_paths(self.base_log_path, cmd)
- ]
-
- return command(_out=out, _err=err, _tee=_tee, _bg=_bg, _bg_exc=False)
-
- @classmethod
- def get_stdout_stderr_paths(cls, base_log_path: str, cmd: str) -> Tuple[Path, Path]:
- """Construct and returns the paths of stdout/stderr files."""
- timestamp = datetime.datetime.now().timestamp()
- base_path = Path(base_log_path)
- base = base_path / "{}_{}".format(valid_for_filename(cmd, "_"), timestamp)
- stdout = base.with_suffix(".out")
- stderr = base.with_suffix(".err")
- try:
- stdout.touch()
- stderr.touch()
- except FileNotFoundError as error:
- logging.error("File not found: %s", error.filename)
- raise error
- return stdout, stderr
-
-
-def parse_command(command: str, shell: str = "bash") -> List[str]:
- """Parse command."""
- cmd, *args = shlex.split(command, posix=True)
-
- if is_shell_script(cmd):
- args = [cmd] + args
- cmd = shell
-
- return [cmd] + args
-
-
-def get_stdout_stderr_paths(
- command: str, base_log_path: str = "/tmp"
-) -> Tuple[Path, Path]:
- """Construct and returns the paths of stdout/stderr files."""
- cmd, *_ = parse_command(command)
-
- return ShellCommand.get_stdout_stderr_paths(base_log_path, cmd)
-
-
-def execute_command( # pylint: disable=invalid-name
- command: str,
- cwd: Path,
- bg: bool = False,
- shell: str = "bash",
- out: Any = None,
- err: Any = None,
-) -> RunningCommand:
- """Execute shell command."""
- cmd, *args = parse_command(command, shell)
-
- search_paths = None
- if cmd != shell and (cwd / cmd).is_file():
- search_paths = [cwd]
-
- return ShellCommand().run(
- cmd, *args, _cwd=cwd, _bg=bg, _search_paths=search_paths, _out=out, _err=err
- )
-
-
-def is_shell_script(cmd: str) -> bool:
- """Check if command is shell script."""
- return cmd.endswith(".sh")
-
-
-def run_and_wait(
- command: str,
- cwd: Path,
- terminate_on_error: bool = False,
- out: Any = None,
- err: Any = None,
-) -> Tuple[int, bytearray, bytearray]:
- """
- Run command and wait while it is executing.
-
- Returns a tuple: (exit_code, stdout, stderr)
- """
- running_cmd: Optional[RunningCommand] = None
- try:
- running_cmd = execute_command(command, cwd, bg=True, out=out, err=err)
- return running_cmd.exit_code, running_cmd.stdout, running_cmd.stderr
- except ErrorReturnCode as cmd_failed:
- raise CommandFailedException() from cmd_failed
- except Exception as error:
- is_running = running_cmd is not None and running_cmd.is_alive()
- if terminate_on_error and is_running:
- print("Terminating ...")
- terminate_command(running_cmd)
-
- raise error
-
-
-def terminate_command(
- running_cmd: RunningCommand,
- wait: bool = True,
- wait_period: float = 0.5,
- number_of_attempts: int = 20,
-) -> None:
- """Terminate running command."""
- try:
- running_cmd.process.signal_group(signal.SIGINT)
- if wait:
- for _ in range(number_of_attempts):
- time.sleep(wait_period)
- if not running_cmd.is_alive():
- return
- print(
- "Unable to terminate process {}. Sending SIGTERM...".format(
- running_cmd.process.pid
- )
- )
- running_cmd.process.signal_group(signal.SIGTERM)
- except ProcessLookupError:
- pass
-
-
-def terminate_external_process(
- process: psutil.Process,
- wait_period: float = 0.5,
- number_of_attempts: int = 20,
- wait_for_termination: float = 5.0,
-) -> None:
- """Terminate external process."""
- try:
- process.terminate()
- for _ in range(number_of_attempts):
- if not process.is_running():
- return
- time.sleep(wait_period)
-
- if process.is_running():
- process.terminate()
- time.sleep(wait_for_termination)
- except psutil.Error:
- print("Unable to terminate process")
-
-
-class ProcessInfo(NamedTuple):
- """Process information."""
-
- name: str
- executable: str
- cwd: str
- pid: int
-
-
-def save_process_info(pid: int, pid_file: Path) -> None:
- """Save process information to file."""
- try:
- parent = psutil.Process(pid)
- children = parent.children(recursive=True)
- family = [parent] + children
-
- with open(pid_file, "w", encoding="utf-8") as file:
- csv_writer = csv.writer(file)
- for member in family:
- process_info = ProcessInfo(
- name=member.name(),
- executable=member.exe(),
- cwd=member.cwd(),
- pid=member.pid,
- )
- csv_writer.writerow(process_info)
- except psutil.NoSuchProcess:
- # if process does not exist or finishes before
- # function call then nothing could be saved
- # just ignore this exception and exit
- pass
-
-
-def read_process_info(pid_file: Path) -> List[ProcessInfo]:
- """Read information about previous system processes."""
- if not pid_file.is_file():
- return []
-
- result = []
- with open(pid_file, encoding="utf-8") as file:
- csv_reader = csv.reader(file)
- for row in csv_reader:
- name, executable, cwd, pid = row
- result.append(
- ProcessInfo(name=name, executable=executable, cwd=cwd, pid=int(pid))
- )
-
- return result
-
-
-def print_command_stdout(command: RunningCommand) -> None:
- """Print the stdout of a command.
-
- The command has 2 states: running and done.
- If the command is running, the output is taken by the running process.
- If the command has ended its execution, the stdout is taken from stdout
- property
- """
- if command.is_alive():
- while True:
- try:
- print(command.next(), end="")
- except StopIteration:
- break
- else:
- print(command.stdout)