diff options
Diffstat (limited to 'src/aiet/backend/tool.py')
-rw-r--r-- | src/aiet/backend/tool.py | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/src/aiet/backend/tool.py b/src/aiet/backend/tool.py new file mode 100644 index 0000000..d643665 --- /dev/null +++ b/src/aiet/backend/tool.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tool backend module.""" +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional + +from aiet.backend.common import Backend +from aiet.backend.common import ConfigurationException +from aiet.backend.common import get_backend_configs +from aiet.backend.common import get_backend_directories +from aiet.backend.common import load_application_or_tool_configs +from aiet.backend.common import load_config +from aiet.backend.config import ExtendedToolConfig +from aiet.backend.config import ToolConfig + + +def get_available_tool_directory_names() -> List[str]: + """Return a list of directory names for all available tools.""" + return [entry.name for entry in get_backend_directories("tools")] + + +def get_available_tools() -> List["Tool"]: + """Return a list with all available tools.""" + available_tools = [] + for config_json in get_backend_configs("tools"): + config_entries = cast(List[ExtendedToolConfig], load_config(config_json)) + for config_entry in config_entries: + config_entry["config_location"] = config_json.parent.absolute() + tools = load_tools(config_entry) + available_tools += tools + + return sorted(available_tools, key=lambda tool: tool.name) + + +def get_tool(tool_name: str, system_name: Optional[str] = None) -> List["Tool"]: + """Return a tool instance with the same name passed as argument.""" + return [ + tool + for tool in get_available_tools() + if tool.name == tool_name and (not system_name or tool.can_run_on(system_name)) + ] + + +def get_unique_tool_names(system_name: Optional[str] = None) -> List[str]: + """Extract a list of unique tool names of all tools available.""" + return list( + set( + tool.name + for tool in get_available_tools() + if not system_name or tool.can_run_on(system_name) + ) + ) + + +class Tool(Backend): + """Class for representing a single tool component.""" + + def __init__(self, config: ToolConfig) -> None: + """Construct a Tool instance from a dict.""" + super().__init__(config) + + self.supported_systems = config.get("supported_systems", []) + + if "run" not in self.commands: + raise ConfigurationException("A Tool must have a 'run' command.") + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Tool): + return False + + return ( + super().__eq__(other) + and self.name == other.name + and set(self.supported_systems) == set(other.supported_systems) + ) + + def can_run_on(self, system_name: str) -> bool: + """Check if the tool can run on the system passed as argument.""" + return system_name in self.supported_systems + + def get_details(self) -> Dict[str, Any]: + """Return dictionary with all relevant information of the Tool instance.""" + output = { + "type": "tool", + "name": self.name, + "description": self.description, + "supported_systems": self.supported_systems, + "commands": self._get_command_details(), + } + + return output + + +def load_tools(config: ExtendedToolConfig) -> List[Tool]: + """Load tool. + + Tool configuration could contain different parameters/commands for different + supported systems. For each supported system this function will return separate + Tool instance with appropriate configuration. + """ + configs = load_application_or_tool_configs( + config, ToolConfig, is_system_required=False + ) + tools = [Tool(cfg) for cfg in configs] + return tools |