aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/cli/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/aiet/cli/common.py')
-rw-r--r--src/aiet/cli/common.py173
1 files changed, 173 insertions, 0 deletions
diff --git a/src/aiet/cli/common.py b/src/aiet/cli/common.py
new file mode 100644
index 0000000..1d157b6
--- /dev/null
+++ b/src/aiet/cli/common.py
@@ -0,0 +1,173 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common functions for cli module."""
+import enum
+import logging
+from functools import wraps
+from signal import SIG_IGN
+from signal import SIGINT
+from signal import signal as signal_handler
+from signal import SIGTERM
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+
+from click import ClickException
+from click import Context
+from click import UsageError
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.execution import AnotherInstanceIsRunningException
+from aiet.backend.execution import ConnectionException
+from aiet.backend.protocol import SSHConnectionException
+from aiet.utils.proc import CommandFailedException
+
+
+class MiddlewareExitCode(enum.IntEnum):
+ """Middleware exit codes."""
+
+ SUCCESS = 0
+ # exit codes 1 and 2 are used by click
+ SHUTDOWN_REQUESTED = 3
+ BACKEND_ERROR = 4
+ CONCURRENT_ERROR = 5
+ CONNECTION_ERROR = 6
+ CONFIGURATION_ERROR = 7
+ MODEL_OPTIMISED_ERROR = 8
+ INVALID_TFLITE_FILE_ERROR = 9
+
+
+class CustomClickException(ClickException):
+ """Custom click exception."""
+
+ def show(self, file: Any = None) -> None:
+ """Override show method."""
+ super().show(file)
+
+ logging.debug("Execution failed with following exception: ", exc_info=self)
+
+
+class MiddlewareShutdownException(CustomClickException):
+ """Exception indicates that user requested middleware shutdown."""
+
+ exit_code = int(MiddlewareExitCode.SHUTDOWN_REQUESTED)
+
+
+class BackendException(CustomClickException):
+ """Exception indicates that command failed."""
+
+ exit_code = int(MiddlewareExitCode.BACKEND_ERROR)
+
+
+class ConcurrentErrorException(CustomClickException):
+ """Exception indicates concurrent execution error."""
+
+ exit_code = int(MiddlewareExitCode.CONCURRENT_ERROR)
+
+
+class BackendConnectionException(CustomClickException):
+ """Exception indicates that connection could not be established."""
+
+ exit_code = int(MiddlewareExitCode.CONNECTION_ERROR)
+
+
+class BackendConfigurationException(CustomClickException):
+ """Exception indicates some configuration issue."""
+
+ exit_code = int(MiddlewareExitCode.CONFIGURATION_ERROR)
+
+
+class ModelOptimisedException(CustomClickException):
+ """Exception indicates input file has previously been Vela optimised."""
+
+ exit_code = int(MiddlewareExitCode.MODEL_OPTIMISED_ERROR)
+
+
+class InvalidTFLiteFileError(CustomClickException):
+ """Exception indicates input TFLite file is misformatted."""
+
+ exit_code = int(MiddlewareExitCode.INVALID_TFLITE_FILE_ERROR)
+
+
+def print_command_details(command: Dict) -> None:
+ """Print command details including parameters."""
+ command_strings = command["command_strings"]
+ print("Commands: {}".format(command_strings))
+ user_params = command["user_params"]
+ for i, param in enumerate(user_params, 1):
+ print("User parameter #{}".format(i))
+ print("\tName: {}".format(param.get("name", "-")))
+ print("\tDescription: {}".format(param["description"]))
+ print("\tPossible values: {}".format(param.get("values", "-")))
+ print("\tDefault value: {}".format(param.get("default_value", "-")))
+ print("\tAlias: {}".format(param.get("alias", "-")))
+
+
+def raise_exception_at_signal(
+ signum: int, frame: Any # pylint: disable=unused-argument
+) -> None:
+ """Handle signals."""
+ # Disable both SIGINT and SIGTERM signals. Further SIGINT and SIGTERM
+ # signals will be ignored as we allow a graceful shutdown.
+ # Unused arguments must be present here in definition as used in signal handler
+ # callback
+
+ signal_handler(SIGINT, SIG_IGN)
+ signal_handler(SIGTERM, SIG_IGN)
+ raise MiddlewareShutdownException("Middleware shutdown requested")
+
+
+def middleware_exception_handler(func: Callable) -> Callable:
+ """Handle backend exceptions decorator."""
+
+ @wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ try:
+ return func(*args, **kwargs)
+ except (MiddlewareShutdownException, UsageError, ClickException) as error:
+ # click should take care of these exceptions
+ raise error
+ except ValueError as error:
+ raise ClickException(str(error)) from error
+ except AnotherInstanceIsRunningException as error:
+ raise ConcurrentErrorException(
+ "Another instance of the system is running"
+ ) from error
+ except (SSHConnectionException, ConnectionException) as error:
+ raise BackendConnectionException(str(error)) from error
+ except ConfigurationException as error:
+ raise BackendConfigurationException(str(error)) from error
+ except (CommandFailedException, Exception) as error:
+ raise BackendException(
+ "Execution failed. Please check output for the details."
+ ) from error
+
+ return wrapper
+
+
+def middleware_signal_handler(func: Callable) -> Callable:
+ """Handle signals decorator."""
+
+ @wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ # Set up signal handlers for SIGINT (ctrl-c) and SIGTERM (kill command)
+ # The handler ignores further signals and it raises an exception
+ signal_handler(SIGINT, raise_exception_at_signal)
+ signal_handler(SIGTERM, raise_exception_at_signal)
+
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def set_format(ctx: Context, format_: str) -> None:
+ """Save format in click context."""
+ ctx_obj = ctx.ensure_object(dict)
+ ctx_obj["format"] = format_
+
+
+def get_format(ctx: Context) -> str:
+ """Get format from click context."""
+ ctx_obj = cast(Dict[str, str], ctx.ensure_object(dict))
+ return ctx_obj["format"]