diff options
Diffstat (limited to 'src/aiet/cli/tool.py')
-rw-r--r-- | src/aiet/cli/tool.py | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/src/aiet/cli/tool.py b/src/aiet/cli/tool.py new file mode 100644 index 0000000..2c80821 --- /dev/null +++ b/src/aiet/cli/tool.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module to manage the CLI interface of tools.""" +import json +from typing import Any +from typing import List +from typing import Optional + +import click + +from aiet.backend.execution import execute_tool_command +from aiet.backend.tool import get_tool +from aiet.backend.tool import get_unique_tool_names +from aiet.cli.common import get_format +from aiet.cli.common import middleware_exception_handler +from aiet.cli.common import middleware_signal_handler +from aiet.cli.common import print_command_details +from aiet.cli.common import set_format + + +@click.group(name="tool") +@click.option( + "-f", + "--format", + "format_", + type=click.Choice(["cli", "json"]), + default="cli", + show_default=True, +) +@click.pass_context +def tool_cmd(ctx: click.Context, format_: str) -> None: + """Sub command to manage tools.""" + set_format(ctx, format_) + + +@tool_cmd.command(name="list") +@click.pass_context +def list_cmd(ctx: click.Context) -> None: + """List all available tools.""" + # raise NotImplementedError("TODO") + tool_names = get_unique_tool_names() + tool_names.sort() + if get_format(ctx) == "json": + data = {"type": "tool", "available": tool_names} + print(json.dumps(data)) + else: + print("Available tools:\n") + print(*tool_names, sep="\n") + + +def validate_system( + ctx: click.Context, + _: click.Parameter, # param is not used + value: Any, +) -> Any: + """Validate provided system name depending on the the tool name.""" + tool_name = ctx.params["tool_name"] + tools = get_tool(tool_name, value) + if not tools: + supported_systems = [tool.supported_systems[0] for tool in get_tool(tool_name)] + raise click.BadParameter( + message="'{}' is not one of {}.".format( + value, + ", ".join("'{}'".format(system) for system in supported_systems), + ), + ctx=ctx, + ) + return value + + +@tool_cmd.command(name="details") +@click.option( + "-n", + "--name", + "tool_name", + type=click.Choice(get_unique_tool_names()), + required=True, +) +@click.option( + "-s", + "--system", + "system_name", + callback=validate_system, + required=False, +) +@click.pass_context +@middleware_signal_handler +@middleware_exception_handler +def details_cmd(ctx: click.Context, tool_name: str, system_name: Optional[str]) -> None: + """Details of a specific tool.""" + tools = get_tool(tool_name, system_name) + if get_format(ctx) == "json": + tools_details = [s.get_details() for s in tools] + print(json.dumps(tools_details)) + else: + for tool in tools: + tool_details = tool.get_details() + tool_details_template = 'Tool "{name}" details\nDescription: {description}' + + print( + tool_details_template.format( + name=tool_details["name"], + description=tool_details["description"], + ) + ) + + print( + "\nSupported systems: {}".format( + ", ".join(tool_details["supported_systems"]) + ) + ) + + command_details = tool_details["commands"] + + for command, details in command_details.items(): + print("\n{} commands:".format(command)) + print_command_details(details) + + +# pylint: disable=too-many-arguments +@tool_cmd.command(name="execute") +@click.option( + "-n", + "--name", + "tool_name", + type=click.Choice(get_unique_tool_names()), + required=True, +) +@click.option("-p", "--param", "tool_params", multiple=True) +@click.option( + "-s", + "--system", + "system_name", + callback=validate_system, + required=False, +) +@middleware_signal_handler +@middleware_exception_handler +def execute_cmd( + tool_name: str, tool_params: List[str], system_name: Optional[str] +) -> None: + """Execute tool commands.""" + execute_tool_command(tool_name, tool_params, system_name) |