aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/cli/tool.py
blob: 2c8082139be3caa21d1d6ff385dc4516bba4176d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module to manage the CLI interface of tools."""
import json
from typing import Any
from typing import List
from typing import Optional

import click

from aiet.backend.execution import execute_tool_command
from aiet.backend.tool import get_tool
from aiet.backend.tool import get_unique_tool_names
from aiet.cli.common import get_format
from aiet.cli.common import middleware_exception_handler
from aiet.cli.common import middleware_signal_handler
from aiet.cli.common import print_command_details
from aiet.cli.common import set_format


@click.group(name="tool")
@click.option(
    "-f",
    "--format",
    "format_",
    type=click.Choice(["cli", "json"]),
    default="cli",
    show_default=True,
)
@click.pass_context
def tool_cmd(ctx: click.Context, format_: str) -> None:
    """Sub command to manage tools."""
    set_format(ctx, format_)


@tool_cmd.command(name="list")
@click.pass_context
def list_cmd(ctx: click.Context) -> None:
    """List all available tools."""
    # raise NotImplementedError("TODO")
    tool_names = get_unique_tool_names()
    tool_names.sort()
    if get_format(ctx) == "json":
        data = {"type": "tool", "available": tool_names}
        print(json.dumps(data))
    else:
        print("Available tools:\n")
        print(*tool_names, sep="\n")


def validate_system(
    ctx: click.Context,
    _: click.Parameter,  # param is not used
    value: Any,
) -> Any:
    """Validate provided system name depending on the the tool name."""
    tool_name = ctx.params["tool_name"]
    tools = get_tool(tool_name, value)
    if not tools:
        supported_systems = [tool.supported_systems[0] for tool in get_tool(tool_name)]
        raise click.BadParameter(
            message="'{}' is not one of {}.".format(
                value,
                ", ".join("'{}'".format(system) for system in supported_systems),
            ),
            ctx=ctx,
        )
    return value


@tool_cmd.command(name="details")
@click.option(
    "-n",
    "--name",
    "tool_name",
    type=click.Choice(get_unique_tool_names()),
    required=True,
)
@click.option(
    "-s",
    "--system",
    "system_name",
    callback=validate_system,
    required=False,
)
@click.pass_context
@middleware_signal_handler
@middleware_exception_handler
def details_cmd(ctx: click.Context, tool_name: str, system_name: Optional[str]) -> None:
    """Details of a specific tool."""
    tools = get_tool(tool_name, system_name)
    if get_format(ctx) == "json":
        tools_details = [s.get_details() for s in tools]
        print(json.dumps(tools_details))
    else:
        for tool in tools:
            tool_details = tool.get_details()
            tool_details_template = 'Tool "{name}" details\nDescription: {description}'

            print(
                tool_details_template.format(
                    name=tool_details["name"],
                    description=tool_details["description"],
                )
            )

            print(
                "\nSupported systems: {}".format(
                    ", ".join(tool_details["supported_systems"])
                )
            )

            command_details = tool_details["commands"]

            for command, details in command_details.items():
                print("\n{} commands:".format(command))
                print_command_details(details)


# pylint: disable=too-many-arguments
@tool_cmd.command(name="execute")
@click.option(
    "-n",
    "--name",
    "tool_name",
    type=click.Choice(get_unique_tool_names()),
    required=True,
)
@click.option("-p", "--param", "tool_params", multiple=True)
@click.option(
    "-s",
    "--system",
    "system_name",
    callback=validate_system,
    required=False,
)
@middleware_signal_handler
@middleware_exception_handler
def execute_cmd(
    tool_name: str, tool_params: List[str], system_name: Optional[str]
) -> None:
    """Execute tool commands."""
    execute_tool_command(tool_name, tool_params, system_name)