aboutsummaryrefslogtreecommitdiff
path: root/src/aiet
diff options
context:
space:
mode:
Diffstat (limited to 'src/aiet')
-rw-r--r--src/aiet/__init__.py7
-rw-r--r--src/aiet/backend/__init__.py3
-rw-r--r--src/aiet/backend/application.py187
-rw-r--r--src/aiet/backend/common.py532
-rw-r--r--src/aiet/backend/config.py107
-rw-r--r--src/aiet/backend/controller.py134
-rw-r--r--src/aiet/backend/execution.py859
-rw-r--r--src/aiet/backend/output_parser.py176
-rw-r--r--src/aiet/backend/protocol.py325
-rw-r--r--src/aiet/backend/source.py209
-rw-r--r--src/aiet/backend/system.py289
-rw-r--r--src/aiet/backend/tool.py109
-rw-r--r--src/aiet/cli/__init__.py28
-rw-r--r--src/aiet/cli/application.py362
-rw-r--r--src/aiet/cli/common.py173
-rw-r--r--src/aiet/cli/completion.py72
-rw-r--r--src/aiet/cli/system.py122
-rw-r--r--src/aiet/cli/tool.py143
-rw-r--r--src/aiet/main.py13
-rw-r--r--src/aiet/resources/applications/.gitignore6
-rw-r--r--src/aiet/resources/systems/.gitignore6
-rw-r--r--src/aiet/resources/tools/vela/aiet-config.json73
-rw-r--r--src/aiet/resources/tools/vela/aiet-config.json.license3
-rw-r--r--src/aiet/resources/tools/vela/check_model.py75
-rw-r--r--src/aiet/resources/tools/vela/run_vela.py65
-rw-r--r--src/aiet/resources/tools/vela/vela.ini53
-rw-r--r--src/aiet/utils/__init__.py3
-rw-r--r--src/aiet/utils/fs.py116
-rw-r--r--src/aiet/utils/helpers.py17
-rw-r--r--src/aiet/utils/proc.py283
30 files changed, 4550 insertions, 0 deletions
diff --git a/src/aiet/__init__.py b/src/aiet/__init__.py
new file mode 100644
index 0000000..49304b1
--- /dev/null
+++ b/src/aiet/__init__.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Init of aiet."""
+import pkg_resources
+
+
+__version__ = pkg_resources.get_distribution("mlia").version
diff --git a/src/aiet/backend/__init__.py b/src/aiet/backend/__init__.py
new file mode 100644
index 0000000..3d60372
--- /dev/null
+++ b/src/aiet/backend/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Backend module."""
diff --git a/src/aiet/backend/application.py b/src/aiet/backend/application.py
new file mode 100644
index 0000000..f6ef815
--- /dev/null
+++ b/src/aiet/backend/application.py
@@ -0,0 +1,187 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Application backend module."""
+import re
+from pathlib import Path
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from aiet.backend.common import Backend
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import DataPaths
+from aiet.backend.common import get_backend_configs
+from aiet.backend.common import get_backend_directories
+from aiet.backend.common import load_application_or_tool_configs
+from aiet.backend.common import load_config
+from aiet.backend.common import remove_backend
+from aiet.backend.config import ApplicationConfig
+from aiet.backend.config import ExtendedApplicationConfig
+from aiet.backend.source import create_destination_and_install
+from aiet.backend.source import get_source
+from aiet.utils.fs import get_resources
+
+
+def get_available_application_directory_names() -> List[str]:
+ """Return a list of directory names for all available applications."""
+ return [entry.name for entry in get_backend_directories("applications")]
+
+
+def get_available_applications() -> List["Application"]:
+ """Return a list with all available applications."""
+ available_applications = []
+ for config_json in get_backend_configs("applications"):
+ config_entries = cast(List[ExtendedApplicationConfig], load_config(config_json))
+ for config_entry in config_entries:
+ config_entry["config_location"] = config_json.parent.absolute()
+ applications = load_applications(config_entry)
+ available_applications += applications
+
+ return sorted(available_applications, key=lambda application: application.name)
+
+
+def get_application(
+ application_name: str, system_name: Optional[str] = None
+) -> List["Application"]:
+ """Return a list of application instances with provided name."""
+ return [
+ application
+ for application in get_available_applications()
+ if application.name == application_name
+ and (not system_name or application.can_run_on(system_name))
+ ]
+
+
+def install_application(source_path: Path) -> None:
+ """Install application."""
+ try:
+ source = get_source(source_path)
+ config = cast(List[ExtendedApplicationConfig], source.config())
+ applications_to_install = [
+ s for entry in config for s in load_applications(entry)
+ ]
+ except Exception as error:
+ raise ConfigurationException("Unable to read application definition") from error
+
+ if not applications_to_install:
+ raise ConfigurationException("No application definition found")
+
+ available_applications = get_available_applications()
+ already_installed = [
+ s for s in applications_to_install if s in available_applications
+ ]
+ if already_installed:
+ names = {application.name for application in already_installed}
+ raise ConfigurationException(
+ "Applications [{}] are already installed".format(",".join(names))
+ )
+
+ create_destination_and_install(source, get_resources("applications"))
+
+
+def remove_application(directory_name: str) -> None:
+ """Remove application directory."""
+ remove_backend(directory_name, "applications")
+
+
+def get_unique_application_names(system_name: Optional[str] = None) -> List[str]:
+ """Extract a list of unique application names of all application available."""
+ return list(
+ set(
+ application.name
+ for application in get_available_applications()
+ if not system_name or application.can_run_on(system_name)
+ )
+ )
+
+
+class Application(Backend):
+ """Class for representing a single application component."""
+
+ def __init__(self, config: ApplicationConfig) -> None:
+ """Construct a Application instance from a dict."""
+ super().__init__(config)
+
+ self.supported_systems = config.get("supported_systems", [])
+ self.deploy_data = config.get("deploy_data", [])
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, Application):
+ return False
+
+ return (
+ super().__eq__(other)
+ and self.name == other.name
+ and set(self.supported_systems) == set(other.supported_systems)
+ )
+
+ def can_run_on(self, system_name: str) -> bool:
+ """Check if the application can run on the system passed as argument."""
+ return system_name in self.supported_systems
+
+ def get_deploy_data(self) -> List[DataPaths]:
+ """Validate and return data specified in the config file."""
+ if self.config_location is None:
+ raise ConfigurationException(
+ "Unable to get application {} config location".format(self.name)
+ )
+
+ deploy_data = []
+ for item in self.deploy_data:
+ src, dst = item
+ src_full_path = self.config_location / src
+ assert src_full_path.exists(), "{} does not exists".format(src_full_path)
+ deploy_data.append(DataPaths(src_full_path, dst))
+ return deploy_data
+
+ def get_details(self) -> Dict[str, Any]:
+ """Return dictionary with information about the Application instance."""
+ output = {
+ "type": "application",
+ "name": self.name,
+ "description": self.description,
+ "supported_systems": self.supported_systems,
+ "commands": self._get_command_details(),
+ }
+
+ return output
+
+ def remove_unused_params(self) -> None:
+ """Remove unused params in commands.
+
+ After merging default and system related configuration application
+ could have parameters that are not being used in commands. They
+ should be removed.
+ """
+ for command in self.commands.values():
+ indexes_or_aliases = [
+ m
+ for cmd_str in command.command_strings
+ for m in re.findall(r"{user_params:(?P<index_or_alias>\w+)}", cmd_str)
+ ]
+
+ only_aliases = all(not item.isnumeric() for item in indexes_or_aliases)
+ if only_aliases:
+ used_params = [
+ param
+ for param in command.params
+ if param.alias in indexes_or_aliases
+ ]
+ command.params = used_params
+
+
+def load_applications(config: ExtendedApplicationConfig) -> List[Application]:
+ """Load application.
+
+ Application configuration could contain different parameters/commands for different
+ supported systems. For each supported system this function will return separate
+ Application instance with appropriate configuration.
+ """
+ configs = load_application_or_tool_configs(config, ApplicationConfig)
+ applications = [Application(cfg) for cfg in configs]
+ for application in applications:
+ application.remove_unused_params()
+ return applications
diff --git a/src/aiet/backend/common.py b/src/aiet/backend/common.py
new file mode 100644
index 0000000..b887ee7
--- /dev/null
+++ b/src/aiet/backend/common.py
@@ -0,0 +1,532 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contain all common functions for the backends."""
+import json
+import logging
+import re
+from abc import ABC
+from collections import Counter
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Final
+from typing import IO
+from typing import Iterable
+from typing import List
+from typing import Match
+from typing import NamedTuple
+from typing import Optional
+from typing import Pattern
+from typing import Tuple
+from typing import Type
+from typing import Union
+
+from aiet.backend.config import BackendConfig
+from aiet.backend.config import BaseBackendConfig
+from aiet.backend.config import NamedExecutionConfig
+from aiet.backend.config import UserParamConfig
+from aiet.backend.config import UserParamsConfig
+from aiet.utils.fs import get_resources
+from aiet.utils.fs import remove_resource
+from aiet.utils.fs import ResourceType
+
+
+AIET_CONFIG_FILE: Final[str] = "aiet-config.json"
+
+
+class ConfigurationException(Exception):
+ """Configuration exception."""
+
+
+def get_backend_config(dir_path: Path) -> Path:
+ """Get path to backendir configuration file."""
+ return dir_path / AIET_CONFIG_FILE
+
+
+def get_backend_configs(resource_type: ResourceType) -> Iterable[Path]:
+ """Get path to the backend configs for provided resource_type."""
+ return (
+ get_backend_config(entry) for entry in get_backend_directories(resource_type)
+ )
+
+
+def get_backend_directories(resource_type: ResourceType) -> Iterable[Path]:
+ """Get path to the backend directories for provided resource_type."""
+ return (
+ entry
+ for entry in get_resources(resource_type).iterdir()
+ if is_backend_directory(entry)
+ )
+
+
+def is_backend_directory(dir_path: Path) -> bool:
+ """Check if path is backend's configuration directory."""
+ return dir_path.is_dir() and get_backend_config(dir_path).is_file()
+
+
+def remove_backend(directory_name: str, resource_type: ResourceType) -> None:
+ """Remove backend with provided type and directory_name."""
+ if not directory_name:
+ raise Exception("No directory name provided")
+
+ remove_resource(directory_name, resource_type)
+
+
+def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig:
+ """Return a loaded json file."""
+ if config is None:
+ raise Exception("Unable to read config")
+
+ if isinstance(config, Path):
+ with config.open() as json_file:
+ return cast(BackendConfig, json.load(json_file))
+
+ return cast(BackendConfig, json.load(config))
+
+
+def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]:
+ """Split the parameter string in name and optional value.
+
+ It manages the following cases:
+ --param=1 -> --param, 1
+ --param 1 -> --param, 1
+ --flag -> --flag, None
+ """
+ data = re.split(" |=", parameter)
+ if len(data) == 1:
+ param_name = data[0]
+ param_value = None
+ else:
+ param_name = " ".join(data[0:-1])
+ param_value = data[-1]
+ return param_name, param_value
+
+
+class DataPaths(NamedTuple):
+ """DataPaths class."""
+
+ src: Path
+ dst: str
+
+
+class Backend(ABC):
+ """Backend class."""
+
+ # pylint: disable=too-many-instance-attributes
+
+ def __init__(self, config: BaseBackendConfig):
+ """Initialize backend."""
+ name = config.get("name")
+ if not name:
+ raise ConfigurationException("Name is empty")
+
+ self.name = name
+ self.description = config.get("description", "")
+ self.config_location = config.get("config_location")
+ self.variables = config.get("variables", {})
+ self.build_dir = config.get("build_dir")
+ self.lock = config.get("lock", False)
+ if self.build_dir:
+ self.build_dir = self._substitute_variables(self.build_dir)
+ self.annotations = config.get("annotations", {})
+
+ self._parse_commands_and_params(config)
+
+ def validate_parameter(self, command_name: str, parameter: str) -> bool:
+ """Validate the parameter string against the application configuration.
+
+ We take the parameter string, extract the parameter name/value and
+ check them against the current configuration.
+ """
+ param_name, param_value = parse_raw_parameter(parameter)
+ valid_param_name = valid_param_value = False
+
+ command = self.commands.get(command_name)
+ if not command:
+ raise AttributeError("Unknown command: '{}'".format(command_name))
+
+ # Iterate over all available parameters until we have a match.
+ for param in command.params:
+ if self._same_parameter(param_name, param):
+ valid_param_name = True
+ # This is a non-empty list
+ if param.values:
+ # We check if the value is allowed in the configuration
+ valid_param_value = param_value in param.values
+ else:
+ # In this case we don't validate the value and accept
+ # whatever we have set.
+ valid_param_value = True
+ break
+
+ return valid_param_name and valid_param_value
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, Backend):
+ return False
+
+ return (
+ self.name == other.name
+ and self.description == other.description
+ and self.commands == other.commands
+ )
+
+ def __repr__(self) -> str:
+ """Represent the Backend instance by its name."""
+ return self.name
+
+ def _parse_commands_and_params(self, config: BaseBackendConfig) -> None:
+ """Parse commands and user parameters."""
+ self.commands: Dict[str, Command] = {}
+
+ commands = config.get("commands")
+ if commands:
+ params = config.get("user_params")
+
+ for command_name in commands.keys():
+ command_params = self._parse_params(params, command_name)
+ command_strings = [
+ self._substitute_variables(cmd)
+ for cmd in commands.get(command_name, [])
+ ]
+ self.commands[command_name] = Command(command_strings, command_params)
+
+ def _substitute_variables(self, str_val: str) -> str:
+ """Substitute variables in string.
+
+ Variables is being substituted at backend's creation stage because
+ they could contain references to other params which will be
+ resolved later.
+ """
+ if not str_val:
+ return str_val
+
+ var_pattern: Final[Pattern] = re.compile(r"{variables:(?P<var_name>\w+)}")
+
+ def var_value(match: Match) -> str:
+ var_name = match["var_name"]
+ if var_name not in self.variables:
+ raise ConfigurationException("Unknown variable {}".format(var_name))
+
+ return self.variables[var_name]
+
+ return var_pattern.sub(var_value, str_val) # type: ignore
+
+ @classmethod
+ def _parse_params(
+ cls, params: Optional[UserParamsConfig], command: str
+ ) -> List["Param"]:
+ if not params:
+ return []
+
+ return [cls._parse_param(p) for p in params.get(command, [])]
+
+ @classmethod
+ def _parse_param(cls, param: UserParamConfig) -> "Param":
+ """Parse a single parameter."""
+ name = param.get("name")
+ if name is not None and not name:
+ raise ConfigurationException("Parameter has an empty 'name' attribute.")
+ values = param.get("values", None)
+ default_value = param.get("default_value", None)
+ description = param.get("description", "")
+ alias = param.get("alias")
+
+ return Param(
+ name=name,
+ description=description,
+ values=values,
+ default_value=default_value,
+ alias=alias,
+ )
+
+ def _get_command_details(self) -> Dict:
+ command_details = {
+ command_name: command.get_details()
+ for command_name, command in self.commands.items()
+ }
+ return command_details
+
+ def _get_user_param_value(
+ self, user_params: List[str], param: "Param"
+ ) -> Optional[str]:
+ """Get the user-specified value of a parameter."""
+ for user_param in user_params:
+ user_param_name, user_param_value = parse_raw_parameter(user_param)
+ if user_param_name == param.name:
+ warn_message = (
+ "The direct use of parameter name is deprecated"
+ " and might be removed in the future.\n"
+ f"Please use alias '{param.alias}' instead of "
+ "'{user_param_name}' to provide the parameter."
+ )
+ logging.warning(warn_message)
+
+ if self._same_parameter(user_param_name, param):
+ return user_param_value
+
+ return None
+
+ @staticmethod
+ def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool:
+ """Compare user parameter name with param name or alias."""
+ # Strip the "=" sign in the param_name. This is needed just for
+ # comparison with the parameters passed by the user.
+ # The equal sign needs to be honoured when re-building the
+ # parameter back.
+ param_name = None if not param.name else param.name.rstrip("=")
+ return user_param_name_or_alias in [param_name, param.alias]
+
+ def resolved_parameters(
+ self, command_name: str, user_params: List[str]
+ ) -> List[Tuple[Optional[str], "Param"]]:
+ """Return list of parameters with values."""
+ result: List[Tuple[Optional[str], "Param"]] = []
+ command = self.commands.get(command_name)
+ if not command:
+ return result
+
+ for param in command.params:
+ value = self._get_user_param_value(user_params, param)
+ if not value:
+ value = param.default_value
+ result.append((value, param))
+
+ return result
+
+ def build_command(
+ self,
+ command_name: str,
+ user_params: List[str],
+ param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str],
+ ) -> List[str]:
+ """
+ Return a list of executable command strings.
+
+ Given a command and associated parameters, returns a list of executable command
+ strings.
+ """
+ command = self.commands.get(command_name)
+ if not command:
+ raise ConfigurationException(
+ "Command '{}' could not be found.".format(command_name)
+ )
+
+ commands_to_run = []
+
+ params_values = self.resolved_parameters(command_name, user_params)
+ for cmd_str in command.command_strings:
+ cmd_str = resolve_all_parameters(
+ cmd_str, param_resolver, command_name, params_values
+ )
+ commands_to_run.append(cmd_str)
+
+ return commands_to_run
+
+
+class Param:
+ """Class for representing a generic application parameter."""
+
+ def __init__( # pylint: disable=too-many-arguments
+ self,
+ name: Optional[str],
+ description: str,
+ values: Optional[List[str]] = None,
+ default_value: Optional[str] = None,
+ alias: Optional[str] = None,
+ ) -> None:
+ """Construct a Param instance."""
+ if not name and not alias:
+ raise ConfigurationException(
+ "Either name, alias or both must be set to identify a parameter."
+ )
+ self.name = name
+ self.values = values
+ self.description = description
+ self.default_value = default_value
+ self.alias = alias
+
+ def get_details(self) -> Dict:
+ """Return a dictionary with all relevant information of a Param."""
+ return {key: value for key, value in self.__dict__.items() if value}
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, Param):
+ return False
+
+ return (
+ self.name == other.name
+ and self.values == other.values
+ and self.default_value == other.default_value
+ and self.description == other.description
+ )
+
+
+class Command:
+ """Class for representing a command."""
+
+ def __init__(
+ self, command_strings: List[str], params: Optional[List[Param]] = None
+ ) -> None:
+ """Construct a Command instance."""
+ self.command_strings = command_strings
+
+ if params:
+ self.params = params
+ else:
+ self.params = []
+
+ self._validate()
+
+ def _validate(self) -> None:
+ """Validate command."""
+ if not self.params:
+ return
+
+ aliases = [param.alias for param in self.params if param.alias is not None]
+ repeated_aliases = [
+ alias for alias, count in Counter(aliases).items() if count > 1
+ ]
+
+ if repeated_aliases:
+ raise ConfigurationException(
+ "Non unique aliases {}".format(", ".join(repeated_aliases))
+ )
+
+ both_name_and_alias = [
+ param.name
+ for param in self.params
+ if param.name in aliases and param.name != param.alias
+ ]
+ if both_name_and_alias:
+ raise ConfigurationException(
+ "Aliases {} could not be used as parameter name".format(
+ ", ".join(both_name_and_alias)
+ )
+ )
+
+ def get_details(self) -> Dict:
+ """Return a dictionary with all relevant information of a Command."""
+ output = {
+ "command_strings": self.command_strings,
+ "user_params": [param.get_details() for param in self.params],
+ }
+ return output
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, Command):
+ return False
+
+ return (
+ self.command_strings == other.command_strings
+ and self.params == other.params
+ )
+
+
+def resolve_all_parameters(
+ str_val: str,
+ param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str],
+ command_name: Optional[str] = None,
+ params_values: Optional[List[Tuple[Optional[str], Param]]] = None,
+) -> str:
+ """Resolve all parameters in the string."""
+ if not str_val:
+ return str_val
+
+ param_pattern: Final[Pattern] = re.compile(r"{(?P<param_name>[\w.:]+)}")
+ while param_pattern.findall(str_val):
+ str_val = param_pattern.sub(
+ lambda m: param_resolver(
+ m["param_name"], command_name or "", params_values or []
+ ),
+ str_val,
+ )
+ return str_val
+
+
+def load_application_or_tool_configs(
+ config: Any,
+ config_type: Type[Any],
+ is_system_required: bool = True,
+) -> Any:
+ """Get one config for each system supported by the application/tool.
+
+ The configuration could contain different parameters/commands for different
+ supported systems. For each supported system this function will return separate
+ config with appropriate configuration.
+ """
+ merged_configs = []
+ supported_systems: Optional[List[NamedExecutionConfig]] = config.get(
+ "supported_systems"
+ )
+ if not supported_systems:
+ if is_system_required:
+ raise ConfigurationException("No supported systems definition provided")
+ # Create an empty system to be used in the parsing below
+ supported_systems = [cast(NamedExecutionConfig, {})]
+
+ default_user_params = config.get("user_params", {})
+
+ def merge_config(system: NamedExecutionConfig) -> Any:
+ system_name = system.get("name")
+ if not system_name and is_system_required:
+ raise ConfigurationException(
+ "Unable to read supported system definition, name is missed"
+ )
+
+ merged_config = config_type(**config)
+ merged_config["supported_systems"] = [system_name] if system_name else []
+ # merge default configuration and specific to the system
+ merged_config["commands"] = {
+ **config.get("commands", {}),
+ **system.get("commands", {}),
+ }
+
+ params = {}
+ tool_user_params = system.get("user_params", {})
+ command_names = tool_user_params.keys() | default_user_params.keys()
+ for command_name in command_names:
+ if command_name not in merged_config["commands"]:
+ continue
+
+ params_default = default_user_params.get(command_name, [])
+ params_tool = tool_user_params.get(command_name, [])
+ if not params_default or not params_tool:
+ params[command_name] = params_tool or params_default
+ if params_default and params_tool:
+ if any(not p.get("alias") for p in params_default):
+ raise ConfigurationException(
+ "Default parameters for command {} should have aliases".format(
+ command_name
+ )
+ )
+ if any(not p.get("alias") for p in params_tool):
+ raise ConfigurationException(
+ "{} parameters for command {} should have aliases".format(
+ system_name, command_name
+ )
+ )
+
+ merged_by_alias = {
+ **{p.get("alias"): p for p in params_default},
+ **{p.get("alias"): p for p in params_tool},
+ }
+ params[command_name] = list(merged_by_alias.values())
+
+ merged_config["user_params"] = params
+ merged_config["build_dir"] = system.get("build_dir", config.get("build_dir"))
+ merged_config["lock"] = system.get("lock", config.get("lock", False))
+ merged_config["variables"] = {
+ **config.get("variables", {}),
+ **system.get("variables", {}),
+ }
+ return merged_config
+
+ merged_configs = [merge_config(system) for system in supported_systems]
+
+ return merged_configs
diff --git a/src/aiet/backend/config.py b/src/aiet/backend/config.py
new file mode 100644
index 0000000..dd42012
--- /dev/null
+++ b/src/aiet/backend/config.py
@@ -0,0 +1,107 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contain definition of backend configuration."""
+from pathlib import Path
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Tuple
+from typing import TypedDict
+from typing import Union
+
+
+class UserParamConfig(TypedDict, total=False):
+ """User parameter configuration."""
+
+ name: Optional[str]
+ default_value: str
+ values: List[str]
+ description: str
+ alias: str
+
+
+UserParamsConfig = Dict[str, List[UserParamConfig]]
+
+
+class ExecutionConfig(TypedDict, total=False):
+ """Execution configuration."""
+
+ commands: Dict[str, List[str]]
+ user_params: UserParamsConfig
+ build_dir: str
+ variables: Dict[str, str]
+ lock: bool
+
+
+class NamedExecutionConfig(ExecutionConfig):
+ """Execution configuration with name."""
+
+ name: str
+
+
+class BaseBackendConfig(ExecutionConfig, total=False):
+ """Base backend configuration."""
+
+ name: str
+ description: str
+ config_location: Path
+ annotations: Dict[str, Union[str, List[str]]]
+
+
+class ApplicationConfig(BaseBackendConfig, total=False):
+ """Application configuration."""
+
+ supported_systems: List[str]
+ deploy_data: List[Tuple[str, str]]
+
+
+class ExtendedApplicationConfig(BaseBackendConfig, total=False):
+ """Extended application configuration."""
+
+ supported_systems: List[NamedExecutionConfig]
+ deploy_data: List[Tuple[str, str]]
+
+
+class ProtocolConfig(TypedDict, total=False):
+ """Protocol config."""
+
+ protocol: Literal["local", "ssh"]
+
+
+class SSHConfig(ProtocolConfig, total=False):
+ """SSH configuration."""
+
+ username: str
+ password: str
+ hostname: str
+ port: str
+
+
+class LocalProtocolConfig(ProtocolConfig, total=False):
+ """Local protocol config."""
+
+
+class SystemConfig(BaseBackendConfig, total=False):
+ """System configuration."""
+
+ data_transfer: Union[SSHConfig, LocalProtocolConfig]
+ reporting: Dict[str, Dict]
+
+
+class ToolConfig(BaseBackendConfig, total=False):
+ """Tool configuration."""
+
+ supported_systems: List[str]
+
+
+class ExtendedToolConfig(BaseBackendConfig, total=False):
+ """Extended tool configuration."""
+
+ supported_systems: List[NamedExecutionConfig]
+
+
+BackendItemConfig = Union[ApplicationConfig, SystemConfig, ToolConfig]
+BackendConfig = Union[
+ List[ExtendedApplicationConfig], List[SystemConfig], List[ToolConfig]
+]
diff --git a/src/aiet/backend/controller.py b/src/aiet/backend/controller.py
new file mode 100644
index 0000000..2650902
--- /dev/null
+++ b/src/aiet/backend/controller.py
@@ -0,0 +1,134 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Controller backend module."""
+import time
+from pathlib import Path
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import psutil
+import sh
+
+from aiet.backend.common import ConfigurationException
+from aiet.utils.fs import read_file_as_string
+from aiet.utils.proc import execute_command
+from aiet.utils.proc import get_stdout_stderr_paths
+from aiet.utils.proc import read_process_info
+from aiet.utils.proc import save_process_info
+from aiet.utils.proc import terminate_command
+from aiet.utils.proc import terminate_external_process
+
+
+class SystemController:
+ """System controller class."""
+
+ def __init__(self) -> None:
+ """Create new instance of service controller."""
+ self.cmd: Optional[sh.RunningCommand] = None
+ self.out_path: Optional[Path] = None
+ self.err_path: Optional[Path] = None
+
+ def before_start(self) -> None:
+ """Run actions before system start."""
+
+ def after_start(self) -> None:
+ """Run actions after system start."""
+
+ def start(self, commands: List[str], cwd: Path) -> None:
+ """Start system."""
+ if not isinstance(cwd, Path) or not cwd.is_dir():
+ raise ConfigurationException("Wrong working directory {}".format(cwd))
+
+ if len(commands) != 1:
+ raise ConfigurationException("System should have only one command to run")
+
+ startup_command = commands[0]
+ if not startup_command:
+ raise ConfigurationException("No startup command provided")
+
+ self.before_start()
+
+ self.out_path, self.err_path = get_stdout_stderr_paths(startup_command)
+
+ self.cmd = execute_command(
+ startup_command,
+ cwd,
+ bg=True,
+ out=str(self.out_path),
+ err=str(self.err_path),
+ )
+
+ self.after_start()
+
+ def stop(
+ self, wait: bool = False, wait_period: float = 0.5, number_of_attempts: int = 20
+ ) -> None:
+ """Stop system."""
+ if self.cmd is not None and self.is_running():
+ terminate_command(self.cmd, wait, wait_period, number_of_attempts)
+
+ def is_running(self) -> bool:
+ """Check if underlying process is running."""
+ return self.cmd is not None and self.cmd.is_alive()
+
+ def get_output(self) -> Tuple[str, str]:
+ """Return application output."""
+ if self.cmd is None or self.out_path is None or self.err_path is None:
+ return ("", "")
+
+ return (read_file_as_string(self.out_path), read_file_as_string(self.err_path))
+
+
+class SystemControllerSingleInstance(SystemController):
+ """System controller with support of system's single instance."""
+
+ def __init__(self, pid_file_path: Optional[Path] = None) -> None:
+ """Create new instance of the service controller."""
+ super().__init__()
+ self.pid_file_path = pid_file_path
+
+ def before_start(self) -> None:
+ """Run actions before system start."""
+ self._check_if_previous_instance_is_running()
+
+ def after_start(self) -> None:
+ """Run actions after system start."""
+ self._save_process_info()
+
+ def _check_if_previous_instance_is_running(self) -> None:
+ """Check if another instance of the system is running."""
+ process_info = read_process_info(self._pid_file())
+
+ for item in process_info:
+ try:
+ process = psutil.Process(item.pid)
+ same_process = (
+ process.name() == item.name
+ and process.exe() == item.executable
+ and process.cwd() == item.cwd
+ )
+ if same_process:
+ print(
+ "Stopping previous instance of the system [{}]".format(item.pid)
+ )
+ terminate_external_process(process)
+ except psutil.NoSuchProcess:
+ pass
+
+ def _save_process_info(self, wait_period: float = 2) -> None:
+ """Save information about system's processes."""
+ if self.cmd is None or not self.is_running():
+ return
+
+ # give some time for the system to start
+ time.sleep(wait_period)
+
+ save_process_info(self.cmd.process.pid, self._pid_file())
+
+ def _pid_file(self) -> Path:
+ """Return path to file which is used for saving process info."""
+ if not self.pid_file_path:
+ raise Exception("No pid file path presented")
+
+ return self.pid_file_path
diff --git a/src/aiet/backend/execution.py b/src/aiet/backend/execution.py
new file mode 100644
index 0000000..1653ee2
--- /dev/null
+++ b/src/aiet/backend/execution.py
@@ -0,0 +1,859 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Application execution module."""
+import itertools
+import json
+import random
+import re
+import string
+import sys
+import time
+import warnings
+from collections import defaultdict
+from contextlib import contextmanager
+from contextlib import ExitStack
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import ContextManager
+from typing import Dict
+from typing import Generator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TypedDict
+from typing import Union
+
+from filelock import FileLock
+from filelock import Timeout
+
+from aiet.backend.application import Application
+from aiet.backend.application import get_application
+from aiet.backend.common import Backend
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import DataPaths
+from aiet.backend.common import Param
+from aiet.backend.common import parse_raw_parameter
+from aiet.backend.common import resolve_all_parameters
+from aiet.backend.output_parser import Base64OutputParser
+from aiet.backend.output_parser import OutputParser
+from aiet.backend.output_parser import RegexOutputParser
+from aiet.backend.system import ControlledSystem
+from aiet.backend.system import get_system
+from aiet.backend.system import StandaloneSystem
+from aiet.backend.system import System
+from aiet.backend.tool import get_tool
+from aiet.backend.tool import Tool
+from aiet.utils.fs import recreate_directory
+from aiet.utils.fs import remove_directory
+from aiet.utils.fs import valid_for_filename
+from aiet.utils.proc import run_and_wait
+
+
+class AnotherInstanceIsRunningException(Exception):
+ """Concurrent execution error."""
+
+
+class ConnectionException(Exception):
+ """Connection exception."""
+
+
+class ExecutionParams(TypedDict, total=False):
+ """Execution parameters."""
+
+ disable_locking: bool
+ unique_build_dir: bool
+
+
+class ExecutionContext:
+ """Command execution context."""
+
+ # pylint: disable=too-many-arguments,too-many-instance-attributes
+ def __init__(
+ self,
+ app: Union[Application, Tool],
+ app_params: List[str],
+ system: Optional[System],
+ system_params: List[str],
+ custom_deploy_data: Optional[List[DataPaths]] = None,
+ execution_params: Optional[ExecutionParams] = None,
+ report_file: Optional[Path] = None,
+ ):
+ """Init execution context."""
+ self.app = app
+ self.app_params = app_params
+ self.custom_deploy_data = custom_deploy_data or []
+ self.system = system
+ self.system_params = system_params
+ self.execution_params = execution_params or ExecutionParams()
+ self.report_file = report_file
+
+ self.reporter: Optional[Reporter]
+ if self.report_file:
+ # Create reporter with output parsers
+ parsers: List[OutputParser] = []
+ if system and system.reporting:
+ # Add RegexOutputParser, if it is configured in the system
+ parsers.append(RegexOutputParser("system", system.reporting["regex"]))
+ # Add Base64 parser for applications
+ parsers.append(Base64OutputParser("application"))
+ self.reporter = Reporter(parsers=parsers)
+ else:
+ self.reporter = None # No reporter needed.
+
+ self.param_resolver = ParamResolver(self)
+ self._resolved_build_dir: Optional[Path] = None
+
+ @property
+ def is_deploy_needed(self) -> bool:
+ """Check if application requires data deployment."""
+ if isinstance(self.app, Application):
+ return (
+ len(self.app.get_deploy_data()) > 0 or len(self.custom_deploy_data) > 0
+ )
+ return False
+
+ @property
+ def is_locking_required(self) -> bool:
+ """Return true if any form of locking required."""
+ return not self._disable_locking() and (
+ self.app.lock or (self.system is not None and self.system.lock)
+ )
+
+ @property
+ def is_build_required(self) -> bool:
+ """Return true if application build required."""
+ return "build" in self.app.commands
+
+ @property
+ def is_unique_build_dir_required(self) -> bool:
+ """Return true if unique build dir required."""
+ return self.execution_params.get("unique_build_dir", False)
+
+ def build_dir(self) -> Path:
+ """Return resolved application build dir."""
+ if self._resolved_build_dir is not None:
+ return self._resolved_build_dir
+
+ if (
+ not isinstance(self.app.config_location, Path)
+ or not self.app.config_location.is_dir()
+ ):
+ raise ConfigurationException(
+ "Application {} has wrong config location".format(self.app.name)
+ )
+
+ _build_dir = self.app.build_dir
+ if _build_dir:
+ _build_dir = resolve_all_parameters(_build_dir, self.param_resolver)
+
+ if not _build_dir:
+ raise ConfigurationException(
+ "No build directory defined for the app {}".format(self.app.name)
+ )
+
+ if self.is_unique_build_dir_required:
+ random_suffix = "".join(
+ random.choices(string.ascii_lowercase + string.digits, k=7)
+ )
+ _build_dir = "{}_{}".format(_build_dir, random_suffix)
+
+ self._resolved_build_dir = self.app.config_location / _build_dir
+ return self._resolved_build_dir
+
+ def _disable_locking(self) -> bool:
+ """Return true if locking should be disabled."""
+ return self.execution_params.get("disable_locking", False)
+
+
+class ParamResolver:
+ """Parameter resolver."""
+
+ def __init__(self, context: ExecutionContext):
+ """Init parameter resolver."""
+ self.ctx = context
+
+ @staticmethod
+ def resolve_user_params(
+ cmd_name: Optional[str],
+ index_or_alias: str,
+ resolved_params: Optional[List[Tuple[Optional[str], Param]]],
+ ) -> str:
+ """Resolve user params."""
+ if not cmd_name or resolved_params is None:
+ raise ConfigurationException("Unable to resolve user params")
+
+ param_value: Optional[str] = None
+ param: Optional[Param] = None
+
+ if index_or_alias.isnumeric():
+ i = int(index_or_alias)
+ if i not in range(len(resolved_params)):
+ raise ConfigurationException(
+ "Invalid index {} for user params of command {}".format(i, cmd_name)
+ )
+ param_value, param = resolved_params[i]
+ else:
+ for val, par in resolved_params:
+ if par.alias == index_or_alias:
+ param_value, param = val, par
+ break
+
+ if param is None:
+ raise ConfigurationException(
+ "No user parameter for command '{}' with alias '{}'.".format(
+ cmd_name, index_or_alias
+ )
+ )
+
+ if param_value:
+ # We need to handle to cases of parameters here:
+ # 1) Optional parameters (non-positional with a name and value)
+ # 2) Positional parameters (value only, no name needed)
+ # Default to empty strings for positional arguments
+ param_name = ""
+ separator = ""
+ if param.name is not None:
+ # A valid param name means we have an optional/non-positional argument:
+ # The separator is an empty string in case the param_name
+ # has an equal sign as we have to honour it.
+ # If the parameter doesn't end with an equal sign then a
+ # space character is injected to split the parameter name
+ # and its value
+ param_name = param.name
+ separator = "" if param.name.endswith("=") else " "
+
+ return "{param_name}{separator}{param_value}".format(
+ param_name=param_name,
+ separator=separator,
+ param_value=param_value,
+ )
+
+ if param.name is None:
+ raise ConfigurationException(
+ "Missing user parameter with alias '{}' for command '{}'.".format(
+ index_or_alias, cmd_name
+ )
+ )
+
+ return param.name # flag: just return the parameter name
+
+ def resolve_commands_and_params(
+ self, backend_type: str, cmd_name: str, return_params: bool, index_or_alias: str
+ ) -> str:
+ """Resolve command or command's param value."""
+ if backend_type == "system":
+ backend = cast(Backend, self.ctx.system)
+ backend_params = self.ctx.system_params
+ else: # Application or Tool backend
+ backend = cast(Backend, self.ctx.app)
+ backend_params = self.ctx.app_params
+
+ if cmd_name not in backend.commands:
+ raise ConfigurationException("Command {} not found".format(cmd_name))
+
+ if return_params:
+ params = backend.resolved_parameters(cmd_name, backend_params)
+ if index_or_alias.isnumeric():
+ i = int(index_or_alias)
+ if i not in range(len(params)):
+ raise ConfigurationException(
+ "Invalid parameter index {} for command {}".format(i, cmd_name)
+ )
+
+ param_value = params[i][0]
+ else:
+ param_value = None
+ for value, param in params:
+ if param.alias == index_or_alias:
+ param_value = value
+ break
+
+ if not param_value:
+ raise ConfigurationException(
+ (
+ "No value for parameter with index or alias {} of command {}"
+ ).format(index_or_alias, cmd_name)
+ )
+ return param_value
+
+ if not index_or_alias.isnumeric():
+ raise ConfigurationException("Bad command index {}".format(index_or_alias))
+
+ i = int(index_or_alias)
+ commands = backend.build_command(cmd_name, backend_params, self.param_resolver)
+ if i not in range(len(commands)):
+ raise ConfigurationException(
+ "Invalid index {} for command {}".format(i, cmd_name)
+ )
+
+ return commands[i]
+
+ def resolve_variables(self, backend_type: str, var_name: str) -> str:
+ """Resolve variable value."""
+ if backend_type == "system":
+ backend = cast(Backend, self.ctx.system)
+ else: # Application or Tool backend
+ backend = cast(Backend, self.ctx.app)
+
+ if var_name not in backend.variables:
+ raise ConfigurationException("Unknown variable {}".format(var_name))
+
+ return backend.variables[var_name]
+
+ def param_matcher(
+ self,
+ param_name: str,
+ cmd_name: Optional[str],
+ resolved_params: Optional[List[Tuple[Optional[str], Param]]],
+ ) -> str:
+ """Regexp to resolve a param from the param_name."""
+ # this pattern supports parameter names like "application.commands.run:0" and
+ # "system.commands.run.params:0"
+ # Note: 'software' is included for backward compatibility.
+ commands_and_params_match = re.match(
+ r"(?P<type>application|software|tool|system)[.]commands[.]"
+ r"(?P<name>\w+)"
+ r"(?P<params>[.]params|)[:]"
+ r"(?P<index_or_alias>\w+)",
+ param_name,
+ )
+
+ if commands_and_params_match:
+ backend_type, cmd_name, return_params, index_or_alias = (
+ commands_and_params_match["type"],
+ commands_and_params_match["name"],
+ commands_and_params_match["params"],
+ commands_and_params_match["index_or_alias"],
+ )
+ return self.resolve_commands_and_params(
+ backend_type, cmd_name, bool(return_params), index_or_alias
+ )
+
+ # Note: 'software' is included for backward compatibility.
+ variables_match = re.match(
+ r"(?P<type>application|software|tool|system)[.]variables:(?P<var_name>\w+)",
+ param_name,
+ )
+ if variables_match:
+ backend_type, var_name = (
+ variables_match["type"],
+ variables_match["var_name"],
+ )
+ return self.resolve_variables(backend_type, var_name)
+
+ user_params_match = re.match(r"user_params:(?P<index_or_alias>\w+)", param_name)
+ if user_params_match:
+ index_or_alias = user_params_match["index_or_alias"]
+ return self.resolve_user_params(cmd_name, index_or_alias, resolved_params)
+
+ raise ConfigurationException(
+ "Unable to resolve parameter {}".format(param_name)
+ )
+
+ def param_resolver(
+ self,
+ param_name: str,
+ cmd_name: Optional[str] = None,
+ resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None,
+ ) -> str:
+ """Resolve parameter value based on current execution context."""
+ # Note: 'software.*' is included for backward compatibility.
+ resolved_param = None
+ if param_name in ["application.name", "tool.name", "software.name"]:
+ resolved_param = self.ctx.app.name
+ elif param_name in [
+ "application.description",
+ "tool.description",
+ "software.description",
+ ]:
+ resolved_param = self.ctx.app.description
+ elif self.ctx.app.config_location and (
+ param_name
+ in ["application.config_dir", "tool.config_dir", "software.config_dir"]
+ ):
+ resolved_param = str(self.ctx.app.config_location.absolute())
+ elif self.ctx.app.build_dir and (
+ param_name
+ in ["application.build_dir", "tool.build_dir", "software.build_dir"]
+ ):
+ resolved_param = str(self.ctx.build_dir().absolute())
+ elif self.ctx.system is not None:
+ if param_name == "system.name":
+ resolved_param = self.ctx.system.name
+ elif param_name == "system.description":
+ resolved_param = self.ctx.system.description
+ elif param_name == "system.config_dir" and self.ctx.system.config_location:
+ resolved_param = str(self.ctx.system.config_location.absolute())
+
+ if not resolved_param:
+ resolved_param = self.param_matcher(param_name, cmd_name, resolved_params)
+ return resolved_param
+
+ def __call__(
+ self,
+ param_name: str,
+ cmd_name: Optional[str] = None,
+ resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None,
+ ) -> str:
+ """Resolve provided parameter."""
+ return self.param_resolver(param_name, cmd_name, resolved_params)
+
+
+class Reporter:
+ """Report metrics from the simulation output."""
+
+ def __init__(self, parsers: Optional[List[OutputParser]] = None) -> None:
+ """Create an empty reporter (i.e. no parsers registered)."""
+ self.parsers: List[OutputParser] = parsers if parsers is not None else []
+ self._report: Dict[str, Any] = defaultdict(lambda: defaultdict(dict))
+
+ def parse(self, output: bytearray) -> None:
+ """Parse output and append parsed metrics to internal report dict."""
+ for parser in self.parsers:
+ # Merge metrics from different parsers (do not overwrite)
+ self._report[parser.name]["metrics"].update(parser(output))
+
+ def get_filtered_output(self, output: bytearray) -> bytearray:
+ """Filter the output according to each parser."""
+ for parser in self.parsers:
+ output = parser.filter_out_parsed_content(output)
+ return output
+
+ def report(self, ctx: ExecutionContext) -> Dict[str, Any]:
+ """Add static simulation info to parsed data and return the report."""
+ report: Dict[str, Any] = defaultdict(dict)
+ # Add static simulation info
+ report.update(self._static_info(ctx))
+ # Add metrics parsed from the output
+ for key, val in self._report.items():
+ report[key].update(val)
+ return report
+
+ @staticmethod
+ def save(report: Dict[str, Any], report_file: Path) -> None:
+ """Save the report to a JSON file."""
+ with open(report_file, "w", encoding="utf-8") as file:
+ json.dump(report, file, indent=4)
+
+ @staticmethod
+ def _compute_all_params(cli_params: List[str], backend: Backend) -> Dict[str, str]:
+ """
+ Build a dict of all parameters, {name:value}.
+
+ Param values taken from command line if specified, defaults otherwise.
+ """
+ # map of params passed from the cli ["p1=v1","p2=v2"] -> {"p1":"v1", "p2":"v2"}
+ app_params_map = dict(parse_raw_parameter(expr) for expr in cli_params)
+
+ # a map of params declared in the application, with values taken from the CLI,
+ # defaults otherwise
+ all_params = {
+ (p.alias or p.name): app_params_map.get(
+ cast(str, p.name), cast(str, p.default_value)
+ )
+ for cmd in backend.commands.values()
+ for p in cmd.params
+ }
+ return cast(Dict[str, str], all_params)
+
+ @staticmethod
+ def _static_info(ctx: ExecutionContext) -> Dict[str, Any]:
+ """Extract static simulation information from the context."""
+ if ctx.system is None:
+ raise ValueError("No system available to report.")
+
+ info = {
+ "system": {
+ "name": ctx.system.name,
+ "params": Reporter._compute_all_params(ctx.system_params, ctx.system),
+ },
+ "application": {
+ "name": ctx.app.name,
+ "params": Reporter._compute_all_params(ctx.app_params, ctx.app),
+ },
+ }
+ return info
+
+
+def validate_parameters(
+ backend: Backend, command_names: List[str], params: List[str]
+) -> None:
+ """Check parameters passed to backend."""
+ for param in params:
+ acceptable = any(
+ backend.validate_parameter(command_name, param)
+ for command_name in command_names
+ if command_name in backend.commands
+ )
+
+ if not acceptable:
+ backend_type = "System" if isinstance(backend, System) else "Application"
+ raise ValueError(
+ "{} parameter '{}' not valid for command '{}'".format(
+ backend_type, param, " or ".join(command_names)
+ )
+ )
+
+
+def get_application_by_name_and_system(
+ application_name: str, system_name: str
+) -> Application:
+ """Get application."""
+ applications = get_application(application_name, system_name)
+ if not applications:
+ raise ValueError(
+ "Application '{}' doesn't support the system '{}'".format(
+ application_name, system_name
+ )
+ )
+
+ if len(applications) != 1:
+ raise ValueError(
+ "Error during getting application {} for the system {}".format(
+ application_name, system_name
+ )
+ )
+
+ return applications[0]
+
+
+def get_application_and_system(
+ application_name: str, system_name: str
+) -> Tuple[Application, System]:
+ """Return application and system by provided names."""
+ system = get_system(system_name)
+ if not system:
+ raise ValueError("System {} is not found".format(system_name))
+
+ application = get_application_by_name_and_system(application_name, system_name)
+
+ return application, system
+
+
+def execute_application_command( # pylint: disable=too-many-arguments
+ command_name: str,
+ application_name: str,
+ application_params: List[str],
+ system_name: str,
+ system_params: List[str],
+ custom_deploy_data: List[DataPaths],
+) -> None:
+ """Execute application command.
+
+ .. deprecated:: 21.12
+ """
+ warnings.warn(
+ "Use 'run_application()' instead. Use of 'execute_application_command()' is "
+ "deprecated and might be removed in a future release.",
+ DeprecationWarning,
+ )
+
+ if command_name not in ["build", "run"]:
+ raise ConfigurationException("Unsupported command {}".format(command_name))
+
+ application, system = get_application_and_system(application_name, system_name)
+ validate_parameters(application, [command_name], application_params)
+ validate_parameters(system, [command_name], system_params)
+
+ ctx = ExecutionContext(
+ app=application,
+ app_params=application_params,
+ system=system,
+ system_params=system_params,
+ custom_deploy_data=custom_deploy_data,
+ )
+
+ if command_name == "run":
+ execute_application_command_run(ctx)
+ else:
+ execute_application_command_build(ctx)
+
+
+# pylint: disable=too-many-arguments
+def run_application(
+ application_name: str,
+ application_params: List[str],
+ system_name: str,
+ system_params: List[str],
+ custom_deploy_data: List[DataPaths],
+ report_file: Optional[Path] = None,
+) -> None:
+ """Run application on the provided system."""
+ application, system = get_application_and_system(application_name, system_name)
+ validate_parameters(application, ["build", "run"], application_params)
+ validate_parameters(system, ["build", "run"], system_params)
+
+ execution_params = ExecutionParams()
+ if isinstance(system, StandaloneSystem):
+ execution_params["disable_locking"] = True
+ execution_params["unique_build_dir"] = True
+
+ ctx = ExecutionContext(
+ app=application,
+ app_params=application_params,
+ system=system,
+ system_params=system_params,
+ custom_deploy_data=custom_deploy_data,
+ execution_params=execution_params,
+ report_file=report_file,
+ )
+
+ with build_dir_manager(ctx):
+ if ctx.is_build_required:
+ execute_application_command_build(ctx)
+
+ execute_application_command_run(ctx)
+
+
+def execute_application_command_build(ctx: ExecutionContext) -> None:
+ """Execute application command 'build'."""
+ with ExitStack() as context_stack:
+ for manager in get_context_managers("build", ctx):
+ context_stack.enter_context(manager(ctx))
+
+ build_dir = ctx.build_dir()
+ recreate_directory(build_dir)
+
+ build_commands = ctx.app.build_command(
+ "build", ctx.app_params, ctx.param_resolver
+ )
+ execute_commands_locally(build_commands, build_dir)
+
+
+def execute_commands_locally(commands: List[str], cwd: Path) -> None:
+ """Execute list of commands locally."""
+ for command in commands:
+ print("Running: {}".format(command))
+ run_and_wait(
+ command, cwd, terminate_on_error=True, out=sys.stdout, err=sys.stderr
+ )
+
+
+def execute_application_command_run(ctx: ExecutionContext) -> None:
+ """Execute application command."""
+ assert ctx.system is not None, "System must be provided."
+ if ctx.is_deploy_needed and not ctx.system.supports_deploy:
+ raise ConfigurationException(
+ "System {} does not support data deploy".format(ctx.system.name)
+ )
+
+ with ExitStack() as context_stack:
+ for manager in get_context_managers("run", ctx):
+ context_stack.enter_context(manager(ctx))
+
+ print("Generating commands to execute")
+ commands_to_run = build_run_commands(ctx)
+
+ if ctx.system.connectable:
+ establish_connection(ctx)
+
+ if ctx.system.supports_deploy:
+ deploy_data(ctx)
+
+ for command in commands_to_run:
+ print("Running: {}".format(command))
+ exit_code, std_output, std_err = ctx.system.run(command)
+
+ if exit_code != 0:
+ print("Application exited with exit code {}".format(exit_code))
+
+ if ctx.reporter:
+ ctx.reporter.parse(std_output)
+ std_output = ctx.reporter.get_filtered_output(std_output)
+
+ print(std_output.decode("utf8"), end="")
+ print(std_err.decode("utf8"), end="")
+
+ if ctx.reporter:
+ report = ctx.reporter.report(ctx)
+ ctx.reporter.save(report, cast(Path, ctx.report_file))
+
+
+def establish_connection(
+ ctx: ExecutionContext, retries: int = 90, interval: float = 15.0
+) -> None:
+ """Establish connection with the system."""
+ assert ctx.system is not None, "System is required."
+ host, port = ctx.system.connection_details()
+ print(
+ "Trying to establish connection with '{}:{}' - "
+ "{} retries every {} seconds ".format(host, port, retries, interval),
+ end="",
+ )
+
+ try:
+ for _ in range(retries):
+ print(".", end="", flush=True)
+
+ if ctx.system.establish_connection():
+ break
+
+ if isinstance(ctx.system, ControlledSystem) and not ctx.system.is_running():
+ print(
+ "\n\n---------- {} execution failed ----------".format(
+ ctx.system.name
+ )
+ )
+ stdout, stderr = ctx.system.get_output()
+ print(stdout)
+ print(stderr)
+
+ raise Exception("System is not running")
+
+ wait(interval)
+ else:
+ raise ConnectionException("Couldn't connect to '{}:{}'.".format(host, port))
+ finally:
+ print()
+
+
+def wait(interval: float) -> None:
+ """Wait for a period of time."""
+ time.sleep(interval)
+
+
+def deploy_data(ctx: ExecutionContext) -> None:
+ """Deploy data to the system."""
+ if isinstance(ctx.app, Application):
+ # Only application can deploy data (tools can not)
+ assert ctx.system is not None, "System is required."
+ for item in itertools.chain(ctx.app.get_deploy_data(), ctx.custom_deploy_data):
+ print("Deploying {} onto {}".format(item.src, item.dst))
+ ctx.system.deploy(item.src, item.dst)
+
+
+def build_run_commands(ctx: ExecutionContext) -> List[str]:
+ """Build commands to run application."""
+ if isinstance(ctx.system, StandaloneSystem):
+ return ctx.system.build_command("run", ctx.system_params, ctx.param_resolver)
+
+ return ctx.app.build_command("run", ctx.app_params, ctx.param_resolver)
+
+
+@contextmanager
+def controlled_system_manager(ctx: ExecutionContext) -> Generator[None, None, None]:
+ """Context manager used for system initialisation before run."""
+ system = cast(ControlledSystem, ctx.system)
+ commands = system.build_command("run", ctx.system_params, ctx.param_resolver)
+ pid_file_path: Optional[Path] = None
+ if ctx.is_locking_required:
+ file_lock_path = get_file_lock_path(ctx)
+ pid_file_path = file_lock_path.parent / "{}.pid".format(file_lock_path.stem)
+
+ system.start(commands, ctx.is_locking_required, pid_file_path)
+ try:
+ yield
+ finally:
+ print("Shutting down sequence...")
+ print("Stopping {}... (It could take few seconds)".format(system.name))
+ system.stop(wait=True)
+ print("{} stopped successfully.".format(system.name))
+
+
+@contextmanager
+def lock_execution_manager(ctx: ExecutionContext) -> Generator[None, None, None]:
+ """Lock execution manager."""
+ file_lock_path = get_file_lock_path(ctx)
+ file_lock = FileLock(str(file_lock_path))
+
+ try:
+ file_lock.acquire(timeout=1)
+ except Timeout as error:
+ raise AnotherInstanceIsRunningException() from error
+
+ try:
+ yield
+ finally:
+ file_lock.release()
+
+
+def get_file_lock_path(ctx: ExecutionContext, lock_dir: Path = Path("/tmp")) -> Path:
+ """Get file lock path."""
+ lock_modules = []
+ if ctx.app.lock:
+ lock_modules.append(ctx.app.name)
+ if ctx.system is not None and ctx.system.lock:
+ lock_modules.append(ctx.system.name)
+ lock_filename = ""
+ if lock_modules:
+ lock_filename = "_".join(["middleware"] + lock_modules) + ".lock"
+
+ if lock_filename:
+ lock_filename = resolve_all_parameters(lock_filename, ctx.param_resolver)
+ lock_filename = valid_for_filename(lock_filename)
+
+ if not lock_filename:
+ raise ConfigurationException("No filename for lock provided")
+
+ if not isinstance(lock_dir, Path) or not lock_dir.is_dir():
+ raise ConfigurationException(
+ "Invalid directory {} for lock files provided".format(lock_dir)
+ )
+
+ return lock_dir / lock_filename
+
+
+@contextmanager
+def build_dir_manager(ctx: ExecutionContext) -> Generator[None, None, None]:
+ """Build directory manager."""
+ try:
+ yield
+ finally:
+ if (
+ ctx.is_build_required
+ and ctx.is_unique_build_dir_required
+ and ctx.build_dir().is_dir()
+ ):
+ remove_directory(ctx.build_dir())
+
+
+def get_context_managers(
+ command_name: str, ctx: ExecutionContext
+) -> Sequence[Callable[[ExecutionContext], ContextManager[None]]]:
+ """Get context manager for the system."""
+ managers = []
+
+ if ctx.is_locking_required:
+ managers.append(lock_execution_manager)
+
+ if command_name == "run":
+ if isinstance(ctx.system, ControlledSystem):
+ managers.append(controlled_system_manager)
+
+ return managers
+
+
+def get_tool_by_system(tool_name: str, system_name: Optional[str]) -> Tool:
+ """Return tool (optionally by provided system name."""
+ tools = get_tool(tool_name, system_name)
+ if not tools:
+ raise ConfigurationException(
+ "Tool '{}' not found or doesn't support the system '{}'".format(
+ tool_name, system_name
+ )
+ )
+ if len(tools) != 1:
+ raise ConfigurationException(
+ "Please specify the system for tool {}.".format(tool_name)
+ )
+ tool = tools[0]
+
+ return tool
+
+
+def execute_tool_command(
+ tool_name: str,
+ tool_params: List[str],
+ system_name: Optional[str] = None,
+) -> None:
+ """Execute the tool command locally calling the 'run' command."""
+ tool = get_tool_by_system(tool_name, system_name)
+ ctx = ExecutionContext(
+ app=tool, app_params=tool_params, system=None, system_params=[]
+ )
+ commands = tool.build_command("run", tool_params, ctx.param_resolver)
+
+ execute_commands_locally(commands, Path.cwd())
diff --git a/src/aiet/backend/output_parser.py b/src/aiet/backend/output_parser.py
new file mode 100644
index 0000000..111772a
--- /dev/null
+++ b/src/aiet/backend/output_parser.py
@@ -0,0 +1,176 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Definition of output parsers (including base class OutputParser)."""
+import base64
+import json
+import re
+from abc import ABC
+from abc import abstractmethod
+from typing import Any
+from typing import Dict
+from typing import Union
+
+
+class OutputParser(ABC):
+ """Abstract base class for output parsers."""
+
+ def __init__(self, name: str) -> None:
+ """Set up the name of the parser."""
+ super().__init__()
+ self.name = name
+
+ @abstractmethod
+ def __call__(self, output: bytearray) -> Dict[str, Any]:
+ """Parse the output and return a map of names to metrics."""
+ return {}
+
+ # pylint: disable=no-self-use
+ def filter_out_parsed_content(self, output: bytearray) -> bytearray:
+ """
+ Filter out the parsed content from the output.
+
+ Does nothing by default. Can be overridden in subclasses.
+ """
+ return output
+
+
+class RegexOutputParser(OutputParser):
+ """Parser of standard output data using regular expressions."""
+
+ _TYPE_MAP = {"str": str, "float": float, "int": int}
+
+ def __init__(
+ self,
+ name: str,
+ regex_config: Dict[str, Dict[str, str]],
+ ) -> None:
+ """
+ Set up the parser with the regular expressions.
+
+ The regex_config is mapping from a name to a dict with keys 'pattern'
+ and 'type':
+ - The 'pattern' holds the regular expression that must contain exactly
+ one capturing parenthesis
+ - The 'type' can be one of ['str', 'float', 'int'].
+
+ Example:
+ ```
+ {"Metric1": {"pattern": ".*= *(.*)", "type": "str"}}
+ ```
+
+ The different regular expressions from the config are combined using
+ non-capturing parenthesis, i.e. regular expressions must not overlap
+ if more than one match per line is expected.
+ """
+ super().__init__(name)
+
+ self._verify_config(regex_config)
+ self._regex_cfg = regex_config
+
+ # Compile regular expression to match in the output
+ self._regex = re.compile(
+ "|".join("(?:{0})".format(x["pattern"]) for x in self._regex_cfg.values())
+ )
+
+ def __call__(self, output: bytearray) -> Dict[str, Union[str, float, int]]:
+ """
+ Parse the output and return a map of names to metrics.
+
+ Example:
+ Assuming a regex_config as used as example in `__init__()` and the
+ following output:
+ ```
+ Simulation finished:
+ SIMULATION_STATUS = SUCCESS
+ Simulation DONE
+ ```
+ Then calling the parser should return the following dict:
+ ```
+ {
+ "Metric1": "SUCCESS"
+ }
+ ```
+ """
+ metrics = {}
+ output_str = output.decode("utf-8")
+ results = self._regex.findall(output_str)
+ for line_result in results:
+ for idx, (name, cfg) in enumerate(self._regex_cfg.items()):
+ # The result(s) returned by findall() are either a single string
+ # or a tuple (depending on the number of groups etc.)
+ result = (
+ line_result if isinstance(line_result, str) else line_result[idx]
+ )
+ if result:
+ mapped_result = self._TYPE_MAP[cfg["type"]](result)
+ metrics[name] = mapped_result
+ return metrics
+
+ def _verify_config(self, regex_config: Dict[str, Dict[str, str]]) -> None:
+ """Make sure we have a valid regex_config.
+
+ I.e.
+ - Exactly one capturing parenthesis per pattern
+ - Correct types
+ """
+ for name, cfg in regex_config.items():
+ # Check that there is one capturing group defined in the pattern.
+ regex = re.compile(cfg["pattern"])
+ if regex.groups != 1:
+ raise ValueError(
+ f"Pattern for metric '{name}' must have exactly one "
+ f"capturing parenthesis, but it has {regex.groups}."
+ )
+ # Check if type is supported
+ if not cfg["type"] in self._TYPE_MAP:
+ raise TypeError(
+ f"Type '{cfg['type']}' for metric '{name}' is not "
+ f"supported. Choose from: {list(self._TYPE_MAP.keys())}."
+ )
+
+
+class Base64OutputParser(OutputParser):
+ """
+ Parser to extract base64-encoded JSON from tagged standard output.
+
+ Example of the tagged output:
+ ```
+ # Encoded JSON: {"test": 1234}
+ <metrics>eyJ0ZXN0IjogMTIzNH0</metrics>
+ ```
+ """
+
+ TAG_NAME = "metrics"
+
+ def __init__(self, name: str) -> None:
+ """Set up the regular expression to extract tagged strings."""
+ super().__init__(name)
+ self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)</{self.TAG_NAME}>")
+
+ def __call__(self, output: bytearray) -> Dict[str, Any]:
+ """
+ Parse the output and return a map of index (as string) to decoded JSON.
+
+ Example:
+ Using the tagged output from the class docs the parser should return
+ the following dict:
+ ```
+ {
+ "0": {"test": 1234}
+ }
+ ```
+ """
+ metrics = {}
+ output_str = output.decode("utf-8")
+ results = self._regex.findall(output_str)
+ for idx, result_base64 in enumerate(results):
+ result_json = base64.b64decode(result_base64, validate=True)
+ result = json.loads(result_json)
+ metrics[str(idx)] = result
+
+ return metrics
+
+ def filter_out_parsed_content(self, output: bytearray) -> bytearray:
+ """Filter out base64-encoded content from the output."""
+ output_str = self._regex.sub("", output.decode("utf-8"))
+ return bytearray(output_str.encode("utf-8"))
diff --git a/src/aiet/backend/protocol.py b/src/aiet/backend/protocol.py
new file mode 100644
index 0000000..c621436
--- /dev/null
+++ b/src/aiet/backend/protocol.py
@@ -0,0 +1,325 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contain protocol related classes and functions."""
+from abc import ABC
+from abc import abstractmethod
+from contextlib import closing
+from pathlib import Path
+from typing import Any
+from typing import cast
+from typing import Iterable
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import paramiko
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.config import LocalProtocolConfig
+from aiet.backend.config import SSHConfig
+from aiet.utils.proc import run_and_wait
+
+
+# Redirect all paramiko thread exceptions to a file otherwise these will be
+# printed to stderr.
+paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO)
+
+
+class SSHConnectionException(Exception):
+ """SSH connection exception."""
+
+
+class SupportsClose(ABC):
+ """Class indicates support of close operation."""
+
+ @abstractmethod
+ def close(self) -> None:
+ """Close protocol session."""
+
+
+class SupportsDeploy(ABC):
+ """Class indicates support of deploy operation."""
+
+ @abstractmethod
+ def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
+ """Abstract method for deploy data."""
+
+
+class SupportsConnection(ABC):
+ """Class indicates that protocol uses network connections."""
+
+ @abstractmethod
+ def establish_connection(self) -> bool:
+ """Establish connection with underlying system."""
+
+ @abstractmethod
+ def connection_details(self) -> Tuple[str, int]:
+ """Return connection details (host, port)."""
+
+
+class Protocol(ABC):
+ """Abstract class for representing the protocol."""
+
+ def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None:
+ """Initialize the class using a dict."""
+ self.__dict__.update(iterable, **kwargs)
+ self._validate()
+
+ @abstractmethod
+ def _validate(self) -> None:
+ """Abstract method for config data validation."""
+
+ @abstractmethod
+ def run(
+ self, command: str, retry: bool = False
+ ) -> Tuple[int, bytearray, bytearray]:
+ """
+ Abstract method for running commands.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+
+
+class CustomSFTPClient(paramiko.SFTPClient):
+ """Class for creating a custom sftp client."""
+
+ def put_dir(self, source: Path, target: str) -> None:
+ """Upload the source directory to the target path.
+
+ The target directory needs to exists and the last directory of the
+ source will be created under the target with all its content.
+ """
+ # Create the target directory
+ self._mkdir(target, ignore_existing=True)
+ # Create the last directory in the source on the target
+ self._mkdir("{}/{}".format(target, source.name), ignore_existing=True)
+ # Go through the whole content of source
+ for item in sorted(source.glob("**/*")):
+ relative_path = item.relative_to(source.parent)
+ remote_target = target / relative_path
+ if item.is_file():
+ self.put(str(item), str(remote_target))
+ else:
+ self._mkdir(str(remote_target), ignore_existing=True)
+
+ def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None:
+ """Extend mkdir functionality.
+
+ This version adds an option to not fail if the folder exists.
+ """
+ try:
+ super().mkdir(path, mode)
+ except IOError as error:
+ if ignore_existing:
+ pass
+ else:
+ raise error
+
+
+class LocalProtocol(Protocol):
+ """Class for local protocol."""
+
+ protocol: str
+ cwd: Path
+
+ def run(
+ self, command: str, retry: bool = False
+ ) -> Tuple[int, bytearray, bytearray]:
+ """
+ Run command locally.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+ if not isinstance(self.cwd, Path) or not self.cwd.is_dir():
+ raise ConfigurationException("Wrong working directory {}".format(self.cwd))
+
+ stdout = bytearray()
+ stderr = bytearray()
+
+ return run_and_wait(
+ command, self.cwd, terminate_on_error=True, out=stdout, err=stderr
+ )
+
+ def _validate(self) -> None:
+ """Validate protocol configuration."""
+ assert hasattr(self, "protocol") and self.protocol == "local"
+ assert hasattr(self, "cwd")
+
+
+class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection):
+ """Class for SSH protocol."""
+
+ protocol: str
+ username: str
+ password: str
+ hostname: str
+ port: int
+
+ def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None:
+ """Initialize the class using a dict."""
+ super().__init__(iterable, **kwargs)
+ # Internal state to store if the system is connectable. It will be set
+ # to true at the first connection instance
+ self.client: Optional[paramiko.client.SSHClient] = None
+ self.port = int(self.port)
+
+ def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]:
+ """
+ Run command over SSH.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+ transport = self._get_transport()
+ with closing(transport.open_session()) as channel:
+ # Enable shell's .profile settings and execute command
+ channel.exec_command("bash -l -c '{}'".format(command))
+ exit_status = -1
+ stdout = bytearray()
+ stderr = bytearray()
+ while True:
+ if channel.exit_status_ready():
+ exit_status = channel.recv_exit_status()
+ # Call it one last time to read any leftover in the channel
+ self._recv_stdout_err(channel, stdout, stderr)
+ break
+ self._recv_stdout_err(channel, stdout, stderr)
+
+ return exit_status, stdout, stderr
+
+ def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
+ """Deploy src to remote dst over SSH.
+
+ src and dst should be path to a file or directory.
+ """
+ transport = self._get_transport()
+ sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport))
+
+ with closing(sftp):
+ if src.is_dir():
+ sftp.put_dir(src, dst)
+ elif src.is_file():
+ sftp.put(str(src), dst)
+ else:
+ raise Exception("Deploy error: file type not supported")
+
+ # After the deployment of files, sync the remote filesystem to flush
+ # buffers to hard disk
+ self.run("sync")
+
+ def close(self) -> None:
+ """Close protocol session."""
+ if self.client is not None:
+ print("Try syncing remote file system...")
+ # Before stopping the system, we try to run sync to make sure all
+ # data are flushed on disk.
+ self.run("sync", retry=False)
+ self._close_client(self.client)
+
+ def establish_connection(self) -> bool:
+ """Establish connection with underlying system."""
+ if self.client is not None:
+ return True
+
+ self.client = self._connect()
+ return self.client is not None
+
+ def _get_transport(self) -> paramiko.transport.Transport:
+ """Get transport."""
+ self.establish_connection()
+
+ if self.client is None:
+ raise SSHConnectionException(
+ "Couldn't connect to '{}:{}'.".format(self.hostname, self.port)
+ )
+
+ transport = self.client.get_transport()
+ if not transport:
+ raise Exception("Unable to get transport")
+
+ return transport
+
+ def connection_details(self) -> Tuple[str, int]:
+ """Return connection details of underlying system."""
+ return (self.hostname, self.port)
+
+ def _connect(self) -> Optional[paramiko.client.SSHClient]:
+ """Try to establish connection."""
+ client: Optional[paramiko.client.SSHClient] = None
+ try:
+ client = paramiko.client.SSHClient()
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ client.connect(
+ self.hostname,
+ self.port,
+ self.username,
+ self.password,
+ # next parameters should be set to False to disable authentication
+ # using ssh keys
+ allow_agent=False,
+ look_for_keys=False,
+ )
+ return client
+ except (
+ # OSError raised on first attempt to connect when running inside Docker
+ OSError,
+ paramiko.ssh_exception.NoValidConnectionsError,
+ paramiko.ssh_exception.SSHException,
+ ):
+ # even if connection is not established socket could be still
+ # open, it should be closed
+ self._close_client(client)
+
+ return None
+
+ @staticmethod
+ def _close_client(client: Optional[paramiko.client.SSHClient]) -> None:
+ """Close ssh client."""
+ try:
+ if client is not None:
+ client.close()
+ except Exception: # pylint: disable=broad-except
+ pass
+
+ @classmethod
+ def _recv_stdout_err(
+ cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray
+ ) -> None:
+ """Read from channel to stdout/stder."""
+ chunk_size = 512
+ if channel.recv_ready():
+ stdout_chunk = channel.recv(chunk_size)
+ stdout.extend(stdout_chunk)
+ if channel.recv_stderr_ready():
+ stderr_chunk = channel.recv_stderr(chunk_size)
+ stderr.extend(stderr_chunk)
+
+ def _validate(self) -> None:
+ """Check if there are all the info for establishing the connection."""
+ assert hasattr(self, "protocol") and self.protocol == "ssh"
+ assert hasattr(self, "username")
+ assert hasattr(self, "password")
+ assert hasattr(self, "hostname")
+ assert hasattr(self, "port")
+
+
+class ProtocolFactory:
+ """Factory class to return the appropriate Protocol class."""
+
+ @staticmethod
+ def get_protocol(
+ config: Optional[Union[SSHConfig, LocalProtocolConfig]],
+ **kwargs: Union[str, Path, None]
+ ) -> Union[SSHProtocol, LocalProtocol]:
+ """Return the right protocol instance based on the config."""
+ if not config:
+ raise ValueError("No protocol config provided")
+
+ protocol = config["protocol"]
+ if protocol == "ssh":
+ return SSHProtocol(config)
+
+ if protocol == "local":
+ cwd = kwargs.get("cwd")
+ return LocalProtocol(config, cwd=cwd)
+
+ raise ValueError("Protocol not supported: '{}'".format(protocol))
diff --git a/src/aiet/backend/source.py b/src/aiet/backend/source.py
new file mode 100644
index 0000000..dec175a
--- /dev/null
+++ b/src/aiet/backend/source.py
@@ -0,0 +1,209 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contain source related classes and functions."""
+import os
+import shutil
+import tarfile
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+from tarfile import TarFile
+from typing import Optional
+from typing import Union
+
+from aiet.backend.common import AIET_CONFIG_FILE
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import get_backend_config
+from aiet.backend.common import is_backend_directory
+from aiet.backend.common import load_config
+from aiet.backend.config import BackendConfig
+from aiet.utils.fs import copy_directory_content
+
+
+class Source(ABC):
+ """Source class."""
+
+ @abstractmethod
+ def name(self) -> Optional[str]:
+ """Get source name."""
+
+ @abstractmethod
+ def config(self) -> Optional[BackendConfig]:
+ """Get configuration file content."""
+
+ @abstractmethod
+ def install_into(self, destination: Path) -> None:
+ """Install source into destination directory."""
+
+ @abstractmethod
+ def create_destination(self) -> bool:
+ """Return True if destination folder should be created before installation."""
+
+
+class DirectorySource(Source):
+ """DirectorySource class."""
+
+ def __init__(self, directory_path: Path) -> None:
+ """Create the DirectorySource instance."""
+ assert isinstance(directory_path, Path)
+ self.directory_path = directory_path
+
+ def name(self) -> str:
+ """Return name of source."""
+ return self.directory_path.name
+
+ def config(self) -> Optional[BackendConfig]:
+ """Return configuration file content."""
+ if not is_backend_directory(self.directory_path):
+ raise ConfigurationException("No configuration file found")
+
+ config_file = get_backend_config(self.directory_path)
+ return load_config(config_file)
+
+ def install_into(self, destination: Path) -> None:
+ """Install source into destination directory."""
+ if not destination.is_dir():
+ raise ConfigurationException("Wrong destination {}".format(destination))
+
+ if not self.directory_path.is_dir():
+ raise ConfigurationException(
+ "Directory {} does not exist".format(self.directory_path)
+ )
+
+ copy_directory_content(self.directory_path, destination)
+
+ def create_destination(self) -> bool:
+ """Return True if destination folder should be created before installation."""
+ return True
+
+
+class TarArchiveSource(Source):
+ """TarArchiveSource class."""
+
+ def __init__(self, archive_path: Path) -> None:
+ """Create the TarArchiveSource class."""
+ assert isinstance(archive_path, Path)
+ self.archive_path = archive_path
+ self._config: Optional[BackendConfig] = None
+ self._has_top_level_folder: Optional[bool] = None
+ self._name: Optional[str] = None
+
+ def _read_archive_content(self) -> None:
+ """Read various information about archive."""
+ # get source name from archive name (everything without extensions)
+ extensions = "".join(self.archive_path.suffixes)
+ self._name = self.archive_path.name.rstrip(extensions)
+
+ if not self.archive_path.exists():
+ return
+
+ with self._open(self.archive_path) as archive:
+ try:
+ config_entry = archive.getmember(AIET_CONFIG_FILE)
+ self._has_top_level_folder = False
+ except KeyError as error_no_config:
+ try:
+ archive_entries = archive.getnames()
+ entries_common_prefix = os.path.commonprefix(archive_entries)
+ top_level_dir = entries_common_prefix.rstrip("/")
+
+ if not top_level_dir:
+ raise RuntimeError(
+ "Archive has no top level directory"
+ ) from error_no_config
+
+ config_path = "{}/{}".format(top_level_dir, AIET_CONFIG_FILE)
+
+ config_entry = archive.getmember(config_path)
+ self._has_top_level_folder = True
+ self._name = top_level_dir
+ except (KeyError, RuntimeError) as error_no_root_dir_or_config:
+ raise ConfigurationException(
+ "No configuration file found"
+ ) from error_no_root_dir_or_config
+
+ content = archive.extractfile(config_entry)
+ self._config = load_config(content)
+
+ def config(self) -> Optional[BackendConfig]:
+ """Return configuration file content."""
+ if self._config is None:
+ self._read_archive_content()
+
+ return self._config
+
+ def name(self) -> Optional[str]:
+ """Return name of the source."""
+ if self._name is None:
+ self._read_archive_content()
+
+ return self._name
+
+ def create_destination(self) -> bool:
+ """Return True if destination folder must be created before installation."""
+ if self._has_top_level_folder is None:
+ self._read_archive_content()
+
+ return not self._has_top_level_folder
+
+ def install_into(self, destination: Path) -> None:
+ """Install source into destination directory."""
+ if not destination.is_dir():
+ raise ConfigurationException("Wrong destination {}".format(destination))
+
+ with self._open(self.archive_path) as archive:
+ archive.extractall(destination)
+
+ def _open(self, archive_path: Path) -> TarFile:
+ """Open archive file."""
+ if not archive_path.is_file():
+ raise ConfigurationException("File {} does not exist".format(archive_path))
+
+ if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"):
+ mode = "r:gz"
+ else:
+ raise ConfigurationException(
+ "Unsupported archive type {}".format(archive_path)
+ )
+
+ # The returned TarFile object can be used as a context manager (using
+ # 'with') by the calling instance.
+ return tarfile.open( # pylint: disable=consider-using-with
+ self.archive_path, mode=mode
+ )
+
+
+def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]:
+ """Return appropriate source instance based on provided source path."""
+ if source_path.is_file():
+ return TarArchiveSource(source_path)
+
+ if source_path.is_dir():
+ return DirectorySource(source_path)
+
+ raise ConfigurationException("Unable to read {}".format(source_path))
+
+
+def create_destination_and_install(source: Source, resource_path: Path) -> None:
+ """Create destination directory and install source.
+
+ This function is used for actual installation of system/backend New
+ directory will be created inside :resource_path: if needed If for example
+ archive contains top level folder then no need to create new directory
+ """
+ destination = resource_path
+ create_destination = source.create_destination()
+
+ if create_destination:
+ name = source.name()
+ if not name:
+ raise ConfigurationException("Unable to get source name")
+
+ destination = resource_path / name
+ destination.mkdir()
+ try:
+ source.install_into(destination)
+ except Exception as error:
+ if create_destination:
+ shutil.rmtree(destination)
+ raise error
diff --git a/src/aiet/backend/system.py b/src/aiet/backend/system.py
new file mode 100644
index 0000000..48f1bb1
--- /dev/null
+++ b/src/aiet/backend/system.py
@@ -0,0 +1,289 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""System backend module."""
+from pathlib import Path
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+from aiet.backend.common import Backend
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import get_backend_configs
+from aiet.backend.common import get_backend_directories
+from aiet.backend.common import load_config
+from aiet.backend.common import remove_backend
+from aiet.backend.config import SystemConfig
+from aiet.backend.controller import SystemController
+from aiet.backend.controller import SystemControllerSingleInstance
+from aiet.backend.protocol import ProtocolFactory
+from aiet.backend.protocol import SupportsClose
+from aiet.backend.protocol import SupportsConnection
+from aiet.backend.protocol import SupportsDeploy
+from aiet.backend.source import create_destination_and_install
+from aiet.backend.source import get_source
+from aiet.utils.fs import get_resources
+
+
+def get_available_systems_directory_names() -> List[str]:
+ """Return a list of directory names for all avialable systems."""
+ return [entry.name for entry in get_backend_directories("systems")]
+
+
+def get_available_systems() -> List["System"]:
+ """Return a list with all available systems."""
+ available_systems = []
+ for config_json in get_backend_configs("systems"):
+ config_entries = cast(List[SystemConfig], (load_config(config_json)))
+ for config_entry in config_entries:
+ config_entry["config_location"] = config_json.parent.absolute()
+ system = load_system(config_entry)
+ available_systems.append(system)
+
+ return sorted(available_systems, key=lambda system: system.name)
+
+
+def get_system(system_name: str) -> Optional["System"]:
+ """Return a system instance with the same name passed as argument."""
+ available_systems = get_available_systems()
+ for system in available_systems:
+ if system_name == system.name:
+ return system
+ return None
+
+
+def install_system(source_path: Path) -> None:
+ """Install new system."""
+ try:
+ source = get_source(source_path)
+ config = cast(List[SystemConfig], source.config())
+ systems_to_install = [load_system(entry) for entry in config]
+ except Exception as error:
+ raise ConfigurationException("Unable to read system definition") from error
+
+ if not systems_to_install:
+ raise ConfigurationException("No system definition found")
+
+ available_systems = get_available_systems()
+ already_installed = [s for s in systems_to_install if s in available_systems]
+ if already_installed:
+ names = [system.name for system in already_installed]
+ raise ConfigurationException(
+ "Systems [{}] are already installed".format(",".join(names))
+ )
+
+ create_destination_and_install(source, get_resources("systems"))
+
+
+def remove_system(directory_name: str) -> None:
+ """Remove system."""
+ remove_backend(directory_name, "systems")
+
+
+class System(Backend):
+ """System class."""
+
+ def __init__(self, config: SystemConfig) -> None:
+ """Construct the System class using the dictionary passed."""
+ super().__init__(config)
+
+ self._setup_data_transfer(config)
+ self._setup_reporting(config)
+
+ def _setup_data_transfer(self, config: SystemConfig) -> None:
+ data_transfer_config = config.get("data_transfer")
+ protocol = ProtocolFactory().get_protocol(
+ data_transfer_config, cwd=self.config_location
+ )
+ self.protocol = protocol
+
+ def _setup_reporting(self, config: SystemConfig) -> None:
+ self.reporting = config.get("reporting")
+
+ def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]:
+ """
+ Run command on the system.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+ return self.protocol.run(command, retry)
+
+ def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
+ """Deploy files to the system."""
+ if isinstance(self.protocol, SupportsDeploy):
+ self.protocol.deploy(src, dst, retry)
+
+ @property
+ def supports_deploy(self) -> bool:
+ """Check if protocol supports deploy operation."""
+ return isinstance(self.protocol, SupportsDeploy)
+
+ @property
+ def connectable(self) -> bool:
+ """Check if protocol supports connection."""
+ return isinstance(self.protocol, SupportsConnection)
+
+ def establish_connection(self) -> bool:
+ """Establish connection with the system."""
+ if not isinstance(self.protocol, SupportsConnection):
+ raise ConfigurationException(
+ "System {} does not support connections".format(self.name)
+ )
+
+ return self.protocol.establish_connection()
+
+ def connection_details(self) -> Tuple[str, int]:
+ """Return connection details."""
+ if not isinstance(self.protocol, SupportsConnection):
+ raise ConfigurationException(
+ "System {} does not support connections".format(self.name)
+ )
+
+ return self.protocol.connection_details()
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, System):
+ return False
+
+ return super().__eq__(other) and self.name == other.name
+
+ def get_details(self) -> Dict[str, Any]:
+ """Return a dictionary with all relevant information of a System."""
+ output = {
+ "type": "system",
+ "name": self.name,
+ "description": self.description,
+ "data_transfer_protocol": self.protocol.protocol,
+ "commands": self._get_command_details(),
+ "annotations": self.annotations,
+ }
+
+ return output
+
+
+class StandaloneSystem(System):
+ """StandaloneSystem class."""
+
+
+def get_controller(
+ single_instance: bool, pid_file_path: Optional[Path] = None
+) -> SystemController:
+ """Get system controller."""
+ if single_instance:
+ return SystemControllerSingleInstance(pid_file_path)
+
+ return SystemController()
+
+
+class ControlledSystem(System):
+ """ControlledSystem class."""
+
+ def __init__(self, config: SystemConfig):
+ """Construct the ControlledSystem class using the dictionary passed."""
+ super().__init__(config)
+ self.controller: Optional[SystemController] = None
+
+ def start(
+ self,
+ commands: List[str],
+ single_instance: bool = True,
+ pid_file_path: Optional[Path] = None,
+ ) -> None:
+ """Launch the system."""
+ if (
+ not isinstance(self.config_location, Path)
+ or not self.config_location.is_dir()
+ ):
+ raise ConfigurationException(
+ "System {} has wrong config location".format(self.name)
+ )
+
+ self.controller = get_controller(single_instance, pid_file_path)
+ self.controller.start(commands, self.config_location)
+
+ def is_running(self) -> bool:
+ """Check if system is running."""
+ if not self.controller:
+ return False
+
+ return self.controller.is_running()
+
+ def get_output(self) -> Tuple[str, str]:
+ """Return system output."""
+ if not self.controller:
+ return "", ""
+
+ return self.controller.get_output()
+
+ def stop(self, wait: bool = False) -> None:
+ """Stop the system."""
+ if not self.controller:
+ raise Exception("System has not been started")
+
+ if isinstance(self.protocol, SupportsClose):
+ try:
+ self.protocol.close()
+ except Exception as error: # pylint: disable=broad-except
+ print(error)
+ self.controller.stop(wait)
+
+
+def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]:
+ """Load system based on it's execution type."""
+ data_transfer = config.get("data_transfer", {})
+ protocol = data_transfer.get("protocol")
+ populate_shared_params(config)
+
+ if protocol == "ssh":
+ return ControlledSystem(config)
+
+ if protocol == "local":
+ return StandaloneSystem(config)
+
+ raise ConfigurationException(
+ "Unsupported execution type for protocol {}".format(protocol)
+ )
+
+
+def populate_shared_params(config: SystemConfig) -> None:
+ """Populate command parameters with shared parameters."""
+ user_params = config.get("user_params")
+ if not user_params or "shared" not in user_params:
+ return
+
+ shared_user_params = user_params["shared"]
+ if not shared_user_params:
+ return
+
+ only_aliases = all(p.get("alias") for p in shared_user_params)
+ if not only_aliases:
+ raise ConfigurationException("All shared parameters should have aliases")
+
+ commands = config.get("commands", {})
+ for cmd_name in ["build", "run"]:
+ command = commands.get(cmd_name)
+ if command is None:
+ commands[cmd_name] = []
+ cmd_user_params = user_params.get(cmd_name)
+ if not cmd_user_params:
+ cmd_user_params = shared_user_params
+ else:
+ only_aliases = all(p.get("alias") for p in cmd_user_params)
+ if not only_aliases:
+ raise ConfigurationException(
+ "All parameters for command {} should have aliases".format(cmd_name)
+ )
+ merged_by_alias = {
+ **{p.get("alias"): p for p in shared_user_params},
+ **{p.get("alias"): p for p in cmd_user_params},
+ }
+ cmd_user_params = list(merged_by_alias.values())
+
+ user_params[cmd_name] = cmd_user_params
+
+ config["commands"] = commands
+ del user_params["shared"]
diff --git a/src/aiet/backend/tool.py b/src/aiet/backend/tool.py
new file mode 100644
index 0000000..d643665
--- /dev/null
+++ b/src/aiet/backend/tool.py
@@ -0,0 +1,109 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tool backend module."""
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from aiet.backend.common import Backend
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import get_backend_configs
+from aiet.backend.common import get_backend_directories
+from aiet.backend.common import load_application_or_tool_configs
+from aiet.backend.common import load_config
+from aiet.backend.config import ExtendedToolConfig
+from aiet.backend.config import ToolConfig
+
+
+def get_available_tool_directory_names() -> List[str]:
+ """Return a list of directory names for all available tools."""
+ return [entry.name for entry in get_backend_directories("tools")]
+
+
+def get_available_tools() -> List["Tool"]:
+ """Return a list with all available tools."""
+ available_tools = []
+ for config_json in get_backend_configs("tools"):
+ config_entries = cast(List[ExtendedToolConfig], load_config(config_json))
+ for config_entry in config_entries:
+ config_entry["config_location"] = config_json.parent.absolute()
+ tools = load_tools(config_entry)
+ available_tools += tools
+
+ return sorted(available_tools, key=lambda tool: tool.name)
+
+
+def get_tool(tool_name: str, system_name: Optional[str] = None) -> List["Tool"]:
+ """Return a tool instance with the same name passed as argument."""
+ return [
+ tool
+ for tool in get_available_tools()
+ if tool.name == tool_name and (not system_name or tool.can_run_on(system_name))
+ ]
+
+
+def get_unique_tool_names(system_name: Optional[str] = None) -> List[str]:
+ """Extract a list of unique tool names of all tools available."""
+ return list(
+ set(
+ tool.name
+ for tool in get_available_tools()
+ if not system_name or tool.can_run_on(system_name)
+ )
+ )
+
+
+class Tool(Backend):
+ """Class for representing a single tool component."""
+
+ def __init__(self, config: ToolConfig) -> None:
+ """Construct a Tool instance from a dict."""
+ super().__init__(config)
+
+ self.supported_systems = config.get("supported_systems", [])
+
+ if "run" not in self.commands:
+ raise ConfigurationException("A Tool must have a 'run' command.")
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, Tool):
+ return False
+
+ return (
+ super().__eq__(other)
+ and self.name == other.name
+ and set(self.supported_systems) == set(other.supported_systems)
+ )
+
+ def can_run_on(self, system_name: str) -> bool:
+ """Check if the tool can run on the system passed as argument."""
+ return system_name in self.supported_systems
+
+ def get_details(self) -> Dict[str, Any]:
+ """Return dictionary with all relevant information of the Tool instance."""
+ output = {
+ "type": "tool",
+ "name": self.name,
+ "description": self.description,
+ "supported_systems": self.supported_systems,
+ "commands": self._get_command_details(),
+ }
+
+ return output
+
+
+def load_tools(config: ExtendedToolConfig) -> List[Tool]:
+ """Load tool.
+
+ Tool configuration could contain different parameters/commands for different
+ supported systems. For each supported system this function will return separate
+ Tool instance with appropriate configuration.
+ """
+ configs = load_application_or_tool_configs(
+ config, ToolConfig, is_system_required=False
+ )
+ tools = [Tool(cfg) for cfg in configs]
+ return tools
diff --git a/src/aiet/cli/__init__.py b/src/aiet/cli/__init__.py
new file mode 100644
index 0000000..bcd17c3
--- /dev/null
+++ b/src/aiet/cli/__init__.py
@@ -0,0 +1,28 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module to mange the CLI interface."""
+import click
+
+from aiet import __version__
+from aiet.cli.application import application_cmd
+from aiet.cli.completion import completion_cmd
+from aiet.cli.system import system_cmd
+from aiet.cli.tool import tool_cmd
+from aiet.utils.helpers import set_verbosity
+
+
+@click.group()
+@click.version_option(__version__)
+@click.option(
+ "-v", "--verbose", default=0, count=True, callback=set_verbosity, expose_value=False
+)
+@click.pass_context
+def cli(ctx: click.Context) -> None: # pylint: disable=unused-argument
+ """AIET: AI Evaluation Toolkit."""
+ # Unused arguments must be present here in definition to pass click context.
+
+
+cli.add_command(application_cmd)
+cli.add_command(system_cmd)
+cli.add_command(tool_cmd)
+cli.add_command(completion_cmd)
diff --git a/src/aiet/cli/application.py b/src/aiet/cli/application.py
new file mode 100644
index 0000000..59b652d
--- /dev/null
+++ b/src/aiet/cli/application.py
@@ -0,0 +1,362 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright (c) 2021, Gianluca Gippetto. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
+"""Module to manage the CLI interface of applications."""
+import json
+import logging
+import re
+from pathlib import Path
+from typing import Any
+from typing import IO
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import click
+import cloup
+
+from aiet.backend.application import get_application
+from aiet.backend.application import get_available_application_directory_names
+from aiet.backend.application import get_unique_application_names
+from aiet.backend.application import install_application
+from aiet.backend.application import remove_application
+from aiet.backend.common import DataPaths
+from aiet.backend.execution import execute_application_command
+from aiet.backend.execution import run_application
+from aiet.backend.system import get_available_systems
+from aiet.cli.common import get_format
+from aiet.cli.common import middleware_exception_handler
+from aiet.cli.common import middleware_signal_handler
+from aiet.cli.common import print_command_details
+from aiet.cli.common import set_format
+
+
+@click.group(name="application")
+@click.option(
+ "-f",
+ "--format",
+ "format_",
+ type=click.Choice(["cli", "json"]),
+ default="cli",
+ show_default=True,
+)
+@click.pass_context
+def application_cmd(ctx: click.Context, format_: str) -> None:
+ """Sub command to manage applications."""
+ set_format(ctx, format_)
+
+
+@application_cmd.command(name="list")
+@click.pass_context
+@click.option(
+ "-s",
+ "--system",
+ "system_name",
+ type=click.Choice([s.name for s in get_available_systems()]),
+ required=False,
+)
+def list_cmd(ctx: click.Context, system_name: str) -> None:
+ """List all available applications."""
+ unique_application_names = get_unique_application_names(system_name)
+ unique_application_names.sort()
+ if get_format(ctx) == "json":
+ data = {"type": "application", "available": unique_application_names}
+ print(json.dumps(data))
+ else:
+ print("Available applications:\n")
+ print(*unique_application_names, sep="\n")
+
+
+@application_cmd.command(name="details")
+@click.option(
+ "-n",
+ "--name",
+ "application_name",
+ type=click.Choice(get_unique_application_names()),
+ required=True,
+)
+@click.option(
+ "-s",
+ "--system",
+ "system_name",
+ type=click.Choice([s.name for s in get_available_systems()]),
+ required=False,
+)
+@click.pass_context
+def details_cmd(ctx: click.Context, application_name: str, system_name: str) -> None:
+ """Details of a specific application."""
+ applications = get_application(application_name, system_name)
+ if not applications:
+ raise click.UsageError(
+ "Application '{}' doesn't support the system '{}'".format(
+ application_name, system_name
+ )
+ )
+
+ if get_format(ctx) == "json":
+ applications_details = [s.get_details() for s in applications]
+ print(json.dumps(applications_details))
+ else:
+ for application in applications:
+ application_details = application.get_details()
+ application_details_template = (
+ 'Application "{name}" details\nDescription: {description}'
+ )
+
+ print(
+ application_details_template.format(
+ name=application_details["name"],
+ description=application_details["description"],
+ )
+ )
+
+ print(
+ "\nSupported systems: {}".format(
+ ", ".join(application_details["supported_systems"])
+ )
+ )
+
+ command_details = application_details["commands"]
+
+ for command, details in command_details.items():
+ print("\n{} commands:".format(command))
+ print_command_details(details)
+
+
+# pylint: disable=too-many-arguments
+@application_cmd.command(name="execute")
+@click.option(
+ "-n",
+ "--name",
+ "application_name",
+ type=click.Choice(get_unique_application_names()),
+ required=True,
+)
+@click.option(
+ "-s",
+ "--system",
+ "system_name",
+ type=click.Choice([s.name for s in get_available_systems()]),
+ required=True,
+)
+@click.option(
+ "-c",
+ "--command",
+ "command_name",
+ type=click.Choice(["build", "run"]),
+ required=True,
+)
+@click.option("-p", "--param", "application_params", multiple=True)
+@click.option("--system-param", "system_params", multiple=True)
+@click.option("-d", "--deploy", "deploy_params", multiple=True)
+@middleware_signal_handler
+@middleware_exception_handler
+def execute_cmd(
+ application_name: str,
+ system_name: str,
+ command_name: str,
+ application_params: List[str],
+ system_params: List[str],
+ deploy_params: List[str],
+) -> None:
+ """Execute application commands. DEPRECATED! Use 'aiet application run' instead."""
+ logging.warning(
+ "Please use 'aiet application run' instead. Use of 'aiet application "
+ "execute' is deprecated and might be removed in a future release."
+ )
+
+ custom_deploy_data = get_custom_deploy_data(command_name, deploy_params)
+
+ execute_application_command(
+ command_name,
+ application_name,
+ application_params,
+ system_name,
+ system_params,
+ custom_deploy_data,
+ )
+
+
+@cloup.command(name="run")
+@cloup.option(
+ "-n",
+ "--name",
+ "application_name",
+ type=click.Choice(get_unique_application_names()),
+)
+@cloup.option(
+ "-s",
+ "--system",
+ "system_name",
+ type=click.Choice([s.name for s in get_available_systems()]),
+)
+@cloup.option("-p", "--param", "application_params", multiple=True)
+@cloup.option("--system-param", "system_params", multiple=True)
+@cloup.option("-d", "--deploy", "deploy_params", multiple=True)
+@click.option(
+ "-r",
+ "--report",
+ "report_file",
+ type=Path,
+ help="Create a report file in JSON format containing metrics parsed from "
+ "the simulation output as specified in the aiet-config.json.",
+)
+@cloup.option(
+ "--config",
+ "config_file",
+ type=click.File("r"),
+ help="Read options from a config file rather than from the command line. "
+ "The config file is a json file.",
+)
+@cloup.constraint(
+ cloup.constraints.If(
+ cloup.constraints.conditions.Not(
+ cloup.constraints.conditions.IsSet("config_file")
+ ),
+ then=cloup.constraints.require_all,
+ ),
+ ["system_name", "application_name"],
+)
+@cloup.constraint(
+ cloup.constraints.If("config_file", then=cloup.constraints.accept_none),
+ [
+ "system_name",
+ "application_name",
+ "application_params",
+ "system_params",
+ "deploy_params",
+ ],
+)
+@middleware_signal_handler
+@middleware_exception_handler
+def run_cmd(
+ application_name: str,
+ system_name: str,
+ application_params: List[str],
+ system_params: List[str],
+ deploy_params: List[str],
+ report_file: Optional[Path],
+ config_file: Optional[IO[str]],
+) -> None:
+ """Execute application commands."""
+ if config_file:
+ payload_data = json.load(config_file)
+ (
+ system_name,
+ application_name,
+ application_params,
+ system_params,
+ deploy_params,
+ report_file,
+ ) = parse_payload_run_config(payload_data)
+
+ custom_deploy_data = get_custom_deploy_data("run", deploy_params)
+
+ run_application(
+ application_name,
+ application_params,
+ system_name,
+ system_params,
+ custom_deploy_data,
+ report_file,
+ )
+
+
+application_cmd.add_command(run_cmd)
+
+
+def parse_payload_run_config(
+ payload_data: dict,
+) -> Tuple[str, str, List[str], List[str], List[str], Optional[Path]]:
+ """Parse the payload into a tuple."""
+ system_id = payload_data.get("id")
+ arguments: Optional[Any] = payload_data.get("arguments")
+
+ if not isinstance(system_id, str):
+ raise click.ClickException("invalid payload json: no system 'id'")
+ if not isinstance(arguments, dict):
+ raise click.ClickException("invalid payload json: no arguments object")
+
+ application_name = arguments.pop("application", None)
+ if not isinstance(application_name, str):
+ raise click.ClickException("invalid payload json: no application_id")
+
+ report_path = arguments.pop("report_path", None)
+
+ application_params = []
+ system_params = []
+ deploy_params = []
+
+ for (param_key, value) in arguments.items():
+ (par, _) = re.subn("^application/", "", param_key)
+ (par, found_sys_param) = re.subn("^system/", "", par)
+ (par, found_deploy_param) = re.subn("^deploy/", "", par)
+
+ param_expr = par + "=" + value
+ if found_sys_param:
+ system_params.append(param_expr)
+ elif found_deploy_param:
+ deploy_params.append(par)
+ else:
+ application_params.append(param_expr)
+
+ return (
+ system_id,
+ application_name,
+ application_params,
+ system_params,
+ deploy_params,
+ report_path,
+ )
+
+
+def get_custom_deploy_data(
+ command_name: str, deploy_params: List[str]
+) -> List[DataPaths]:
+ """Get custom deploy data information."""
+ custom_deploy_data: List[DataPaths] = []
+ if not deploy_params:
+ return custom_deploy_data
+
+ for param in deploy_params:
+ parts = param.split(":")
+ if not len(parts) == 2 or any(not part.strip() for part in parts):
+ raise click.ClickException(
+ "Invalid deploy parameter '{}' for command {}".format(
+ param, command_name
+ )
+ )
+ data_path = DataPaths(Path(parts[0]), parts[1])
+ if not data_path.src.exists():
+ raise click.ClickException("Path {} does not exist".format(data_path.src))
+ custom_deploy_data.append(data_path)
+
+ return custom_deploy_data
+
+
+@application_cmd.command(name="install")
+@click.option(
+ "-s",
+ "--source",
+ "source",
+ required=True,
+ help="Path to the directory or archive with application definition",
+)
+def install_cmd(source: str) -> None:
+ """Install new application."""
+ source_path = Path(source)
+ install_application(source_path)
+
+
+@application_cmd.command(name="remove")
+@click.option(
+ "-d",
+ "--directory_name",
+ "directory_name",
+ type=click.Choice(get_available_application_directory_names()),
+ required=True,
+ help="Name of the directory with application",
+)
+def remove_cmd(directory_name: str) -> None:
+ """Remove application."""
+ remove_application(directory_name)
diff --git a/src/aiet/cli/common.py b/src/aiet/cli/common.py
new file mode 100644
index 0000000..1d157b6
--- /dev/null
+++ b/src/aiet/cli/common.py
@@ -0,0 +1,173 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common functions for cli module."""
+import enum
+import logging
+from functools import wraps
+from signal import SIG_IGN
+from signal import SIGINT
+from signal import signal as signal_handler
+from signal import SIGTERM
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+
+from click import ClickException
+from click import Context
+from click import UsageError
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.execution import AnotherInstanceIsRunningException
+from aiet.backend.execution import ConnectionException
+from aiet.backend.protocol import SSHConnectionException
+from aiet.utils.proc import CommandFailedException
+
+
+class MiddlewareExitCode(enum.IntEnum):
+ """Middleware exit codes."""
+
+ SUCCESS = 0
+ # exit codes 1 and 2 are used by click
+ SHUTDOWN_REQUESTED = 3
+ BACKEND_ERROR = 4
+ CONCURRENT_ERROR = 5
+ CONNECTION_ERROR = 6
+ CONFIGURATION_ERROR = 7
+ MODEL_OPTIMISED_ERROR = 8
+ INVALID_TFLITE_FILE_ERROR = 9
+
+
+class CustomClickException(ClickException):
+ """Custom click exception."""
+
+ def show(self, file: Any = None) -> None:
+ """Override show method."""
+ super().show(file)
+
+ logging.debug("Execution failed with following exception: ", exc_info=self)
+
+
+class MiddlewareShutdownException(CustomClickException):
+ """Exception indicates that user requested middleware shutdown."""
+
+ exit_code = int(MiddlewareExitCode.SHUTDOWN_REQUESTED)
+
+
+class BackendException(CustomClickException):
+ """Exception indicates that command failed."""
+
+ exit_code = int(MiddlewareExitCode.BACKEND_ERROR)
+
+
+class ConcurrentErrorException(CustomClickException):
+ """Exception indicates concurrent execution error."""
+
+ exit_code = int(MiddlewareExitCode.CONCURRENT_ERROR)
+
+
+class BackendConnectionException(CustomClickException):
+ """Exception indicates that connection could not be established."""
+
+ exit_code = int(MiddlewareExitCode.CONNECTION_ERROR)
+
+
+class BackendConfigurationException(CustomClickException):
+ """Exception indicates some configuration issue."""
+
+ exit_code = int(MiddlewareExitCode.CONFIGURATION_ERROR)
+
+
+class ModelOptimisedException(CustomClickException):
+ """Exception indicates input file has previously been Vela optimised."""
+
+ exit_code = int(MiddlewareExitCode.MODEL_OPTIMISED_ERROR)
+
+
+class InvalidTFLiteFileError(CustomClickException):
+ """Exception indicates input TFLite file is misformatted."""
+
+ exit_code = int(MiddlewareExitCode.INVALID_TFLITE_FILE_ERROR)
+
+
+def print_command_details(command: Dict) -> None:
+ """Print command details including parameters."""
+ command_strings = command["command_strings"]
+ print("Commands: {}".format(command_strings))
+ user_params = command["user_params"]
+ for i, param in enumerate(user_params, 1):
+ print("User parameter #{}".format(i))
+ print("\tName: {}".format(param.get("name", "-")))
+ print("\tDescription: {}".format(param["description"]))
+ print("\tPossible values: {}".format(param.get("values", "-")))
+ print("\tDefault value: {}".format(param.get("default_value", "-")))
+ print("\tAlias: {}".format(param.get("alias", "-")))
+
+
+def raise_exception_at_signal(
+ signum: int, frame: Any # pylint: disable=unused-argument
+) -> None:
+ """Handle signals."""
+ # Disable both SIGINT and SIGTERM signals. Further SIGINT and SIGTERM
+ # signals will be ignored as we allow a graceful shutdown.
+ # Unused arguments must be present here in definition as used in signal handler
+ # callback
+
+ signal_handler(SIGINT, SIG_IGN)
+ signal_handler(SIGTERM, SIG_IGN)
+ raise MiddlewareShutdownException("Middleware shutdown requested")
+
+
+def middleware_exception_handler(func: Callable) -> Callable:
+ """Handle backend exceptions decorator."""
+
+ @wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ try:
+ return func(*args, **kwargs)
+ except (MiddlewareShutdownException, UsageError, ClickException) as error:
+ # click should take care of these exceptions
+ raise error
+ except ValueError as error:
+ raise ClickException(str(error)) from error
+ except AnotherInstanceIsRunningException as error:
+ raise ConcurrentErrorException(
+ "Another instance of the system is running"
+ ) from error
+ except (SSHConnectionException, ConnectionException) as error:
+ raise BackendConnectionException(str(error)) from error
+ except ConfigurationException as error:
+ raise BackendConfigurationException(str(error)) from error
+ except (CommandFailedException, Exception) as error:
+ raise BackendException(
+ "Execution failed. Please check output for the details."
+ ) from error
+
+ return wrapper
+
+
+def middleware_signal_handler(func: Callable) -> Callable:
+ """Handle signals decorator."""
+
+ @wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ # Set up signal handlers for SIGINT (ctrl-c) and SIGTERM (kill command)
+ # The handler ignores further signals and it raises an exception
+ signal_handler(SIGINT, raise_exception_at_signal)
+ signal_handler(SIGTERM, raise_exception_at_signal)
+
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def set_format(ctx: Context, format_: str) -> None:
+ """Save format in click context."""
+ ctx_obj = ctx.ensure_object(dict)
+ ctx_obj["format"] = format_
+
+
+def get_format(ctx: Context) -> str:
+ """Get format from click context."""
+ ctx_obj = cast(Dict[str, str], ctx.ensure_object(dict))
+ return ctx_obj["format"]
diff --git a/src/aiet/cli/completion.py b/src/aiet/cli/completion.py
new file mode 100644
index 0000000..71f054f
--- /dev/null
+++ b/src/aiet/cli/completion.py
@@ -0,0 +1,72 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Add auto completion to different shells with these helpers.
+
+See: https://click.palletsprojects.com/en/8.0.x/shell-completion/
+"""
+import click
+
+
+def _get_package_name() -> str:
+ return __name__.split(".", maxsplit=1)[0]
+
+
+# aiet completion bash
+@click.group(name="completion")
+def completion_cmd() -> None:
+ """Enable auto completion for your shell."""
+
+
+@completion_cmd.command(name="bash")
+def bash_cmd() -> None:
+ """
+ Enable auto completion for bash.
+
+ Use this command to activate completion in the current bash:
+
+ eval "`aiet completion bash`"
+
+ Use this command to add auto completion to bash globally, if you have aiet
+ installed globally (requires starting a new shell afterwards):
+
+ aiet completion bash >> ~/.bashrc
+ """
+ package_name = _get_package_name()
+ print(f'eval "$(_{package_name.upper()}_COMPLETE=bash_source {package_name})"')
+
+
+@completion_cmd.command(name="zsh")
+def zsh_cmd() -> None:
+ """
+ Enable auto completion for zsh.
+
+ Use this command to activate completion in the current zsh:
+
+ eval "`aiet completion zsh`"
+
+ Use this command to add auto completion to zsh globally, if you have aiet
+ installed globally (requires starting a new shell afterwards):
+
+ aiet completion zsh >> ~/.zshrc
+ """
+ package_name = _get_package_name()
+ print(f'eval "$(_{package_name.upper()}_COMPLETE=zsh_source {package_name})"')
+
+
+@completion_cmd.command(name="fish")
+def fish_cmd() -> None:
+ """
+ Enable auto completion for fish.
+
+ Use this command to activate completion in the current fish:
+
+ eval "`aiet completion fish`"
+
+ Use this command to add auto completion to fish globally, if you have aiet
+ installed globally (requires starting a new shell afterwards):
+
+ aiet completion fish >> ~/.config/fish/completions/aiet.fish
+ """
+ package_name = _get_package_name()
+ print(f'eval "(env _{package_name.upper()}_COMPLETE=fish_source {package_name})"')
diff --git a/src/aiet/cli/system.py b/src/aiet/cli/system.py
new file mode 100644
index 0000000..f1f7637
--- /dev/null
+++ b/src/aiet/cli/system.py
@@ -0,0 +1,122 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module to manage the CLI interface of systems."""
+import json
+from pathlib import Path
+from typing import cast
+
+import click
+
+from aiet.backend.application import get_available_applications
+from aiet.backend.system import get_available_systems
+from aiet.backend.system import get_available_systems_directory_names
+from aiet.backend.system import get_system
+from aiet.backend.system import install_system
+from aiet.backend.system import remove_system
+from aiet.backend.system import System
+from aiet.cli.common import get_format
+from aiet.cli.common import print_command_details
+from aiet.cli.common import set_format
+
+
+@click.group(name="system")
+@click.option(
+ "-f",
+ "--format",
+ "format_",
+ type=click.Choice(["cli", "json"]),
+ default="cli",
+ show_default=True,
+)
+@click.pass_context
+def system_cmd(ctx: click.Context, format_: str) -> None:
+ """Sub command to manage systems."""
+ set_format(ctx, format_)
+
+
+@system_cmd.command(name="list")
+@click.pass_context
+def list_cmd(ctx: click.Context) -> None:
+ """List all available systems."""
+ available_systems = get_available_systems()
+ system_names = [system.name for system in available_systems]
+ if get_format(ctx) == "json":
+ data = {"type": "system", "available": system_names}
+ print(json.dumps(data))
+ else:
+ print("Available systems:\n")
+ print(*system_names, sep="\n")
+
+
+@system_cmd.command(name="details")
+@click.option(
+ "-n",
+ "--name",
+ "system_name",
+ type=click.Choice([s.name for s in get_available_systems()]),
+ required=True,
+)
+@click.pass_context
+def details_cmd(ctx: click.Context, system_name: str) -> None:
+ """Details of a specific system."""
+ system = cast(System, get_system(system_name))
+ applications = [
+ s.name for s in get_available_applications() if s.can_run_on(system.name)
+ ]
+ system_details = system.get_details()
+ if get_format(ctx) == "json":
+ system_details["available_application"] = applications
+ print(json.dumps(system_details))
+ else:
+ system_details_template = (
+ 'System "{name}" details\n'
+ "Description: {description}\n"
+ "Data Transfer Protocol: {protocol}\n"
+ "Available Applications: {available_application}"
+ )
+ print(
+ system_details_template.format(
+ name=system_details["name"],
+ description=system_details["description"],
+ protocol=system_details["data_transfer_protocol"],
+ available_application=", ".join(applications),
+ )
+ )
+
+ if system_details["annotations"]:
+ print("Annotations:")
+ for ann_name, ann_value in system_details["annotations"].items():
+ print("\t{}: {}".format(ann_name, ann_value))
+
+ command_details = system_details["commands"]
+ for command, details in command_details.items():
+ print("\n{} commands:".format(command))
+ print_command_details(details)
+
+
+@system_cmd.command(name="install")
+@click.option(
+ "-s",
+ "--source",
+ "source",
+ required=True,
+ help="Path to the directory or archive with system definition",
+)
+def install_cmd(source: str) -> None:
+ """Install new system."""
+ source_path = Path(source)
+ install_system(source_path)
+
+
+@system_cmd.command(name="remove")
+@click.option(
+ "-d",
+ "--directory_name",
+ "directory_name",
+ type=click.Choice(get_available_systems_directory_names()),
+ required=True,
+ help="Name of the directory with system",
+)
+def remove_cmd(directory_name: str) -> None:
+ """Remove system by given name."""
+ remove_system(directory_name)
diff --git a/src/aiet/cli/tool.py b/src/aiet/cli/tool.py
new file mode 100644
index 0000000..2c80821
--- /dev/null
+++ b/src/aiet/cli/tool.py
@@ -0,0 +1,143 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module to manage the CLI interface of tools."""
+import json
+from typing import Any
+from typing import List
+from typing import Optional
+
+import click
+
+from aiet.backend.execution import execute_tool_command
+from aiet.backend.tool import get_tool
+from aiet.backend.tool import get_unique_tool_names
+from aiet.cli.common import get_format
+from aiet.cli.common import middleware_exception_handler
+from aiet.cli.common import middleware_signal_handler
+from aiet.cli.common import print_command_details
+from aiet.cli.common import set_format
+
+
+@click.group(name="tool")
+@click.option(
+ "-f",
+ "--format",
+ "format_",
+ type=click.Choice(["cli", "json"]),
+ default="cli",
+ show_default=True,
+)
+@click.pass_context
+def tool_cmd(ctx: click.Context, format_: str) -> None:
+ """Sub command to manage tools."""
+ set_format(ctx, format_)
+
+
+@tool_cmd.command(name="list")
+@click.pass_context
+def list_cmd(ctx: click.Context) -> None:
+ """List all available tools."""
+ # raise NotImplementedError("TODO")
+ tool_names = get_unique_tool_names()
+ tool_names.sort()
+ if get_format(ctx) == "json":
+ data = {"type": "tool", "available": tool_names}
+ print(json.dumps(data))
+ else:
+ print("Available tools:\n")
+ print(*tool_names, sep="\n")
+
+
+def validate_system(
+ ctx: click.Context,
+ _: click.Parameter, # param is not used
+ value: Any,
+) -> Any:
+ """Validate provided system name depending on the the tool name."""
+ tool_name = ctx.params["tool_name"]
+ tools = get_tool(tool_name, value)
+ if not tools:
+ supported_systems = [tool.supported_systems[0] for tool in get_tool(tool_name)]
+ raise click.BadParameter(
+ message="'{}' is not one of {}.".format(
+ value,
+ ", ".join("'{}'".format(system) for system in supported_systems),
+ ),
+ ctx=ctx,
+ )
+ return value
+
+
+@tool_cmd.command(name="details")
+@click.option(
+ "-n",
+ "--name",
+ "tool_name",
+ type=click.Choice(get_unique_tool_names()),
+ required=True,
+)
+@click.option(
+ "-s",
+ "--system",
+ "system_name",
+ callback=validate_system,
+ required=False,
+)
+@click.pass_context
+@middleware_signal_handler
+@middleware_exception_handler
+def details_cmd(ctx: click.Context, tool_name: str, system_name: Optional[str]) -> None:
+ """Details of a specific tool."""
+ tools = get_tool(tool_name, system_name)
+ if get_format(ctx) == "json":
+ tools_details = [s.get_details() for s in tools]
+ print(json.dumps(tools_details))
+ else:
+ for tool in tools:
+ tool_details = tool.get_details()
+ tool_details_template = 'Tool "{name}" details\nDescription: {description}'
+
+ print(
+ tool_details_template.format(
+ name=tool_details["name"],
+ description=tool_details["description"],
+ )
+ )
+
+ print(
+ "\nSupported systems: {}".format(
+ ", ".join(tool_details["supported_systems"])
+ )
+ )
+
+ command_details = tool_details["commands"]
+
+ for command, details in command_details.items():
+ print("\n{} commands:".format(command))
+ print_command_details(details)
+
+
+# pylint: disable=too-many-arguments
+@tool_cmd.command(name="execute")
+@click.option(
+ "-n",
+ "--name",
+ "tool_name",
+ type=click.Choice(get_unique_tool_names()),
+ required=True,
+)
+@click.option("-p", "--param", "tool_params", multiple=True)
+@click.option(
+ "-s",
+ "--system",
+ "system_name",
+ callback=validate_system,
+ required=False,
+)
+@middleware_signal_handler
+@middleware_exception_handler
+def execute_cmd(
+ tool_name: str, tool_params: List[str], system_name: Optional[str]
+) -> None:
+ """Execute tool commands."""
+ execute_tool_command(tool_name, tool_params, system_name)
diff --git a/src/aiet/main.py b/src/aiet/main.py
new file mode 100644
index 0000000..6898ad9
--- /dev/null
+++ b/src/aiet/main.py
@@ -0,0 +1,13 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Entry point module of AIET."""
+from aiet.cli import cli
+
+
+def main() -> None:
+ """Entry point of aiet application."""
+ cli() # pylint: disable=no-value-for-parameter
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/aiet/resources/applications/.gitignore b/src/aiet/resources/applications/.gitignore
new file mode 100644
index 0000000..0226166
--- /dev/null
+++ b/src/aiet/resources/applications/.gitignore
@@ -0,0 +1,6 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# Ignore everything in this directory
+*
+# Except this file
+!.gitignore
diff --git a/src/aiet/resources/systems/.gitignore b/src/aiet/resources/systems/.gitignore
new file mode 100644
index 0000000..0226166
--- /dev/null
+++ b/src/aiet/resources/systems/.gitignore
@@ -0,0 +1,6 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# Ignore everything in this directory
+*
+# Except this file
+!.gitignore
diff --git a/src/aiet/resources/tools/vela/aiet-config.json b/src/aiet/resources/tools/vela/aiet-config.json
new file mode 100644
index 0000000..c12f291
--- /dev/null
+++ b/src/aiet/resources/tools/vela/aiet-config.json
@@ -0,0 +1,73 @@
+[
+ {
+ "name": "vela",
+ "description": "Neural network model compiler for Arm Ethos-U NPUs",
+ "supported_systems": [
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U55"
+ },
+ {
+ "name": "Corstone-310: Cortex-M85+Ethos-U55"
+ },
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U65",
+ "variables": {
+ "accelerator_config_prefix": "ethos-u65",
+ "system_config": "Ethos_U65_High_End",
+ "shared_sram": "U65_Shared_Sram"
+ },
+ "user_params": {
+ "run": [
+ {
+ "description": "MACs per cycle",
+ "values": [
+ "256",
+ "512"
+ ],
+ "default_value": "512",
+ "alias": "mac"
+ }
+ ]
+ }
+ }
+ ],
+ "variables": {
+ "accelerator_config_prefix": "ethos-u55",
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "shared_sram": "U55_Shared_Sram"
+ },
+ "commands": {
+ "run": [
+ "run_vela {user_params:input} {user_params:output} --config {tool.config_dir}/vela.ini --accelerator-config {variables:accelerator_config_prefix}-{user_params:mac} --system-config {variables:system_config} --memory-mode {variables:shared_sram} --optimise Performance"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "description": "MACs per cycle",
+ "values": [
+ "32",
+ "64",
+ "128",
+ "256"
+ ],
+ "default_value": "128",
+ "alias": "mac"
+ },
+ {
+ "name": "--input-model",
+ "description": "Path to the TFLite model",
+ "values": [],
+ "alias": "input"
+ },
+ {
+ "name": "--output-model",
+ "description": "Path to the output model file of the vela-optimisation step. The vela output is saved in the parent directory.",
+ "values": [],
+ "default_value": "output_model.tflite",
+ "alias": "output"
+ }
+ ]
+ }
+ }
+]
diff --git a/src/aiet/resources/tools/vela/aiet-config.json.license b/src/aiet/resources/tools/vela/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/aiet/resources/tools/vela/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/aiet/resources/tools/vela/check_model.py b/src/aiet/resources/tools/vela/check_model.py
new file mode 100644
index 0000000..7c700b1
--- /dev/null
+++ b/src/aiet/resources/tools/vela/check_model.py
@@ -0,0 +1,75 @@
+# SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Check if a TFLite model file is Vela-optimised."""
+import struct
+from pathlib import Path
+
+from ethosu.vela.tflite.Model import Model
+
+from aiet.cli.common import InvalidTFLiteFileError
+from aiet.cli.common import ModelOptimisedException
+from aiet.utils.fs import read_file_as_bytearray
+
+
+def get_model_from_file(input_model_file: Path) -> Model:
+ """Generate Model instance from TFLite file using flatc generated code."""
+ buffer = read_file_as_bytearray(input_model_file)
+ try:
+ model = Model.GetRootAsModel(buffer, 0)
+ except (TypeError, RuntimeError, struct.error) as tflite_error:
+ raise InvalidTFLiteFileError(
+ f"Error reading in model from {input_model_file}."
+ ) from tflite_error
+ return model
+
+
+def is_vela_optimised(tflite_model: Model) -> bool:
+ """Return True if 'ethos-u' custom operator found in the Model."""
+ operators = get_operators_from_model(tflite_model)
+
+ custom_codes = get_custom_codes_from_operators(operators)
+
+ return check_custom_codes_for_ethosu(custom_codes)
+
+
+def get_operators_from_model(tflite_model: Model) -> list:
+ """Return list of the unique operator codes used in the Model."""
+ return [
+ tflite_model.OperatorCodes(index)
+ for index in range(tflite_model.OperatorCodesLength())
+ ]
+
+
+def get_custom_codes_from_operators(operators: list) -> list:
+ """Return list of each operator's CustomCode() strings, if they exist."""
+ return [
+ operator.CustomCode()
+ for operator in operators
+ if operator.CustomCode() is not None
+ ]
+
+
+def check_custom_codes_for_ethosu(custom_codes: list) -> bool:
+ """Check for existence of ethos-u string in the custom codes."""
+ return any(
+ custom_code_name.decode("utf-8") == "ethos-u"
+ for custom_code_name in custom_codes
+ )
+
+
+def check_model(tflite_file_name: str) -> None:
+ """Raise an exception if model in given file is Vela optimised."""
+ tflite_path = Path(tflite_file_name)
+
+ tflite_model = get_model_from_file(tflite_path)
+
+ if is_vela_optimised(tflite_model):
+ raise ModelOptimisedException(
+ f"TFLite model in {tflite_file_name} is already "
+ f"vela optimised ('ethos-u' custom op detected)."
+ )
+
+ print(
+ f"TFLite model in {tflite_file_name} is not vela optimised "
+ f"('ethos-u' custom op not detected)."
+ )
diff --git a/src/aiet/resources/tools/vela/run_vela.py b/src/aiet/resources/tools/vela/run_vela.py
new file mode 100644
index 0000000..2c1b0be
--- /dev/null
+++ b/src/aiet/resources/tools/vela/run_vela.py
@@ -0,0 +1,65 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Wrapper to only run Vela when the input is not already optimised."""
+import shutil
+import subprocess
+from pathlib import Path
+from typing import Tuple
+
+import click
+
+from aiet.cli.common import ModelOptimisedException
+from aiet.resources.tools.vela.check_model import check_model
+
+
+def vela_output_model_path(input_model: str, output_dir: str) -> Path:
+ """Construct the path to the Vela output file."""
+ in_path = Path(input_model)
+ tflite_vela = Path(output_dir) / f"{in_path.stem}_vela{in_path.suffix}"
+ return tflite_vela
+
+
+def execute_vela(vela_args: Tuple, output_dir: Path, input_model: str) -> None:
+ """Execute vela as external call."""
+ cmd = ["vela"] + list(vela_args)
+ cmd += ["--output-dir", str(output_dir)] # Re-add parsed out_dir to arguments
+ cmd += [input_model]
+ subprocess.run(cmd, check=True)
+
+
+@click.command(context_settings=dict(ignore_unknown_options=True))
+@click.option(
+ "--input-model",
+ "-i",
+ type=click.Path(exists=True, file_okay=True, readable=True),
+ required=True,
+)
+@click.option("--output-model", "-o", type=click.Path(), required=True)
+# Collect the remaining arguments to be directly forwarded to Vela
+@click.argument("vela-args", nargs=-1, type=click.UNPROCESSED)
+def run_vela(input_model: str, output_model: str, vela_args: Tuple) -> None:
+ """Check input, run Vela (if needed) and copy optimised file to destination."""
+ output_dir = Path(output_model).parent
+ try:
+ check_model(input_model) # raises an exception if already Vela-optimised
+ execute_vela(vela_args, output_dir, input_model)
+ print("Vela optimisation complete.")
+ src_model = vela_output_model_path(input_model, str(output_dir))
+ except ModelOptimisedException as ex:
+ # Input already optimized: copy input file to destination path and return
+ print(f"Input already vela-optimised.\n{ex}")
+ src_model = Path(input_model)
+ except subprocess.CalledProcessError as ex:
+ print(ex)
+ raise SystemExit(ex.returncode) from ex
+
+ try:
+ shutil.copyfile(src_model, output_model)
+ except (shutil.SameFileError, OSError) as ex:
+ print(ex)
+ raise SystemExit(ex.errno) from ex
+
+
+def main() -> None:
+ """Entry point of check_model application."""
+ run_vela() # pylint: disable=no-value-for-parameter
diff --git a/src/aiet/resources/tools/vela/vela.ini b/src/aiet/resources/tools/vela/vela.ini
new file mode 100644
index 0000000..5996553
--- /dev/null
+++ b/src/aiet/resources/tools/vela/vela.ini
@@ -0,0 +1,53 @@
+; SPDX-FileCopyrightText: Copyright 2021-2022, Arm Limited and/or its affiliates.
+; SPDX-License-Identifier: Apache-2.0
+
+; -----------------------------------------------------------------------------
+; Vela configuration file
+
+; -----------------------------------------------------------------------------
+; System Configuration
+
+; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s)
+[System_Config.Ethos_U55_High_End_Embedded]
+core_clock=500e6
+axi0_port=Sram
+axi1_port=OffChipFlash
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+OffChipFlash_clock_scale=0.125
+OffChipFlash_burst_length=128
+OffChipFlash_read_latency=64
+OffChipFlash_write_latency=64
+
+; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s)
+[System_Config.Ethos_U65_High_End]
+core_clock=1e9
+axi0_port=Sram
+axi1_port=Dram
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+Dram_clock_scale=0.234375
+Dram_burst_length=128
+Dram_read_latency=500
+Dram_write_latency=250
+
+; -----------------------------------------------------------------------------
+; Memory Mode
+
+; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software
+; The non-SRAM memory is assumed to be read-only
+[Memory_Mode.U55_Shared_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi0
+cache_mem_area=Axi0
+arena_cache_size=4194304
+
+[Memory_Mode.U65_Shared_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi0
+cache_mem_area=Axi0
+arena_cache_size=2097152
diff --git a/src/aiet/utils/__init__.py b/src/aiet/utils/__init__.py
new file mode 100644
index 0000000..fc7ef7c
--- /dev/null
+++ b/src/aiet/utils/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""This module contains all utils shared across aiet project."""
diff --git a/src/aiet/utils/fs.py b/src/aiet/utils/fs.py
new file mode 100644
index 0000000..ea99a69
--- /dev/null
+++ b/src/aiet/utils/fs.py
@@ -0,0 +1,116 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module to host all file system related functions."""
+import importlib.resources as pkg_resources
+import re
+import shutil
+from pathlib import Path
+from typing import Any
+from typing import Literal
+from typing import Optional
+
+ResourceType = Literal["applications", "systems", "tools"]
+
+
+def get_aiet_resources() -> Path:
+ """Get resources folder path."""
+ with pkg_resources.path("aiet", "__init__.py") as init_path:
+ project_root = init_path.parent
+ return project_root / "resources"
+
+
+def get_resources(name: ResourceType) -> Path:
+ """Return the absolute path of the specified resource.
+
+ It uses importlib to return resources packaged with MANIFEST.in.
+ """
+ if not name:
+ raise ResourceWarning("Resource name is not provided")
+
+ resource_path = get_aiet_resources() / name
+ if resource_path.is_dir():
+ return resource_path
+
+ raise ResourceWarning("Resource '{}' not found.".format(name))
+
+
+def copy_directory_content(source: Path, destination: Path) -> None:
+ """Copy content of the source directory into destination directory."""
+ for item in source.iterdir():
+ src = source / item.name
+ dest = destination / item.name
+
+ if src.is_dir():
+ shutil.copytree(src, dest)
+ else:
+ shutil.copy2(src, dest)
+
+
+def remove_resource(resource_directory: str, resource_type: ResourceType) -> None:
+ """Remove resource data."""
+ resources = get_resources(resource_type)
+
+ resource_location = resources / resource_directory
+ if not resource_location.exists():
+ raise Exception("Resource {} does not exist".format(resource_directory))
+
+ if not resource_location.is_dir():
+ raise Exception("Wrong resource {}".format(resource_directory))
+
+ shutil.rmtree(resource_location)
+
+
+def remove_directory(directory_path: Optional[Path]) -> None:
+ """Remove directory."""
+ if not directory_path or not directory_path.is_dir():
+ raise Exception("No directory path provided")
+
+ shutil.rmtree(directory_path)
+
+
+def recreate_directory(directory_path: Optional[Path]) -> None:
+ """Recreate directory."""
+ if not directory_path:
+ raise Exception("No directory path provided")
+
+ if directory_path.exists() and not directory_path.is_dir():
+ raise Exception(
+ "Path {} does exist and it is not a directory".format(str(directory_path))
+ )
+
+ if directory_path.is_dir():
+ remove_directory(directory_path)
+
+ directory_path.mkdir()
+
+
+def read_file(file_path: Path, mode: Optional[str] = None) -> Any:
+ """Read file as string or bytearray."""
+ if file_path.is_file():
+ if mode is not None:
+ # Ignore pylint warning because mode can be 'binary' as well which
+ # is not compatible with specifying encodings.
+ with open(file_path, mode) as file: # pylint: disable=unspecified-encoding
+ return file.read()
+ else:
+ with open(file_path, encoding="utf-8") as file:
+ return file.read()
+
+ if mode == "rb":
+ return b""
+ return ""
+
+
+def read_file_as_string(file_path: Path) -> str:
+ """Read file as string."""
+ return str(read_file(file_path))
+
+
+def read_file_as_bytearray(file_path: Path) -> bytearray:
+ """Read a file as bytearray."""
+ return bytearray(read_file(file_path, mode="rb"))
+
+
+def valid_for_filename(value: str, replacement: str = "") -> str:
+ """Replace non alpha numeric characters."""
+ return re.sub(r"[^\w.]", replacement, value, flags=re.ASCII)
diff --git a/src/aiet/utils/helpers.py b/src/aiet/utils/helpers.py
new file mode 100644
index 0000000..6d3cd22
--- /dev/null
+++ b/src/aiet/utils/helpers.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Helpers functions."""
+import logging
+from typing import Any
+
+
+def set_verbosity(
+ ctx: Any, option: Any, verbosity: Any # pylint: disable=unused-argument
+) -> None:
+ """Set the logging level according to the verbosity."""
+ # Unused arguments must be present here in definition as these are required in
+ # function definition when set as a callback
+ if verbosity == 1:
+ logging.getLogger().setLevel(logging.INFO)
+ elif verbosity > 1:
+ logging.getLogger().setLevel(logging.DEBUG)
diff --git a/src/aiet/utils/proc.py b/src/aiet/utils/proc.py
new file mode 100644
index 0000000..b6f4357
--- /dev/null
+++ b/src/aiet/utils/proc.py
@@ -0,0 +1,283 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Processes module.
+
+This module contains all classes and functions for dealing with Linux
+processes.
+"""
+import csv
+import datetime
+import logging
+import shlex
+import signal
+import time
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import NamedTuple
+from typing import Optional
+from typing import Tuple
+
+import psutil
+from sh import Command
+from sh import CommandNotFound
+from sh import ErrorReturnCode
+from sh import RunningCommand
+
+from aiet.utils.fs import valid_for_filename
+
+
+class CommandFailedException(Exception):
+ """Exception for failed command execution."""
+
+
+class ShellCommand:
+ """Wrapper class for shell commands."""
+
+ def __init__(self, base_log_path: str = "/tmp") -> None:
+ """Initialise the class.
+
+ base_log_path: it is the base directory where logs will be stored
+ """
+ self.base_log_path = base_log_path
+
+ def run(
+ self,
+ cmd: str,
+ *args: str,
+ _cwd: Optional[Path] = None,
+ _tee: bool = True,
+ _bg: bool = True,
+ _out: Any = None,
+ _err: Any = None,
+ _search_paths: Optional[List[Path]] = None
+ ) -> RunningCommand:
+ """Run the shell command with the given arguments.
+
+ There are special arguments that modify the behaviour of the process.
+ _cwd: current working directory
+ _tee: it redirects the stdout both to console and file
+ _bg: if True, it runs the process in background and the command is not
+ blocking.
+ _out: use this object for stdout redirect,
+ _err: use this object for stderr redirect,
+ _search_paths: If presented used for searching executable
+ """
+ try:
+ kwargs = {}
+ if _cwd:
+ kwargs["_cwd"] = str(_cwd)
+ command = Command(cmd, _search_paths).bake(args, **kwargs)
+ except CommandNotFound as error:
+ logging.error("Command '%s' not found", error.args[0])
+ raise error
+
+ out, err = _out, _err
+ if not _out and not _err:
+ out, err = [
+ str(item)
+ for item in self.get_stdout_stderr_paths(self.base_log_path, cmd)
+ ]
+
+ return command(_out=out, _err=err, _tee=_tee, _bg=_bg, _bg_exc=False)
+
+ @classmethod
+ def get_stdout_stderr_paths(cls, base_log_path: str, cmd: str) -> Tuple[Path, Path]:
+ """Construct and returns the paths of stdout/stderr files."""
+ timestamp = datetime.datetime.now().timestamp()
+ base_path = Path(base_log_path)
+ base = base_path / "{}_{}".format(valid_for_filename(cmd, "_"), timestamp)
+ stdout = base.with_suffix(".out")
+ stderr = base.with_suffix(".err")
+ try:
+ stdout.touch()
+ stderr.touch()
+ except FileNotFoundError as error:
+ logging.error("File not found: %s", error.filename)
+ raise error
+ return stdout, stderr
+
+
+def parse_command(command: str, shell: str = "bash") -> List[str]:
+ """Parse command."""
+ cmd, *args = shlex.split(command, posix=True)
+
+ if is_shell_script(cmd):
+ args = [cmd] + args
+ cmd = shell
+
+ return [cmd] + args
+
+
+def get_stdout_stderr_paths(
+ command: str, base_log_path: str = "/tmp"
+) -> Tuple[Path, Path]:
+ """Construct and returns the paths of stdout/stderr files."""
+ cmd, *_ = parse_command(command)
+
+ return ShellCommand.get_stdout_stderr_paths(base_log_path, cmd)
+
+
+def execute_command( # pylint: disable=invalid-name
+ command: str,
+ cwd: Path,
+ bg: bool = False,
+ shell: str = "bash",
+ out: Any = None,
+ err: Any = None,
+) -> RunningCommand:
+ """Execute shell command."""
+ cmd, *args = parse_command(command, shell)
+
+ search_paths = None
+ if cmd != shell and (cwd / cmd).is_file():
+ search_paths = [cwd]
+
+ return ShellCommand().run(
+ cmd, *args, _cwd=cwd, _bg=bg, _search_paths=search_paths, _out=out, _err=err
+ )
+
+
+def is_shell_script(cmd: str) -> bool:
+ """Check if command is shell script."""
+ return cmd.endswith(".sh")
+
+
+def run_and_wait(
+ command: str,
+ cwd: Path,
+ terminate_on_error: bool = False,
+ out: Any = None,
+ err: Any = None,
+) -> Tuple[int, bytearray, bytearray]:
+ """
+ Run command and wait while it is executing.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+ running_cmd: Optional[RunningCommand] = None
+ try:
+ running_cmd = execute_command(command, cwd, bg=True, out=out, err=err)
+ return running_cmd.exit_code, running_cmd.stdout, running_cmd.stderr
+ except ErrorReturnCode as cmd_failed:
+ raise CommandFailedException() from cmd_failed
+ except Exception as error:
+ is_running = running_cmd is not None and running_cmd.is_alive()
+ if terminate_on_error and is_running:
+ print("Terminating ...")
+ terminate_command(running_cmd)
+
+ raise error
+
+
+def terminate_command(
+ running_cmd: RunningCommand,
+ wait: bool = True,
+ wait_period: float = 0.5,
+ number_of_attempts: int = 20,
+) -> None:
+ """Terminate running command."""
+ try:
+ running_cmd.process.signal_group(signal.SIGINT)
+ if wait:
+ for _ in range(number_of_attempts):
+ time.sleep(wait_period)
+ if not running_cmd.is_alive():
+ return
+ print(
+ "Unable to terminate process {}. Sending SIGTERM...".format(
+ running_cmd.process.pid
+ )
+ )
+ running_cmd.process.signal_group(signal.SIGTERM)
+ except ProcessLookupError:
+ pass
+
+
+def terminate_external_process(
+ process: psutil.Process,
+ wait_period: float = 0.5,
+ number_of_attempts: int = 20,
+ wait_for_termination: float = 5.0,
+) -> None:
+ """Terminate external process."""
+ try:
+ process.terminate()
+ for _ in range(number_of_attempts):
+ if not process.is_running():
+ return
+ time.sleep(wait_period)
+
+ if process.is_running():
+ process.terminate()
+ time.sleep(wait_for_termination)
+ except psutil.Error:
+ print("Unable to terminate process")
+
+
+class ProcessInfo(NamedTuple):
+ """Process information."""
+
+ name: str
+ executable: str
+ cwd: str
+ pid: int
+
+
+def save_process_info(pid: int, pid_file: Path) -> None:
+ """Save process information to file."""
+ try:
+ parent = psutil.Process(pid)
+ children = parent.children(recursive=True)
+ family = [parent] + children
+
+ with open(pid_file, "w", encoding="utf-8") as file:
+ csv_writer = csv.writer(file)
+ for member in family:
+ process_info = ProcessInfo(
+ name=member.name(),
+ executable=member.exe(),
+ cwd=member.cwd(),
+ pid=member.pid,
+ )
+ csv_writer.writerow(process_info)
+ except psutil.NoSuchProcess:
+ # if process does not exist or finishes before
+ # function call then nothing could be saved
+ # just ignore this exception and exit
+ pass
+
+
+def read_process_info(pid_file: Path) -> List[ProcessInfo]:
+ """Read information about previous system processes."""
+ if not pid_file.is_file():
+ return []
+
+ result = []
+ with open(pid_file, encoding="utf-8") as file:
+ csv_reader = csv.reader(file)
+ for row in csv_reader:
+ name, executable, cwd, pid = row
+ result.append(
+ ProcessInfo(name=name, executable=executable, cwd=cwd, pid=int(pid))
+ )
+
+ return result
+
+
+def print_command_stdout(command: RunningCommand) -> None:
+ """Print the stdout of a command.
+
+ The command has 2 states: running and done.
+ If the command is running, the output is taken by the running process.
+ If the command has ended its execution, the stdout is taken from stdout
+ property
+ """
+ if command.is_alive():
+ while True:
+ try:
+ print(command.next(), end="")
+ except StopIteration:
+ break
+ else:
+ print(command.stdout)