aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/backend/tool.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/aiet/backend/tool.py')
-rw-r--r--src/aiet/backend/tool.py109
1 files changed, 109 insertions, 0 deletions
diff --git a/src/aiet/backend/tool.py b/src/aiet/backend/tool.py
new file mode 100644
index 0000000..d643665
--- /dev/null
+++ b/src/aiet/backend/tool.py
@@ -0,0 +1,109 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tool backend module."""
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from aiet.backend.common import Backend
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import get_backend_configs
+from aiet.backend.common import get_backend_directories
+from aiet.backend.common import load_application_or_tool_configs
+from aiet.backend.common import load_config
+from aiet.backend.config import ExtendedToolConfig
+from aiet.backend.config import ToolConfig
+
+
+def get_available_tool_directory_names() -> List[str]:
+ """Return a list of directory names for all available tools."""
+ return [entry.name for entry in get_backend_directories("tools")]
+
+
+def get_available_tools() -> List["Tool"]:
+ """Return a list with all available tools."""
+ available_tools = []
+ for config_json in get_backend_configs("tools"):
+ config_entries = cast(List[ExtendedToolConfig], load_config(config_json))
+ for config_entry in config_entries:
+ config_entry["config_location"] = config_json.parent.absolute()
+ tools = load_tools(config_entry)
+ available_tools += tools
+
+ return sorted(available_tools, key=lambda tool: tool.name)
+
+
+def get_tool(tool_name: str, system_name: Optional[str] = None) -> List["Tool"]:
+ """Return a tool instance with the same name passed as argument."""
+ return [
+ tool
+ for tool in get_available_tools()
+ if tool.name == tool_name and (not system_name or tool.can_run_on(system_name))
+ ]
+
+
+def get_unique_tool_names(system_name: Optional[str] = None) -> List[str]:
+ """Extract a list of unique tool names of all tools available."""
+ return list(
+ set(
+ tool.name
+ for tool in get_available_tools()
+ if not system_name or tool.can_run_on(system_name)
+ )
+ )
+
+
+class Tool(Backend):
+ """Class for representing a single tool component."""
+
+ def __init__(self, config: ToolConfig) -> None:
+ """Construct a Tool instance from a dict."""
+ super().__init__(config)
+
+ self.supported_systems = config.get("supported_systems", [])
+
+ if "run" not in self.commands:
+ raise ConfigurationException("A Tool must have a 'run' command.")
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, Tool):
+ return False
+
+ return (
+ super().__eq__(other)
+ and self.name == other.name
+ and set(self.supported_systems) == set(other.supported_systems)
+ )
+
+ def can_run_on(self, system_name: str) -> bool:
+ """Check if the tool can run on the system passed as argument."""
+ return system_name in self.supported_systems
+
+ def get_details(self) -> Dict[str, Any]:
+ """Return dictionary with all relevant information of the Tool instance."""
+ output = {
+ "type": "tool",
+ "name": self.name,
+ "description": self.description,
+ "supported_systems": self.supported_systems,
+ "commands": self._get_command_details(),
+ }
+
+ return output
+
+
+def load_tools(config: ExtendedToolConfig) -> List[Tool]:
+ """Load tool.
+
+ Tool configuration could contain different parameters/commands for different
+ supported systems. For each supported system this function will return separate
+ Tool instance with appropriate configuration.
+ """
+ configs = load_application_or_tool_configs(
+ config, ToolConfig, is_system_required=False
+ )
+ tools = [Tool(cfg) for cfg in configs]
+ return tools