aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/backend/tool.py
blob: d643665c19905b3255ca9ed7f7bc41efc73c5f8a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tool backend module."""
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional

from aiet.backend.common import Backend
from aiet.backend.common import ConfigurationException
from aiet.backend.common import get_backend_configs
from aiet.backend.common import get_backend_directories
from aiet.backend.common import load_application_or_tool_configs
from aiet.backend.common import load_config
from aiet.backend.config import ExtendedToolConfig
from aiet.backend.config import ToolConfig


def get_available_tool_directory_names() -> List[str]:
    """Return a list of directory names for all available tools."""
    return [entry.name for entry in get_backend_directories("tools")]


def get_available_tools() -> List["Tool"]:
    """Return a list with all available tools."""
    available_tools = []
    for config_json in get_backend_configs("tools"):
        config_entries = cast(List[ExtendedToolConfig], load_config(config_json))
        for config_entry in config_entries:
            config_entry["config_location"] = config_json.parent.absolute()
            tools = load_tools(config_entry)
            available_tools += tools

    return sorted(available_tools, key=lambda tool: tool.name)


def get_tool(tool_name: str, system_name: Optional[str] = None) -> List["Tool"]:
    """Return a tool instance with the same name passed as argument."""
    return [
        tool
        for tool in get_available_tools()
        if tool.name == tool_name and (not system_name or tool.can_run_on(system_name))
    ]


def get_unique_tool_names(system_name: Optional[str] = None) -> List[str]:
    """Extract a list of unique tool names of all tools available."""
    return list(
        set(
            tool.name
            for tool in get_available_tools()
            if not system_name or tool.can_run_on(system_name)
        )
    )


class Tool(Backend):
    """Class for representing a single tool component."""

    def __init__(self, config: ToolConfig) -> None:
        """Construct a Tool instance from a dict."""
        super().__init__(config)

        self.supported_systems = config.get("supported_systems", [])

        if "run" not in self.commands:
            raise ConfigurationException("A Tool must have a 'run' command.")

    def __eq__(self, other: object) -> bool:
        """Overload operator ==."""
        if not isinstance(other, Tool):
            return False

        return (
            super().__eq__(other)
            and self.name == other.name
            and set(self.supported_systems) == set(other.supported_systems)
        )

    def can_run_on(self, system_name: str) -> bool:
        """Check if the tool can run on the system passed as argument."""
        return system_name in self.supported_systems

    def get_details(self) -> Dict[str, Any]:
        """Return dictionary with all relevant information of the Tool instance."""
        output = {
            "type": "tool",
            "name": self.name,
            "description": self.description,
            "supported_systems": self.supported_systems,
            "commands": self._get_command_details(),
        }

        return output


def load_tools(config: ExtendedToolConfig) -> List[Tool]:
    """Load tool.

    Tool configuration could contain different parameters/commands for different
    supported systems. For each supported system this function will return separate
    Tool instance with appropriate configuration.
    """
    configs = load_application_or_tool_configs(
        config, ToolConfig, is_system_required=False
    )
    tools = [Tool(cfg) for cfg in configs]
    return tools