diff options
Diffstat (limited to 'src/aiet/cli/common.py')
-rw-r--r-- | src/aiet/cli/common.py | 173 |
1 files changed, 173 insertions, 0 deletions
diff --git a/src/aiet/cli/common.py b/src/aiet/cli/common.py new file mode 100644 index 0000000..1d157b6 --- /dev/null +++ b/src/aiet/cli/common.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Common functions for cli module.""" +import enum +import logging +from functools import wraps +from signal import SIG_IGN +from signal import SIGINT +from signal import signal as signal_handler +from signal import SIGTERM +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict + +from click import ClickException +from click import Context +from click import UsageError + +from aiet.backend.common import ConfigurationException +from aiet.backend.execution import AnotherInstanceIsRunningException +from aiet.backend.execution import ConnectionException +from aiet.backend.protocol import SSHConnectionException +from aiet.utils.proc import CommandFailedException + + +class MiddlewareExitCode(enum.IntEnum): + """Middleware exit codes.""" + + SUCCESS = 0 + # exit codes 1 and 2 are used by click + SHUTDOWN_REQUESTED = 3 + BACKEND_ERROR = 4 + CONCURRENT_ERROR = 5 + CONNECTION_ERROR = 6 + CONFIGURATION_ERROR = 7 + MODEL_OPTIMISED_ERROR = 8 + INVALID_TFLITE_FILE_ERROR = 9 + + +class CustomClickException(ClickException): + """Custom click exception.""" + + def show(self, file: Any = None) -> None: + """Override show method.""" + super().show(file) + + logging.debug("Execution failed with following exception: ", exc_info=self) + + +class MiddlewareShutdownException(CustomClickException): + """Exception indicates that user requested middleware shutdown.""" + + exit_code = int(MiddlewareExitCode.SHUTDOWN_REQUESTED) + + +class BackendException(CustomClickException): + """Exception indicates that command failed.""" + + exit_code = int(MiddlewareExitCode.BACKEND_ERROR) + + +class ConcurrentErrorException(CustomClickException): + """Exception indicates concurrent execution error.""" + + exit_code = int(MiddlewareExitCode.CONCURRENT_ERROR) + + +class BackendConnectionException(CustomClickException): + """Exception indicates that connection could not be established.""" + + exit_code = int(MiddlewareExitCode.CONNECTION_ERROR) + + +class BackendConfigurationException(CustomClickException): + """Exception indicates some configuration issue.""" + + exit_code = int(MiddlewareExitCode.CONFIGURATION_ERROR) + + +class ModelOptimisedException(CustomClickException): + """Exception indicates input file has previously been Vela optimised.""" + + exit_code = int(MiddlewareExitCode.MODEL_OPTIMISED_ERROR) + + +class InvalidTFLiteFileError(CustomClickException): + """Exception indicates input TFLite file is misformatted.""" + + exit_code = int(MiddlewareExitCode.INVALID_TFLITE_FILE_ERROR) + + +def print_command_details(command: Dict) -> None: + """Print command details including parameters.""" + command_strings = command["command_strings"] + print("Commands: {}".format(command_strings)) + user_params = command["user_params"] + for i, param in enumerate(user_params, 1): + print("User parameter #{}".format(i)) + print("\tName: {}".format(param.get("name", "-"))) + print("\tDescription: {}".format(param["description"])) + print("\tPossible values: {}".format(param.get("values", "-"))) + print("\tDefault value: {}".format(param.get("default_value", "-"))) + print("\tAlias: {}".format(param.get("alias", "-"))) + + +def raise_exception_at_signal( + signum: int, frame: Any # pylint: disable=unused-argument +) -> None: + """Handle signals.""" + # Disable both SIGINT and SIGTERM signals. Further SIGINT and SIGTERM + # signals will be ignored as we allow a graceful shutdown. + # Unused arguments must be present here in definition as used in signal handler + # callback + + signal_handler(SIGINT, SIG_IGN) + signal_handler(SIGTERM, SIG_IGN) + raise MiddlewareShutdownException("Middleware shutdown requested") + + +def middleware_exception_handler(func: Callable) -> Callable: + """Handle backend exceptions decorator.""" + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return func(*args, **kwargs) + except (MiddlewareShutdownException, UsageError, ClickException) as error: + # click should take care of these exceptions + raise error + except ValueError as error: + raise ClickException(str(error)) from error + except AnotherInstanceIsRunningException as error: + raise ConcurrentErrorException( + "Another instance of the system is running" + ) from error + except (SSHConnectionException, ConnectionException) as error: + raise BackendConnectionException(str(error)) from error + except ConfigurationException as error: + raise BackendConfigurationException(str(error)) from error + except (CommandFailedException, Exception) as error: + raise BackendException( + "Execution failed. Please check output for the details." + ) from error + + return wrapper + + +def middleware_signal_handler(func: Callable) -> Callable: + """Handle signals decorator.""" + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + # Set up signal handlers for SIGINT (ctrl-c) and SIGTERM (kill command) + # The handler ignores further signals and it raises an exception + signal_handler(SIGINT, raise_exception_at_signal) + signal_handler(SIGTERM, raise_exception_at_signal) + + return func(*args, **kwargs) + + return wrapper + + +def set_format(ctx: Context, format_: str) -> None: + """Save format in click context.""" + ctx_obj = ctx.ensure_object(dict) + ctx_obj["format"] = format_ + + +def get_format(ctx: Context) -> str: + """Get format from click context.""" + ctx_obj = cast(Dict[str, str], ctx.ensure_object(dict)) + return ctx_obj["format"] |