aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-08 14:24:39 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-09 17:21:48 +0100
commitf5b293d0927506c2a979a091bf0d07ecc78fa181 (patch)
tree4de585b7cb6ed34da8237063752270189a730a41 /src/mlia/cli
parentcde0c6ee140bd108849bff40467d8f18ffc332ef (diff)
downloadmlia-f5b293d0927506c2a979a091bf0d07ecc78fa181.tar.gz
MLIA-386 Simplify typing in the source code
- Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
Diffstat (limited to 'src/mlia/cli')
-rw-r--r--src/mlia/cli/commands.py26
-rw-r--r--src/mlia/cli/common.py9
-rw-r--r--src/mlia/cli/config.py7
-rw-r--r--src/mlia/cli/helpers.py28
-rw-r--r--src/mlia/cli/logging.py17
-rw-r--r--src/mlia/cli/main.py18
-rw-r--r--src/mlia/cli/options.py15
7 files changed, 58 insertions, 62 deletions
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index 45c7c32..5dd39f9 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -16,11 +16,11 @@ be configured. Function 'setup_logging' from module
>>> mlia.all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
"path/to/model")
"""
+from __future__ import annotations
+
import logging
from pathlib import Path
from typing import cast
-from typing import List
-from typing import Optional
from mlia.api import ExecutionContext
from mlia.api import get_advice
@@ -42,8 +42,8 @@ def all_tests(
model: str,
optimization_type: str = "pruning,clustering",
optimization_target: str = "0.5,32",
- output: Optional[PathOrFileLike] = None,
- evaluate_on: Optional[List[str]] = None,
+ output: PathOrFileLike | None = None,
+ evaluate_on: list[str] | None = None,
) -> None:
"""Generate a full report on the input model.
@@ -99,8 +99,8 @@ def all_tests(
def operators(
ctx: ExecutionContext,
target_profile: str,
- model: Optional[str] = None,
- output: Optional[PathOrFileLike] = None,
+ model: str | None = None,
+ output: PathOrFileLike | None = None,
supported_ops_report: bool = False,
) -> None:
"""Print the model's operator list.
@@ -149,8 +149,8 @@ def performance(
ctx: ExecutionContext,
target_profile: str,
model: str,
- output: Optional[PathOrFileLike] = None,
- evaluate_on: Optional[List[str]] = None,
+ output: PathOrFileLike | None = None,
+ evaluate_on: list[str] | None = None,
) -> None:
"""Print the model's performance stats.
@@ -192,9 +192,9 @@ def optimization(
model: str,
optimization_type: str,
optimization_target: str,
- layers_to_optimize: Optional[List[str]] = None,
- output: Optional[PathOrFileLike] = None,
- evaluate_on: Optional[List[str]] = None,
+ layers_to_optimize: list[str] | None = None,
+ output: PathOrFileLike | None = None,
+ evaluate_on: list[str] | None = None,
) -> None:
"""Show the performance improvements (if any) after applying the optimizations.
@@ -245,9 +245,9 @@ def optimization(
def backend(
backend_action: str,
- path: Optional[Path] = None,
+ path: Path | None = None,
download: bool = False,
- name: Optional[str] = None,
+ name: str | None = None,
i_agree_to_the_contained_eula: bool = False,
noninteractive: bool = False,
) -> None:
diff --git a/src/mlia/cli/common.py b/src/mlia/cli/common.py
index 54bd457..3f60668 100644
--- a/src/mlia/cli/common.py
+++ b/src/mlia/cli/common.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI common module."""
+from __future__ import annotations
+
import argparse
from dataclasses import dataclass
from typing import Callable
-from typing import List
@dataclass
@@ -12,8 +13,8 @@ class CommandInfo:
"""Command description."""
func: Callable
- aliases: List[str]
- opt_groups: List[Callable[[argparse.ArgumentParser], None]]
+ aliases: list[str]
+ opt_groups: list[Callable[[argparse.ArgumentParser], None]]
is_default: bool = False
@property
@@ -22,7 +23,7 @@ class CommandInfo:
return self.func.__name__
@property
- def command_name_and_aliases(self) -> List[str]:
+ def command_name_and_aliases(self) -> list[str]:
"""Return list of command name and aliases."""
return [self.command_name, *self.aliases]
diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py
index a673230..dc28fa2 100644
--- a/src/mlia/cli/config.py
+++ b/src/mlia/cli/config.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Environment configuration functions."""
+from __future__ import annotations
+
import logging
from functools import lru_cache
-from typing import List
import mlia.backend.manager as backend_manager
from mlia.tools.metadata.common import DefaultInstallationManager
@@ -21,7 +22,7 @@ def get_installation_manager(noninteractive: bool = False) -> InstallationManage
@lru_cache
-def get_available_backends() -> List[str]:
+def get_available_backends() -> list[str]:
"""Return list of the available backends."""
available_backends = ["Vela"]
@@ -42,7 +43,7 @@ def get_available_backends() -> List[str]:
_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300")
-def get_default_backends() -> List[str]:
+def get_default_backends() -> list[str]:
"""Get default backends for evaluation."""
backends = get_available_backends()
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
index 81d5a15..acec837 100644
--- a/src/mlia/cli/helpers.py
+++ b/src/mlia/cli/helpers.py
@@ -1,11 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for various helper classes."""
+from __future__ import annotations
+
from typing import Any
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
from mlia.cli.options import get_target_profile_opts
from mlia.core.helpers import ActionResolver
@@ -17,12 +15,12 @@ from mlia.utils.types import is_list_of
class CLIActionResolver(ActionResolver):
"""Helper class for generating cli commands."""
- def __init__(self, args: Dict[str, Any]) -> None:
+ def __init__(self, args: dict[str, Any]) -> None:
"""Init action resolver."""
self.args = args
@staticmethod
- def _general_optimization_command(model_path: Optional[str]) -> List[str]:
+ def _general_optimization_command(model_path: str | None) -> list[str]:
"""Return general optimization command description."""
keras_note = []
if model_path is None or not is_keras_model(model_path):
@@ -40,8 +38,8 @@ class CLIActionResolver(ActionResolver):
def _specific_optimization_command(
model_path: str,
device_opts: str,
- opt_settings: List[OptimizationSettings],
- ) -> List[str]:
+ opt_settings: list[OptimizationSettings],
+ ) -> list[str]:
"""Return specific optimization command description."""
opt_types = ",".join(opt.optimization_type for opt in opt_settings)
opt_targs = ",".join(str(opt.optimization_target) for opt in opt_settings)
@@ -53,7 +51,7 @@ class CLIActionResolver(ActionResolver):
f"--optimization-target {opt_targs}{device_opts} {model_path}",
]
- def apply_optimizations(self, **kwargs: Any) -> List[str]:
+ def apply_optimizations(self, **kwargs: Any) -> list[str]:
"""Return command details for applying optimizations."""
model_path, device_opts = self._get_model_and_device_opts()
@@ -67,14 +65,14 @@ class CLIActionResolver(ActionResolver):
return []
- def supported_operators_info(self) -> List[str]:
+ def supported_operators_info(self) -> list[str]:
"""Return command details for generating supported ops report."""
return [
"For guidance on supported operators, run: mlia operators "
"--supported-ops-report",
]
- def check_performance(self) -> List[str]:
+ def check_performance(self) -> list[str]:
"""Return command details for checking performance."""
model_path, device_opts = self._get_model_and_device_opts()
if not model_path:
@@ -85,7 +83,7 @@ class CLIActionResolver(ActionResolver):
f"mlia performance{device_opts} {model_path}",
]
- def check_operator_compatibility(self) -> List[str]:
+ def check_operator_compatibility(self) -> list[str]:
"""Return command details for op compatibility."""
model_path, device_opts = self._get_model_and_device_opts()
if not model_path:
@@ -96,17 +94,17 @@ class CLIActionResolver(ActionResolver):
f"mlia operators{device_opts} {model_path}",
]
- def operator_compatibility_details(self) -> List[str]:
+ def operator_compatibility_details(self) -> list[str]:
"""Return command details for op compatibility."""
return ["For more details, run: mlia operators --help"]
- def optimization_details(self) -> List[str]:
+ def optimization_details(self) -> list[str]:
"""Return command details for optimization."""
return ["For more info, see: mlia optimization --help"]
def _get_model_and_device_opts(
self, separate_device_opts: bool = True
- ) -> Tuple[Optional[str], str]:
+ ) -> tuple[str | None, str]:
"""Get model and device options."""
device_opts = " ".join(get_target_profile_opts(self.args))
if separate_device_opts and device_opts:
diff --git a/src/mlia/cli/logging.py b/src/mlia/cli/logging.py
index c5fc7bd..40f47d3 100644
--- a/src/mlia/cli/logging.py
+++ b/src/mlia/cli/logging.py
@@ -1,12 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI logging configuration."""
+from __future__ import annotations
+
import logging
import sys
from pathlib import Path
-from typing import List
-from typing import Optional
-from typing import Union
from mlia.utils.logging import attach_handlers
from mlia.utils.logging import create_log_handler
@@ -18,7 +17,7 @@ _FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
def setup_logging(
- logs_dir: Optional[Union[str, Path]] = None,
+ logs_dir: str | Path | None = None,
verbose: bool = False,
log_filename: str = "mlia.log",
) -> None:
@@ -49,10 +48,10 @@ def setup_logging(
def _get_mlia_handlers(
- logs_dir: Optional[Union[str, Path]],
+ logs_dir: str | Path | None,
log_filename: str,
verbose: bool,
-) -> List[logging.Handler]:
+) -> list[logging.Handler]:
"""Get handlers for the MLIA loggers."""
result = []
stdout_handler = create_log_handler(
@@ -84,10 +83,10 @@ def _get_mlia_handlers(
def _get_tools_handlers(
- logs_dir: Optional[Union[str, Path]],
+ logs_dir: str | Path | None,
log_filename: str,
verbose: bool,
-) -> List[logging.Handler]:
+) -> list[logging.Handler]:
"""Get handler for the tools loggers."""
result = []
if verbose:
@@ -110,7 +109,7 @@ def _get_tools_handlers(
return result
-def _get_log_file(logs_dir: Union[str, Path], log_filename: str) -> Path:
+def _get_log_file(logs_dir: str | Path, log_filename: str) -> Path:
"""Get the log file path."""
logs_dir_path = Path(logs_dir)
logs_dir_path.mkdir(exist_ok=True)
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index f8fc00c..0ece289 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -1,16 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI main entry point."""
+from __future__ import annotations
+
import argparse
import logging
import sys
from functools import partial
from inspect import signature
from pathlib import Path
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
from mlia import __version__
from mlia.cli.commands import all_tests
@@ -50,7 +48,7 @@ Supported targets:
""".strip()
-def get_commands() -> List[CommandInfo]:
+def get_commands() -> list[CommandInfo]:
"""Return commands configuration."""
return [
CommandInfo(
@@ -111,7 +109,7 @@ def get_commands() -> List[CommandInfo]:
]
-def get_default_command() -> Optional[str]:
+def get_default_command() -> str | None:
"""Get name of the default command."""
commands = get_commands()
@@ -121,7 +119,7 @@ def get_default_command() -> Optional[str]:
return next(iter(marked_as_default), None)
-def get_possible_command_names() -> List[str]:
+def get_possible_command_names() -> list[str]:
"""Get all possible command names including aliases."""
return [
name_or_alias
@@ -151,7 +149,7 @@ def init_commands(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
def setup_context(
args: argparse.Namespace, context_var_name: str = "ctx"
-) -> Tuple[ExecutionContext, Dict]:
+) -> tuple[ExecutionContext, dict]:
"""Set up context and resolve function parameters."""
ctx = ExecutionContext(
working_dir=args.working_dir,
@@ -252,7 +250,7 @@ def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.Argument
return parser
-def add_default_command_if_needed(args: List[str]) -> None:
+def add_default_command_if_needed(args: list[str]) -> None:
"""Add default command to the list of the arguments if needed."""
default_command = get_default_command()
@@ -265,7 +263,7 @@ def add_default_command_if_needed(args: List[str]) -> None:
args.insert(0, default_command)
-def main(argv: Optional[List[str]] = None) -> int:
+def main(argv: list[str] | None = None) -> int:
"""Entry point of the application."""
common_parser = init_common_parser()
subcommand_parser = init_subcommand_parser(common_parser)
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 29a0d89..3f0dc1f 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -1,13 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the CLI options."""
+from __future__ import annotations
+
import argparse
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import Dict
-from typing import List
-from typing import Optional
from mlia.cli.config import get_available_backends
from mlia.cli.config import get_default_backends
@@ -17,7 +16,7 @@ from mlia.utils.types import is_number
def add_target_options(
- parser: argparse.ArgumentParser, profiles_to_skip: Optional[List[str]] = None
+ parser: argparse.ArgumentParser, profiles_to_skip: list[str] | None = None
) -> None:
"""Add target specific options."""
target_profiles = get_supported_profile_names()
@@ -217,8 +216,8 @@ def parse_optimization_parameters(
optimization_type: str,
optimization_target: str,
sep: str = ",",
- layers_to_optimize: Optional[List[str]] = None,
-) -> List[Dict[str, Any]]:
+ layers_to_optimize: list[str] | None = None,
+) -> list[dict[str, Any]]:
"""Parse provided optimization parameters."""
if not optimization_type:
raise Exception("Optimization type is not provided")
@@ -250,7 +249,7 @@ def parse_optimization_parameters(
return optimizer_params
-def get_target_profile_opts(device_args: Optional[Dict]) -> List[str]:
+def get_target_profile_opts(device_args: dict | None) -> list[str]:
"""Get non default values passed as parameters for the target profile."""
if not device_args:
return []
@@ -270,7 +269,7 @@ def get_target_profile_opts(device_args: Optional[Dict]) -> List[str]:
if arg_name in args and vars(args)[arg_name] != arg_value
]
- def construct_param(name: str, value: Any) -> List[str]:
+ def construct_param(name: str, value: Any) -> list[str]:
"""Construct parameter."""
if isinstance(value, list):
return [str(item) for v in value for item in [name, v]]