aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli/helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/cli/helpers.py')
-rw-r--r--src/mlia/cli/helpers.py116
1 files changed, 116 insertions, 0 deletions
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
new file mode 100644
index 0000000..81d5a15
--- /dev/null
+++ b/src/mlia/cli/helpers.py
@@ -0,0 +1,116 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for various helper classes."""
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from mlia.cli.options import get_target_profile_opts
+from mlia.core.helpers import ActionResolver
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.utils.types import is_list_of
+
+
+class CLIActionResolver(ActionResolver):
+ """Helper class for generating cli commands."""
+
+ def __init__(self, args: Dict[str, Any]) -> None:
+ """Init action resolver."""
+ self.args = args
+
+ @staticmethod
+ def _general_optimization_command(model_path: Optional[str]) -> List[str]:
+ """Return general optimization command description."""
+ keras_note = []
+ if model_path is None or not is_keras_model(model_path):
+ model_path = "/path/to/keras_model"
+ keras_note = ["Note: you will need a Keras model for that."]
+
+ return [
+ *keras_note,
+ "For example: mlia optimization --optimization-type "
+ f"pruning,clustering --optimization-target 0.5,32 {model_path}",
+ "For more info: mlia optimization --help",
+ ]
+
+ @staticmethod
+ def _specific_optimization_command(
+ model_path: str,
+ device_opts: str,
+ opt_settings: List[OptimizationSettings],
+ ) -> List[str]:
+ """Return specific optimization command description."""
+ opt_types = ",".join(opt.optimization_type for opt in opt_settings)
+ opt_targs = ",".join(str(opt.optimization_target) for opt in opt_settings)
+
+ return [
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ f"mlia optimization --optimization-type {opt_types} "
+ f"--optimization-target {opt_targs}{device_opts} {model_path}",
+ ]
+
+ def apply_optimizations(self, **kwargs: Any) -> List[str]:
+ """Return command details for applying optimizations."""
+ model_path, device_opts = self._get_model_and_device_opts()
+
+ if (opt_settings := kwargs.pop("opt_settings", None)) is None:
+ return self._general_optimization_command(model_path)
+
+ if is_list_of(opt_settings, OptimizationSettings) and model_path:
+ return self._specific_optimization_command(
+ model_path, device_opts, opt_settings
+ )
+
+ return []
+
+ def supported_operators_info(self) -> List[str]:
+ """Return command details for generating supported ops report."""
+ return [
+ "For guidance on supported operators, run: mlia operators "
+ "--supported-ops-report",
+ ]
+
+ def check_performance(self) -> List[str]:
+ """Return command details for checking performance."""
+ model_path, device_opts = self._get_model_and_device_opts()
+ if not model_path:
+ return []
+
+ return [
+ "Check the estimated performance by running the following command: ",
+ f"mlia performance{device_opts} {model_path}",
+ ]
+
+ def check_operator_compatibility(self) -> List[str]:
+ """Return command details for op compatibility."""
+ model_path, device_opts = self._get_model_and_device_opts()
+ if not model_path:
+ return []
+
+ return [
+ "Try running the following command to verify that:",
+ f"mlia operators{device_opts} {model_path}",
+ ]
+
+ def operator_compatibility_details(self) -> List[str]:
+ """Return command details for op compatibility."""
+ return ["For more details, run: mlia operators --help"]
+
+ def optimization_details(self) -> List[str]:
+ """Return command details for optimization."""
+ return ["For more info, see: mlia optimization --help"]
+
+ def _get_model_and_device_opts(
+ self, separate_device_opts: bool = True
+ ) -> Tuple[Optional[str], str]:
+ """Get model and device options."""
+ device_opts = " ".join(get_target_profile_opts(self.args))
+ if separate_device_opts and device_opts:
+ device_opts = f" {device_opts}"
+
+ model_path = self.args.get("model")
+ return model_path, device_opts