From 0efca3cadbad5517a59884576ddb90cfe7ac30f8 Mon Sep 17 00:00:00 2001 From: Diego Russo Date: Mon, 30 May 2022 13:34:14 +0100 Subject: Add MLIA codebase Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd --- src/aiet/backend/common.py | 532 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 532 insertions(+) create mode 100644 src/aiet/backend/common.py (limited to 'src/aiet/backend/common.py') diff --git a/src/aiet/backend/common.py b/src/aiet/backend/common.py new file mode 100644 index 0000000..b887ee7 --- /dev/null +++ b/src/aiet/backend/common.py @@ -0,0 +1,532 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain all common functions for the backends.""" +import json +import logging +import re +from abc import ABC +from collections import Counter +from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Final +from typing import IO +from typing import Iterable +from typing import List +from typing import Match +from typing import NamedTuple +from typing import Optional +from typing import Pattern +from typing import Tuple +from typing import Type +from typing import Union + +from aiet.backend.config import BackendConfig +from aiet.backend.config import BaseBackendConfig +from aiet.backend.config import NamedExecutionConfig +from aiet.backend.config import UserParamConfig +from aiet.backend.config import UserParamsConfig +from aiet.utils.fs import get_resources +from aiet.utils.fs import remove_resource +from aiet.utils.fs import ResourceType + + +AIET_CONFIG_FILE: Final[str] = "aiet-config.json" + + +class ConfigurationException(Exception): + """Configuration exception.""" + + +def get_backend_config(dir_path: Path) -> Path: + """Get path to backendir configuration file.""" + return dir_path / AIET_CONFIG_FILE + + +def get_backend_configs(resource_type: ResourceType) -> Iterable[Path]: + """Get path to the backend configs for provided resource_type.""" + return ( + get_backend_config(entry) for entry in get_backend_directories(resource_type) + ) + + +def get_backend_directories(resource_type: ResourceType) -> Iterable[Path]: + """Get path to the backend directories for provided resource_type.""" + return ( + entry + for entry in get_resources(resource_type).iterdir() + if is_backend_directory(entry) + ) + + +def is_backend_directory(dir_path: Path) -> bool: + """Check if path is backend's configuration directory.""" + return dir_path.is_dir() and get_backend_config(dir_path).is_file() + + +def remove_backend(directory_name: str, resource_type: ResourceType) -> None: + """Remove backend with provided type and directory_name.""" + if not directory_name: + raise Exception("No directory name provided") + + remove_resource(directory_name, resource_type) + + +def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig: + """Return a loaded json file.""" + if config is None: + raise Exception("Unable to read config") + + if isinstance(config, Path): + with config.open() as json_file: + return cast(BackendConfig, json.load(json_file)) + + return cast(BackendConfig, json.load(config)) + + +def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]: + """Split the parameter string in name and optional value. + + It manages the following cases: + --param=1 -> --param, 1 + --param 1 -> --param, 1 + --flag -> --flag, None + """ + data = re.split(" |=", parameter) + if len(data) == 1: + param_name = data[0] + param_value = None + else: + param_name = " ".join(data[0:-1]) + param_value = data[-1] + return param_name, param_value + + +class DataPaths(NamedTuple): + """DataPaths class.""" + + src: Path + dst: str + + +class Backend(ABC): + """Backend class.""" + + # pylint: disable=too-many-instance-attributes + + def __init__(self, config: BaseBackendConfig): + """Initialize backend.""" + name = config.get("name") + if not name: + raise ConfigurationException("Name is empty") + + self.name = name + self.description = config.get("description", "") + self.config_location = config.get("config_location") + self.variables = config.get("variables", {}) + self.build_dir = config.get("build_dir") + self.lock = config.get("lock", False) + if self.build_dir: + self.build_dir = self._substitute_variables(self.build_dir) + self.annotations = config.get("annotations", {}) + + self._parse_commands_and_params(config) + + def validate_parameter(self, command_name: str, parameter: str) -> bool: + """Validate the parameter string against the application configuration. + + We take the parameter string, extract the parameter name/value and + check them against the current configuration. + """ + param_name, param_value = parse_raw_parameter(parameter) + valid_param_name = valid_param_value = False + + command = self.commands.get(command_name) + if not command: + raise AttributeError("Unknown command: '{}'".format(command_name)) + + # Iterate over all available parameters until we have a match. + for param in command.params: + if self._same_parameter(param_name, param): + valid_param_name = True + # This is a non-empty list + if param.values: + # We check if the value is allowed in the configuration + valid_param_value = param_value in param.values + else: + # In this case we don't validate the value and accept + # whatever we have set. + valid_param_value = True + break + + return valid_param_name and valid_param_value + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Backend): + return False + + return ( + self.name == other.name + and self.description == other.description + and self.commands == other.commands + ) + + def __repr__(self) -> str: + """Represent the Backend instance by its name.""" + return self.name + + def _parse_commands_and_params(self, config: BaseBackendConfig) -> None: + """Parse commands and user parameters.""" + self.commands: Dict[str, Command] = {} + + commands = config.get("commands") + if commands: + params = config.get("user_params") + + for command_name in commands.keys(): + command_params = self._parse_params(params, command_name) + command_strings = [ + self._substitute_variables(cmd) + for cmd in commands.get(command_name, []) + ] + self.commands[command_name] = Command(command_strings, command_params) + + def _substitute_variables(self, str_val: str) -> str: + """Substitute variables in string. + + Variables is being substituted at backend's creation stage because + they could contain references to other params which will be + resolved later. + """ + if not str_val: + return str_val + + var_pattern: Final[Pattern] = re.compile(r"{variables:(?P\w+)}") + + def var_value(match: Match) -> str: + var_name = match["var_name"] + if var_name not in self.variables: + raise ConfigurationException("Unknown variable {}".format(var_name)) + + return self.variables[var_name] + + return var_pattern.sub(var_value, str_val) # type: ignore + + @classmethod + def _parse_params( + cls, params: Optional[UserParamsConfig], command: str + ) -> List["Param"]: + if not params: + return [] + + return [cls._parse_param(p) for p in params.get(command, [])] + + @classmethod + def _parse_param(cls, param: UserParamConfig) -> "Param": + """Parse a single parameter.""" + name = param.get("name") + if name is not None and not name: + raise ConfigurationException("Parameter has an empty 'name' attribute.") + values = param.get("values", None) + default_value = param.get("default_value", None) + description = param.get("description", "") + alias = param.get("alias") + + return Param( + name=name, + description=description, + values=values, + default_value=default_value, + alias=alias, + ) + + def _get_command_details(self) -> Dict: + command_details = { + command_name: command.get_details() + for command_name, command in self.commands.items() + } + return command_details + + def _get_user_param_value( + self, user_params: List[str], param: "Param" + ) -> Optional[str]: + """Get the user-specified value of a parameter.""" + for user_param in user_params: + user_param_name, user_param_value = parse_raw_parameter(user_param) + if user_param_name == param.name: + warn_message = ( + "The direct use of parameter name is deprecated" + " and might be removed in the future.\n" + f"Please use alias '{param.alias}' instead of " + "'{user_param_name}' to provide the parameter." + ) + logging.warning(warn_message) + + if self._same_parameter(user_param_name, param): + return user_param_value + + return None + + @staticmethod + def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool: + """Compare user parameter name with param name or alias.""" + # Strip the "=" sign in the param_name. This is needed just for + # comparison with the parameters passed by the user. + # The equal sign needs to be honoured when re-building the + # parameter back. + param_name = None if not param.name else param.name.rstrip("=") + return user_param_name_or_alias in [param_name, param.alias] + + def resolved_parameters( + self, command_name: str, user_params: List[str] + ) -> List[Tuple[Optional[str], "Param"]]: + """Return list of parameters with values.""" + result: List[Tuple[Optional[str], "Param"]] = [] + command = self.commands.get(command_name) + if not command: + return result + + for param in command.params: + value = self._get_user_param_value(user_params, param) + if not value: + value = param.default_value + result.append((value, param)) + + return result + + def build_command( + self, + command_name: str, + user_params: List[str], + param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str], + ) -> List[str]: + """ + Return a list of executable command strings. + + Given a command and associated parameters, returns a list of executable command + strings. + """ + command = self.commands.get(command_name) + if not command: + raise ConfigurationException( + "Command '{}' could not be found.".format(command_name) + ) + + commands_to_run = [] + + params_values = self.resolved_parameters(command_name, user_params) + for cmd_str in command.command_strings: + cmd_str = resolve_all_parameters( + cmd_str, param_resolver, command_name, params_values + ) + commands_to_run.append(cmd_str) + + return commands_to_run + + +class Param: + """Class for representing a generic application parameter.""" + + def __init__( # pylint: disable=too-many-arguments + self, + name: Optional[str], + description: str, + values: Optional[List[str]] = None, + default_value: Optional[str] = None, + alias: Optional[str] = None, + ) -> None: + """Construct a Param instance.""" + if not name and not alias: + raise ConfigurationException( + "Either name, alias or both must be set to identify a parameter." + ) + self.name = name + self.values = values + self.description = description + self.default_value = default_value + self.alias = alias + + def get_details(self) -> Dict: + """Return a dictionary with all relevant information of a Param.""" + return {key: value for key, value in self.__dict__.items() if value} + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Param): + return False + + return ( + self.name == other.name + and self.values == other.values + and self.default_value == other.default_value + and self.description == other.description + ) + + +class Command: + """Class for representing a command.""" + + def __init__( + self, command_strings: List[str], params: Optional[List[Param]] = None + ) -> None: + """Construct a Command instance.""" + self.command_strings = command_strings + + if params: + self.params = params + else: + self.params = [] + + self._validate() + + def _validate(self) -> None: + """Validate command.""" + if not self.params: + return + + aliases = [param.alias for param in self.params if param.alias is not None] + repeated_aliases = [ + alias for alias, count in Counter(aliases).items() if count > 1 + ] + + if repeated_aliases: + raise ConfigurationException( + "Non unique aliases {}".format(", ".join(repeated_aliases)) + ) + + both_name_and_alias = [ + param.name + for param in self.params + if param.name in aliases and param.name != param.alias + ] + if both_name_and_alias: + raise ConfigurationException( + "Aliases {} could not be used as parameter name".format( + ", ".join(both_name_and_alias) + ) + ) + + def get_details(self) -> Dict: + """Return a dictionary with all relevant information of a Command.""" + output = { + "command_strings": self.command_strings, + "user_params": [param.get_details() for param in self.params], + } + return output + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Command): + return False + + return ( + self.command_strings == other.command_strings + and self.params == other.params + ) + + +def resolve_all_parameters( + str_val: str, + param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str], + command_name: Optional[str] = None, + params_values: Optional[List[Tuple[Optional[str], Param]]] = None, +) -> str: + """Resolve all parameters in the string.""" + if not str_val: + return str_val + + param_pattern: Final[Pattern] = re.compile(r"{(?P[\w.:]+)}") + while param_pattern.findall(str_val): + str_val = param_pattern.sub( + lambda m: param_resolver( + m["param_name"], command_name or "", params_values or [] + ), + str_val, + ) + return str_val + + +def load_application_or_tool_configs( + config: Any, + config_type: Type[Any], + is_system_required: bool = True, +) -> Any: + """Get one config for each system supported by the application/tool. + + The configuration could contain different parameters/commands for different + supported systems. For each supported system this function will return separate + config with appropriate configuration. + """ + merged_configs = [] + supported_systems: Optional[List[NamedExecutionConfig]] = config.get( + "supported_systems" + ) + if not supported_systems: + if is_system_required: + raise ConfigurationException("No supported systems definition provided") + # Create an empty system to be used in the parsing below + supported_systems = [cast(NamedExecutionConfig, {})] + + default_user_params = config.get("user_params", {}) + + def merge_config(system: NamedExecutionConfig) -> Any: + system_name = system.get("name") + if not system_name and is_system_required: + raise ConfigurationException( + "Unable to read supported system definition, name is missed" + ) + + merged_config = config_type(**config) + merged_config["supported_systems"] = [system_name] if system_name else [] + # merge default configuration and specific to the system + merged_config["commands"] = { + **config.get("commands", {}), + **system.get("commands", {}), + } + + params = {} + tool_user_params = system.get("user_params", {}) + command_names = tool_user_params.keys() | default_user_params.keys() + for command_name in command_names: + if command_name not in merged_config["commands"]: + continue + + params_default = default_user_params.get(command_name, []) + params_tool = tool_user_params.get(command_name, []) + if not params_default or not params_tool: + params[command_name] = params_tool or params_default + if params_default and params_tool: + if any(not p.get("alias") for p in params_default): + raise ConfigurationException( + "Default parameters for command {} should have aliases".format( + command_name + ) + ) + if any(not p.get("alias") for p in params_tool): + raise ConfigurationException( + "{} parameters for command {} should have aliases".format( + system_name, command_name + ) + ) + + merged_by_alias = { + **{p.get("alias"): p for p in params_default}, + **{p.get("alias"): p for p in params_tool}, + } + params[command_name] = list(merged_by_alias.values()) + + merged_config["user_params"] = params + merged_config["build_dir"] = system.get("build_dir", config.get("build_dir")) + merged_config["lock"] = system.get("lock", config.get("lock", False)) + merged_config["variables"] = { + **config.get("variables", {}), + **system.get("variables", {}), + } + return merged_config + + merged_configs = [merge_config(system) for system in supported_systems] + + return merged_configs -- cgit v1.2.1