aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-08 14:24:39 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-09 17:21:48 +0100
commitf5b293d0927506c2a979a091bf0d07ecc78fa181 (patch)
tree4de585b7cb6ed34da8237063752270189a730a41
parentcde0c6ee140bd108849bff40467d8f18ffc332ef (diff)
downloadmlia-f5b293d0927506c2a979a091bf0d07ecc78fa181.tar.gz
MLIA-386 Simplify typing in the source code
- Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
-rw-r--r--src/mlia/api.py24
-rw-r--r--src/mlia/backend/application.py18
-rw-r--r--src/mlia/backend/common.py64
-rw-r--r--src/mlia/backend/config.py19
-rw-r--r--src/mlia/backend/execution.py41
-rw-r--r--src/mlia/backend/fs.py7
-rw-r--r--src/mlia/backend/manager.py41
-rw-r--r--src/mlia/backend/output_consumer.py5
-rw-r--r--src/mlia/backend/proc.py17
-rw-r--r--src/mlia/backend/source.py22
-rw-r--r--src/mlia/backend/system.py12
-rw-r--r--src/mlia/cli/commands.py26
-rw-r--r--src/mlia/cli/common.py9
-rw-r--r--src/mlia/cli/config.py7
-rw-r--r--src/mlia/cli/helpers.py28
-rw-r--r--src/mlia/cli/logging.py17
-rw-r--r--src/mlia/cli/main.py18
-rw-r--r--src/mlia/cli/options.py15
-rw-r--r--src/mlia/core/advice_generation.py14
-rw-r--r--src/mlia/core/advisor.py11
-rw-r--r--src/mlia/core/common.py4
-rw-r--r--src/mlia/core/context.py31
-rw-r--r--src/mlia/core/data_analysis.py9
-rw-r--r--src/mlia/core/events.py24
-rw-r--r--src/mlia/core/handlers.py10
-rw-r--r--src/mlia/core/helpers.py15
-rw-r--r--src/mlia/core/mixins.py7
-rw-r--r--src/mlia/core/performance.py19
-rw-r--r--src/mlia/core/reporting.py99
-rw-r--r--src/mlia/core/typing.py (renamed from src/mlia/core/_typing.py)0
-rw-r--r--src/mlia/core/workflow.py15
-rw-r--r--src/mlia/devices/ethosu/advice_generation.py10
-rw-r--r--src/mlia/devices/ethosu/advisor.py32
-rw-r--r--src/mlia/devices/ethosu/config.py7
-rw-r--r--src/mlia/devices/ethosu/data_analysis.py21
-rw-r--r--src/mlia/devices/ethosu/data_collection.py18
-rw-r--r--src/mlia/devices/ethosu/handlers.py7
-rw-r--r--src/mlia/devices/ethosu/performance.py35
-rw-r--r--src/mlia/devices/ethosu/reporters.py19
-rw-r--r--src/mlia/devices/tosa/advisor.py30
-rw-r--r--src/mlia/devices/tosa/handlers.py7
-rw-r--r--src/mlia/devices/tosa/operators.py12
-rw-r--r--src/mlia/devices/tosa/reporters.py7
-rw-r--r--src/mlia/nn/tensorflow/config.py36
-rw-r--r--src/mlia/nn/tensorflow/optimizations/clustering.py9
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py13
-rw-r--r--src/mlia/nn/tensorflow/optimizations/select.py34
-rw-r--r--src/mlia/nn/tensorflow/tflite_metrics.py20
-rw-r--r--src/mlia/nn/tensorflow/utils.py15
-rw-r--r--src/mlia/tools/metadata/common.py40
-rw-r--r--src/mlia/tools/metadata/corstone.py27
-rw-r--r--src/mlia/tools/vela_wrapper.py33
-rw-r--r--src/mlia/utils/console.py8
-rw-r--r--src/mlia/utils/download.py10
-rw-r--r--src/mlia/utils/filesystem.py20
-rw-r--r--src/mlia/utils/logging.py22
-rw-r--r--src/mlia/utils/types.py7
-rw-r--r--tests/test_backend_application.py5
-rw-r--r--tests/test_backend_common.py16
-rw-r--r--tests/test_backend_fs.py5
-rw-r--r--tests/test_backend_manager.py45
-rw-r--r--tests/test_backend_output_consumer.py9
-rw-r--r--tests/test_backend_system.py12
-rw-r--r--tests/test_cli_commands.py7
-rw-r--r--tests/test_cli_config.py7
-rw-r--r--tests/test_cli_helpers.py14
-rw-r--r--tests/test_cli_logging.py5
-rw-r--r--tests/test_cli_main.py7
-rw-r--r--tests/test_cli_options.py9
-rw-r--r--tests/test_core_advice_generation.py4
-rw-r--r--tests/test_core_reporting.py11
-rw-r--r--tests/test_devices_ethosu_advice_generation.py6
-rw-r--r--tests/test_devices_ethosu_config.py5
-rw-r--r--tests/test_devices_ethosu_data_analysis.py4
-rw-r--r--tests/test_devices_ethosu_reporters.py16
-rw-r--r--tests/test_devices_tosa_advice_generation.py4
-rw-r--r--tests/test_devices_tosa_data_analysis.py4
-rw-r--r--tests/test_devices_tosa_operators.py5
-rw-r--r--tests/test_nn_tensorflow_optimizations_clustering.py12
-rw-r--r--tests/test_nn_tensorflow_optimizations_pruning.py8
-rw-r--r--tests/test_nn_tensorflow_optimizations_select.py6
-rw-r--r--tests/test_nn_tensorflow_tflite_metrics.py5
-rw-r--r--tests/test_tools_metadata_common.py12
-rw-r--r--tests/test_tools_metadata_corstone.py20
-rw-r--r--tests/test_utils_console.py6
-rw-r--r--tests/test_utils_download.py11
-rw-r--r--tests/test_utils_logging.py9
-rw-r--r--tests/test_utils_types.py9
-rw-r--r--tests/utils/common.py4
89 files changed, 709 insertions, 740 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index c720b8d..878e316 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -1,19 +1,17 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the API functions."""
+from __future__ import annotations
+
import logging
from pathlib import Path
from typing import Any
-from typing import Dict
-from typing import List
from typing import Literal
-from typing import Optional
-from typing import Union
-from mlia.core._typing import PathOrFileLike
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
+from mlia.core.typing import PathOrFileLike
from mlia.devices.ethosu.advisor import configure_and_get_ethosu_advisor
from mlia.devices.tosa.advisor import configure_and_get_tosa_advisor
from mlia.utils.filesystem import get_target
@@ -24,13 +22,13 @@ logger = logging.getLogger(__name__)
def get_advice(
target_profile: str,
- model: Union[Path, str],
+ model: str | Path,
category: Literal["all", "operators", "performance", "optimization"] = "all",
- optimization_targets: Optional[List[Dict[str, Any]]] = None,
- working_dir: Union[str, Path] = "mlia_output",
- output: Optional[PathOrFileLike] = None,
- context: Optional[ExecutionContext] = None,
- backends: Optional[List[str]] = None,
+ optimization_targets: list[dict[str, Any]] | None = None,
+ working_dir: str | Path = "mlia_output",
+ output: PathOrFileLike | None = None,
+ context: ExecutionContext | None = None,
+ backends: list[str] | None = None,
) -> None:
"""Get the advice.
@@ -97,8 +95,8 @@ def get_advice(
def get_advisor(
context: ExecutionContext,
target_profile: str,
- model: Union[Path, str],
- output: Optional[PathOrFileLike] = None,
+ model: str | Path,
+ output: PathOrFileLike | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Find appropriate advisor for the target."""
diff --git a/src/mlia/backend/application.py b/src/mlia/backend/application.py
index 4b04324..a093afe 100644
--- a/src/mlia/backend/application.py
+++ b/src/mlia/backend/application.py
@@ -1,13 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Application backend module."""
+from __future__ import annotations
+
import re
from pathlib import Path
from typing import Any
from typing import cast
-from typing import Dict
from typing import List
-from typing import Optional
from mlia.backend.common import Backend
from mlia.backend.common import ConfigurationException
@@ -23,12 +23,12 @@ from mlia.backend.source import create_destination_and_install
from mlia.backend.source import get_source
-def get_available_application_directory_names() -> List[str]:
+def get_available_application_directory_names() -> list[str]:
"""Return a list of directory names for all available applications."""
return [entry.name for entry in get_backend_directories("applications")]
-def get_available_applications() -> List["Application"]:
+def get_available_applications() -> list[Application]:
"""Return a list with all available applications."""
available_applications = []
for config_json in get_backend_configs("applications"):
@@ -42,8 +42,8 @@ def get_available_applications() -> List["Application"]:
def get_application(
- application_name: str, system_name: Optional[str] = None
-) -> List["Application"]:
+ application_name: str, system_name: str | None = None
+) -> list[Application]:
"""Return a list of application instances with provided name."""
return [
application
@@ -85,7 +85,7 @@ def remove_application(directory_name: str) -> None:
remove_backend(directory_name, "applications")
-def get_unique_application_names(system_name: Optional[str] = None) -> List[str]:
+def get_unique_application_names(system_name: str | None = None) -> list[str]:
"""Extract a list of unique application names of all application available."""
return list(
set(
@@ -120,7 +120,7 @@ class Application(Backend):
"""Check if the application can run on the system passed as argument."""
return system_name in self.supported_systems
- def get_details(self) -> Dict[str, Any]:
+ def get_details(self) -> dict[str, Any]:
"""Return dictionary with information about the Application instance."""
output = {
"type": "application",
@@ -156,7 +156,7 @@ class Application(Backend):
command.params = used_params
-def load_applications(config: ExtendedApplicationConfig) -> List[Application]:
+def load_applications(config: ExtendedApplicationConfig) -> list[Application]:
"""Load application.
Application configuration could contain different parameters/commands for different
diff --git a/src/mlia/backend/common.py b/src/mlia/backend/common.py
index e61d6b6..697c2a0 100644
--- a/src/mlia/backend/common.py
+++ b/src/mlia/backend/common.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Contain all common functions for the backends."""
+from __future__ import annotations
+
import json
import logging
import re
@@ -10,18 +12,12 @@ from pathlib import Path
from typing import Any
from typing import Callable
from typing import cast
-from typing import Dict
from typing import Final
from typing import IO
from typing import Iterable
-from typing import List
from typing import Match
from typing import NamedTuple
-from typing import Optional
from typing import Pattern
-from typing import Tuple
-from typing import Type
-from typing import Union
from mlia.backend.config import BackendConfig
from mlia.backend.config import BaseBackendConfig
@@ -74,7 +70,7 @@ def remove_backend(directory_name: str, resource_type: ResourceType) -> None:
remove_resource(directory_name, resource_type)
-def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig:
+def load_config(config: Path | IO[bytes] | None) -> BackendConfig:
"""Return a loaded json file."""
if config is None:
raise Exception("Unable to read config")
@@ -86,7 +82,7 @@ def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig:
return cast(BackendConfig, json.load(config))
-def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]:
+def parse_raw_parameter(parameter: str) -> tuple[str, str | None]:
"""Split the parameter string in name and optional value.
It manages the following cases:
@@ -176,7 +172,7 @@ class Backend(ABC):
def _parse_commands_and_params(self, config: BaseBackendConfig) -> None:
"""Parse commands and user parameters."""
- self.commands: Dict[str, Command] = {}
+ self.commands: dict[str, Command] = {}
commands = config.get("commands")
if commands:
@@ -213,15 +209,15 @@ class Backend(ABC):
@classmethod
def _parse_params(
- cls, params: Optional[UserParamsConfig], command: str
- ) -> List["Param"]:
+ cls, params: UserParamsConfig | None, command: str
+ ) -> list[Param]:
if not params:
return []
return [cls._parse_param(p) for p in params.get(command, [])]
@classmethod
- def _parse_param(cls, param: UserParamConfig) -> "Param":
+ def _parse_param(cls, param: UserParamConfig) -> Param:
"""Parse a single parameter."""
name = param.get("name")
if name is not None and not name:
@@ -239,16 +235,14 @@ class Backend(ABC):
alias=alias,
)
- def _get_command_details(self) -> Dict:
+ def _get_command_details(self) -> dict:
command_details = {
command_name: command.get_details()
for command_name, command in self.commands.items()
}
return command_details
- def _get_user_param_value(
- self, user_params: List[str], param: "Param"
- ) -> Optional[str]:
+ def _get_user_param_value(self, user_params: list[str], param: Param) -> str | None:
"""Get the user-specified value of a parameter."""
for user_param in user_params:
user_param_name, user_param_value = parse_raw_parameter(user_param)
@@ -267,7 +261,7 @@ class Backend(ABC):
return None
@staticmethod
- def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool:
+ def _same_parameter(user_param_name_or_alias: str, param: Param) -> bool:
"""Compare user parameter name with param name or alias."""
# Strip the "=" sign in the param_name. This is needed just for
# comparison with the parameters passed by the user.
@@ -277,10 +271,10 @@ class Backend(ABC):
return user_param_name_or_alias in [param_name, param.alias]
def resolved_parameters(
- self, command_name: str, user_params: List[str]
- ) -> List[Tuple[Optional[str], "Param"]]:
+ self, command_name: str, user_params: list[str]
+ ) -> list[tuple[str | None, Param]]:
"""Return list of parameters with values."""
- result: List[Tuple[Optional[str], "Param"]] = []
+ result: list[tuple[str | None, Param]] = []
command = self.commands.get(command_name)
if not command:
return result
@@ -296,9 +290,9 @@ class Backend(ABC):
def build_command(
self,
command_name: str,
- user_params: List[str],
- param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str],
- ) -> List[str]:
+ user_params: list[str],
+ param_resolver: Callable[[str, str, list[tuple[str | None, Param]]], str],
+ ) -> list[str]:
"""
Return a list of executable command strings.
@@ -328,11 +322,11 @@ class Param:
def __init__( # pylint: disable=too-many-arguments
self,
- name: Optional[str],
+ name: str | None,
description: str,
- values: Optional[List[str]] = None,
- default_value: Optional[str] = None,
- alias: Optional[str] = None,
+ values: list[str] | None = None,
+ default_value: str | None = None,
+ alias: str | None = None,
) -> None:
"""Construct a Param instance."""
if not name and not alias:
@@ -345,7 +339,7 @@ class Param:
self.default_value = default_value
self.alias = alias
- def get_details(self) -> Dict:
+ def get_details(self) -> dict:
"""Return a dictionary with all relevant information of a Param."""
return {key: value for key, value in self.__dict__.items() if value}
@@ -366,7 +360,7 @@ class Command:
"""Class for representing a command."""
def __init__(
- self, command_strings: List[str], params: Optional[List[Param]] = None
+ self, command_strings: list[str], params: list[Param] | None = None
) -> None:
"""Construct a Command instance."""
self.command_strings = command_strings
@@ -404,7 +398,7 @@ class Command:
"as parameter name."
)
- def get_details(self) -> Dict:
+ def get_details(self) -> dict:
"""Return a dictionary with all relevant information of a Command."""
output = {
"command_strings": self.command_strings,
@@ -425,9 +419,9 @@ class Command:
def resolve_all_parameters(
str_val: str,
- param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str],
- command_name: Optional[str] = None,
- params_values: Optional[List[Tuple[Optional[str], Param]]] = None,
+ param_resolver: Callable[[str, str, list[tuple[str | None, Param]]], str],
+ command_name: str | None = None,
+ params_values: list[tuple[str | None, Param]] | None = None,
) -> str:
"""Resolve all parameters in the string."""
if not str_val:
@@ -446,7 +440,7 @@ def resolve_all_parameters(
def load_application_configs(
config: Any,
- config_type: Type[Any],
+ config_type: type[Any],
is_system_required: bool = True,
) -> Any:
"""Get one config for each system supported by the application.
@@ -456,7 +450,7 @@ def load_application_configs(
config with appropriate configuration.
"""
merged_configs = []
- supported_systems: Optional[List[NamedExecutionConfig]] = config.get(
+ supported_systems: list[NamedExecutionConfig] | None = config.get(
"supported_systems"
)
if not supported_systems:
diff --git a/src/mlia/backend/config.py b/src/mlia/backend/config.py
index 9a56fa9..dca53da 100644
--- a/src/mlia/backend/config.py
+++ b/src/mlia/backend/config.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Contain definition of backend configuration."""
+from __future__ import annotations
+
from pathlib import Path
from typing import Dict
from typing import List
-from typing import Optional
from typing import TypedDict
from typing import Union
@@ -12,9 +13,9 @@ from typing import Union
class UserParamConfig(TypedDict, total=False):
"""User parameter configuration."""
- name: Optional[str]
+ name: str | None
default_value: str
- values: List[str]
+ values: list[str]
description: str
alias: str
@@ -25,9 +26,9 @@ UserParamsConfig = Dict[str, List[UserParamConfig]]
class ExecutionConfig(TypedDict, total=False):
"""Execution configuration."""
- commands: Dict[str, List[str]]
+ commands: dict[str, list[str]]
user_params: UserParamsConfig
- variables: Dict[str, str]
+ variables: dict[str, str]
class NamedExecutionConfig(ExecutionConfig):
@@ -42,25 +43,25 @@ class BaseBackendConfig(ExecutionConfig, total=False):
name: str
description: str
config_location: Path
- annotations: Dict[str, Union[str, List[str]]]
+ annotations: dict[str, str | list[str]]
class ApplicationConfig(BaseBackendConfig, total=False):
"""Application configuration."""
- supported_systems: List[str]
+ supported_systems: list[str]
class ExtendedApplicationConfig(BaseBackendConfig, total=False):
"""Extended application configuration."""
- supported_systems: List[NamedExecutionConfig]
+ supported_systems: list[NamedExecutionConfig]
class SystemConfig(BaseBackendConfig, total=False):
"""System configuration."""
- reporting: Dict[str, Dict]
+ reporting: dict[str, dict]
BackendItemConfig = Union[ApplicationConfig, SystemConfig]
diff --git a/src/mlia/backend/execution.py b/src/mlia/backend/execution.py
index 5340a47..f3fe401 100644
--- a/src/mlia/backend/execution.py
+++ b/src/mlia/backend/execution.py
@@ -1,12 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Application execution module."""
+from __future__ import annotations
+
import logging
import re
from typing import cast
-from typing import List
-from typing import Optional
-from typing import Tuple
from mlia.backend.application import Application
from mlia.backend.application import get_application
@@ -29,9 +28,9 @@ class ExecutionContext: # pylint: disable=too-few-public-methods
def __init__(
self,
app: Application,
- app_params: List[str],
+ app_params: list[str],
system: System,
- system_params: List[str],
+ system_params: list[str],
):
"""Init execution context."""
self.app = app
@@ -41,8 +40,8 @@ class ExecutionContext: # pylint: disable=too-few-public-methods
self.param_resolver = ParamResolver(self)
- self.stdout: Optional[bytearray] = None
- self.stderr: Optional[bytearray] = None
+ self.stdout: bytearray | None = None
+ self.stderr: bytearray | None = None
class ParamResolver:
@@ -54,16 +53,16 @@ class ParamResolver:
@staticmethod
def resolve_user_params(
- cmd_name: Optional[str],
+ cmd_name: str | None,
index_or_alias: str,
- resolved_params: Optional[List[Tuple[Optional[str], Param]]],
+ resolved_params: list[tuple[str | None, Param]] | None,
) -> str:
"""Resolve user params."""
if not cmd_name or resolved_params is None:
raise ConfigurationException("Unable to resolve user params")
- param_value: Optional[str] = None
- param: Optional[Param] = None
+ param_value: str | None = None
+ param: Param | None = None
if index_or_alias.isnumeric():
i = int(index_or_alias)
@@ -176,8 +175,8 @@ class ParamResolver:
def param_matcher(
self,
param_name: str,
- cmd_name: Optional[str],
- resolved_params: Optional[List[Tuple[Optional[str], Param]]],
+ cmd_name: str | None,
+ resolved_params: list[tuple[str | None, Param]] | None,
) -> str:
"""Regexp to resolve a param from the param_name."""
# this pattern supports parameter names like "application.commands.run:0" and
@@ -224,8 +223,8 @@ class ParamResolver:
def param_resolver(
self,
param_name: str,
- cmd_name: Optional[str] = None,
- resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None,
+ cmd_name: str | None = None,
+ resolved_params: list[tuple[str | None, Param]] | None = None,
) -> str:
"""Resolve parameter value based on current execution context."""
# Note: 'software.*' is included for backward compatibility.
@@ -253,15 +252,15 @@ class ParamResolver:
def __call__(
self,
param_name: str,
- cmd_name: Optional[str] = None,
- resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None,
+ cmd_name: str | None = None,
+ resolved_params: list[tuple[str | None, Param]] | None = None,
) -> str:
"""Resolve provided parameter."""
return self.param_resolver(param_name, cmd_name, resolved_params)
def validate_parameters(
- backend: Backend, command_names: List[str], params: List[str]
+ backend: Backend, command_names: list[str], params: list[str]
) -> None:
"""Check parameters passed to backend."""
for param in params:
@@ -301,7 +300,7 @@ def get_application_by_name_and_system(
def get_application_and_system(
application_name: str, system_name: str
-) -> Tuple[Application, System]:
+) -> tuple[Application, System]:
"""Return application and system by provided names."""
system = get_system(system_name)
if not system:
@@ -314,9 +313,9 @@ def get_application_and_system(
def run_application(
application_name: str,
- application_params: List[str],
+ application_params: list[str],
system_name: str,
- system_params: List[str],
+ system_params: list[str],
) -> ExecutionContext:
"""Run application on the provided system."""
application, system = get_application_and_system(application_name, system_name)
diff --git a/src/mlia/backend/fs.py b/src/mlia/backend/fs.py
index 9fb53b1..3fce19c 100644
--- a/src/mlia/backend/fs.py
+++ b/src/mlia/backend/fs.py
@@ -1,11 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module to host all file system related functions."""
+from __future__ import annotations
+
import re
import shutil
from pathlib import Path
from typing import Literal
-from typing import Optional
from mlia.utils.filesystem import get_mlia_resources
@@ -58,7 +59,7 @@ def remove_resource(resource_directory: str, resource_type: ResourceType) -> Non
shutil.rmtree(resource_location)
-def remove_directory(directory_path: Optional[Path]) -> None:
+def remove_directory(directory_path: Path | None) -> None:
"""Remove directory."""
if not directory_path or not directory_path.is_dir():
raise Exception("No directory path provided")
@@ -66,7 +67,7 @@ def remove_directory(directory_path: Optional[Path]) -> None:
shutil.rmtree(directory_path)
-def recreate_directory(directory_path: Optional[Path]) -> None:
+def recreate_directory(directory_path: Path | None) -> None:
"""Recreate directory."""
if not directory_path:
raise Exception("No directory path provided")
diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py
index 8d8246d..c8fe0f7 100644
--- a/src/mlia/backend/manager.py
+++ b/src/mlia/backend/manager.py
@@ -1,17 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for backend integration."""
+from __future__ import annotations
+
import logging
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
-from typing import Dict
-from typing import List
from typing import Literal
-from typing import Optional
-from typing import Set
-from typing import Tuple
from mlia.backend.application import get_available_applications
from mlia.backend.application import install_application
@@ -58,7 +55,7 @@ def get_system_name(backend: str, device_type: str) -> str:
return _SUPPORTED_SYSTEMS[backend][device_type]
-def is_supported(backend: str, device_type: Optional[str] = None) -> bool:
+def is_supported(backend: str, device_type: str | None = None) -> bool:
"""Check if the backend (and optionally device type) is supported."""
if device_type is None:
return backend in _SUPPORTED_SYSTEMS
@@ -70,17 +67,17 @@ def is_supported(backend: str, device_type: Optional[str] = None) -> bool:
return False
-def supported_backends() -> List[str]:
+def supported_backends() -> list[str]:
"""Get a list of all backends supported by the backend manager."""
return list(_SUPPORTED_SYSTEMS.keys())
-def get_all_system_names(backend: str) -> List[str]:
+def get_all_system_names(backend: str) -> list[str]:
"""Get all systems supported by the backend."""
return list(_SUPPORTED_SYSTEMS.get(backend, {}).values())
-def get_all_application_names(backend: str) -> List[str]:
+def get_all_application_names(backend: str) -> list[str]:
"""Get all applications supported by the backend."""
app_set = {
app
@@ -124,8 +121,8 @@ class ExecutionParams:
application: str
system: str
- application_params: List[str]
- system_params: List[str]
+ application_params: list[str]
+ system_params: list[str]
class LogWriter(OutputConsumer):
@@ -153,7 +150,7 @@ class GenericInferenceOutputParser(Base64OutputConsumer):
}
@property
- def result(self) -> Dict:
+ def result(self) -> dict:
"""Merge the raw results and map the names to the right output names."""
merged_result = {}
for raw_result in self.parsed_output:
@@ -172,7 +169,7 @@ class GenericInferenceOutputParser(Base64OutputConsumer):
"""Return true if all expected data has been parsed."""
return set(self.result.keys()) == set(self._map.values())
- def missed_keys(self) -> Set[str]:
+ def missed_keys(self) -> set[str]:
"""Return a set of the keys that have not been found in the output."""
return set(self._map.values()) - set(self.result.keys())
@@ -184,12 +181,12 @@ class BackendRunner:
"""Init BackendRunner instance."""
@staticmethod
- def get_installed_systems() -> List[str]:
+ def get_installed_systems() -> list[str]:
"""Get list of the installed systems."""
return [system.name for system in get_available_systems()]
@staticmethod
- def get_installed_applications(system: Optional[str] = None) -> List[str]:
+ def get_installed_applications(system: str | None = None) -> list[str]:
"""Get list of the installed application."""
return [
app.name
@@ -205,7 +202,7 @@ class BackendRunner:
"""Return true if requested system installed."""
return system in self.get_installed_systems()
- def systems_installed(self, systems: List[str]) -> bool:
+ def systems_installed(self, systems: list[str]) -> bool:
"""Check if all provided systems are installed."""
if not systems:
return False
@@ -213,7 +210,7 @@ class BackendRunner:
installed_systems = self.get_installed_systems()
return all(system in installed_systems for system in systems)
- def applications_installed(self, applications: List[str]) -> bool:
+ def applications_installed(self, applications: list[str]) -> bool:
"""Check if all provided applications are installed."""
if not applications:
return False
@@ -221,7 +218,7 @@ class BackendRunner:
installed_apps = self.get_installed_applications()
return all(app in installed_apps for app in applications)
- def all_installed(self, systems: List[str], apps: List[str]) -> bool:
+ def all_installed(self, systems: list[str], apps: list[str]) -> bool:
"""Check if all provided artifacts are installed."""
return self.systems_installed(systems) and self.applications_installed(apps)
@@ -247,7 +244,7 @@ class BackendRunner:
return ctx
@staticmethod
- def _params(name: str, params: List[str]) -> List[str]:
+ def _params(name: str, params: list[str]) -> list[str]:
return [p for item in [(name, param) for param in params] for p in item]
@@ -259,7 +256,7 @@ class GenericInferenceRunner(ABC):
self.backend_runner = backend_runner
def run(
- self, model_info: ModelInfo, output_consumers: List[OutputConsumer]
+ self, model_info: ModelInfo, output_consumers: list[OutputConsumer]
) -> None:
"""Run generic inference for the provided device/model."""
execution_params = self.get_execution_params(model_info)
@@ -284,7 +281,7 @@ class GenericInferenceRunner(ABC):
)
@staticmethod
- def consume_output(output: bytearray, consumers: List[OutputConsumer]) -> bytearray:
+ def consume_output(output: bytearray, consumers: list[OutputConsumer]) -> bytearray:
"""
Pass program's output to the consumers and filter it.
@@ -320,7 +317,7 @@ class GenericInferenceRunnerEthosU(GenericInferenceRunner):
@staticmethod
def resolve_system_and_app(
device_info: DeviceInfo, backend: str
- ) -> Tuple[str, str]:
+ ) -> tuple[str, str]:
"""Find appropriate system and application for the provided device/backend."""
try:
system_name = get_system_name(backend, device_info.device_type)
diff --git a/src/mlia/backend/output_consumer.py b/src/mlia/backend/output_consumer.py
index bac4186..3c3b132 100644
--- a/src/mlia/backend/output_consumer.py
+++ b/src/mlia/backend/output_consumer.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Output consumers module."""
+from __future__ import annotations
+
import base64
import json
import re
-from typing import List
from typing import Protocol
from typing import runtime_checkable
@@ -37,7 +38,7 @@ class Base64OutputConsumer(OutputConsumer):
def __init__(self) -> None:
"""Set up the regular expression to extract tagged strings."""
self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)</{self.TAG_NAME}>")
- self.parsed_output: List = []
+ self.parsed_output: list = []
def feed(self, line: str) -> bool:
"""
diff --git a/src/mlia/backend/proc.py b/src/mlia/backend/proc.py
index 911d672..7b3e92a 100644
--- a/src/mlia/backend/proc.py
+++ b/src/mlia/backend/proc.py
@@ -5,6 +5,8 @@
This module contains all classes and functions for dealing with Linux
processes.
"""
+from __future__ import annotations
+
import datetime
import logging
import shlex
@@ -13,9 +15,6 @@ import tempfile
import time
from pathlib import Path
from typing import Any
-from typing import List
-from typing import Optional
-from typing import Tuple
from sh import Command
from sh import CommandNotFound
@@ -38,12 +37,12 @@ class ShellCommand:
self,
cmd: str,
*args: str,
- _cwd: Optional[Path] = None,
+ _cwd: Path | None = None,
_tee: bool = True,
_bg: bool = True,
_out: Any = None,
_err: Any = None,
- _search_paths: Optional[List[Path]] = None,
+ _search_paths: list[Path] | None = None,
) -> RunningCommand:
"""Run the shell command with the given arguments.
@@ -72,7 +71,7 @@ class ShellCommand:
return command(_out=out, _err=err, _tee=_tee, _bg=_bg, _bg_exc=False)
@classmethod
- def get_stdout_stderr_paths(cls, cmd: str) -> Tuple[Path, Path]:
+ def get_stdout_stderr_paths(cls, cmd: str) -> tuple[Path, Path]:
"""Construct and returns the paths of stdout/stderr files."""
timestamp = datetime.datetime.now().timestamp()
base_path = Path(tempfile.mkdtemp(prefix="mlia-", suffix=f"{timestamp}"))
@@ -88,7 +87,7 @@ class ShellCommand:
return stdout, stderr
-def parse_command(command: str, shell: str = "bash") -> List[str]:
+def parse_command(command: str, shell: str = "bash") -> list[str]:
"""Parse command."""
cmd, *args = shlex.split(command, posix=True)
@@ -130,13 +129,13 @@ def run_and_wait(
terminate_on_error: bool = False,
out: Any = None,
err: Any = None,
-) -> Tuple[int, bytearray, bytearray]:
+) -> tuple[int, bytearray, bytearray]:
"""
Run command and wait while it is executing.
Returns a tuple: (exit_code, stdout, stderr)
"""
- running_cmd: Optional[RunningCommand] = None
+ running_cmd: RunningCommand | None = None
try:
running_cmd = execute_command(command, cwd, bg=True, out=out, err=err)
return running_cmd.exit_code, running_cmd.stdout, running_cmd.stderr
diff --git a/src/mlia/backend/source.py b/src/mlia/backend/source.py
index f80a774..c951eae 100644
--- a/src/mlia/backend/source.py
+++ b/src/mlia/backend/source.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Contain source related classes and functions."""
+from __future__ import annotations
+
import os
import shutil
import tarfile
@@ -8,8 +10,6 @@ from abc import ABC
from abc import abstractmethod
from pathlib import Path
from tarfile import TarFile
-from typing import Optional
-from typing import Union
from mlia.backend.common import BACKEND_CONFIG_FILE
from mlia.backend.common import ConfigurationException
@@ -24,11 +24,11 @@ class Source(ABC):
"""Source class."""
@abstractmethod
- def name(self) -> Optional[str]:
+ def name(self) -> str | None:
"""Get source name."""
@abstractmethod
- def config(self) -> Optional[BackendConfig]:
+ def config(self) -> BackendConfig | None:
"""Get configuration file content."""
@abstractmethod
@@ -52,7 +52,7 @@ class DirectorySource(Source):
"""Return name of source."""
return self.directory_path.name
- def config(self) -> Optional[BackendConfig]:
+ def config(self) -> BackendConfig | None:
"""Return configuration file content."""
if not is_backend_directory(self.directory_path):
raise ConfigurationException("No configuration file found")
@@ -84,9 +84,9 @@ class TarArchiveSource(Source):
"""Create the TarArchiveSource class."""
assert isinstance(archive_path, Path)
self.archive_path = archive_path
- self._config: Optional[BackendConfig] = None
- self._has_top_level_folder: Optional[bool] = None
- self._name: Optional[str] = None
+ self._config: BackendConfig | None = None
+ self._has_top_level_folder: bool | None = None
+ self._name: str | None = None
def _read_archive_content(self) -> None:
"""Read various information about archive."""
@@ -125,14 +125,14 @@ class TarArchiveSource(Source):
content = archive.extractfile(config_entry)
self._config = load_config(content)
- def config(self) -> Optional[BackendConfig]:
+ def config(self) -> BackendConfig | None:
"""Return configuration file content."""
if self._config is None:
self._read_archive_content()
return self._config
- def name(self) -> Optional[str]:
+ def name(self) -> str | None:
"""Return name of the source."""
if self._name is None:
self._read_archive_content()
@@ -171,7 +171,7 @@ class TarArchiveSource(Source):
)
-def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]:
+def get_source(source_path: Path) -> TarArchiveSource | DirectorySource:
"""Return appropriate source instance based on provided source path."""
if source_path.is_file():
return TarArchiveSource(source_path)
diff --git a/src/mlia/backend/system.py b/src/mlia/backend/system.py
index ff85bf3..0e51ab2 100644
--- a/src/mlia/backend/system.py
+++ b/src/mlia/backend/system.py
@@ -1,12 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""System backend module."""
+from __future__ import annotations
+
from pathlib import Path
from typing import Any
from typing import cast
-from typing import Dict
from typing import List
-from typing import Tuple
from mlia.backend.common import Backend
from mlia.backend.common import ConfigurationException
@@ -33,7 +33,7 @@ class System(Backend):
def _setup_reporting(self, config: SystemConfig) -> None:
self.reporting = config.get("reporting")
- def run(self, command: str) -> Tuple[int, bytearray, bytearray]:
+ def run(self, command: str) -> tuple[int, bytearray, bytearray]:
"""
Run command on the system.
@@ -63,7 +63,7 @@ class System(Backend):
return super().__eq__(other) and self.name == other.name
- def get_details(self) -> Dict[str, Any]:
+ def get_details(self) -> dict[str, Any]:
"""Return a dictionary with all relevant information of a System."""
output = {
"type": "system",
@@ -76,12 +76,12 @@ class System(Backend):
return output
-def get_available_systems_directory_names() -> List[str]:
+def get_available_systems_directory_names() -> list[str]:
"""Return a list of directory names for all avialable systems."""
return [entry.name for entry in get_backend_directories("systems")]
-def get_available_systems() -> List[System]:
+def get_available_systems() -> list[System]:
"""Return a list with all available systems."""
available_systems = []
for config_json in get_backend_configs("systems"):
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index 45c7c32..5dd39f9 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -16,11 +16,11 @@ be configured. Function 'setup_logging' from module
>>> mlia.all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
"path/to/model")
"""
+from __future__ import annotations
+
import logging
from pathlib import Path
from typing import cast
-from typing import List
-from typing import Optional
from mlia.api import ExecutionContext
from mlia.api import get_advice
@@ -42,8 +42,8 @@ def all_tests(
model: str,
optimization_type: str = "pruning,clustering",
optimization_target: str = "0.5,32",
- output: Optional[PathOrFileLike] = None,
- evaluate_on: Optional[List[str]] = None,
+ output: PathOrFileLike | None = None,
+ evaluate_on: list[str] | None = None,
) -> None:
"""Generate a full report on the input model.
@@ -99,8 +99,8 @@ def all_tests(
def operators(
ctx: ExecutionContext,
target_profile: str,
- model: Optional[str] = None,
- output: Optional[PathOrFileLike] = None,
+ model: str | None = None,
+ output: PathOrFileLike | None = None,
supported_ops_report: bool = False,
) -> None:
"""Print the model's operator list.
@@ -149,8 +149,8 @@ def performance(
ctx: ExecutionContext,
target_profile: str,
model: str,
- output: Optional[PathOrFileLike] = None,
- evaluate_on: Optional[List[str]] = None,
+ output: PathOrFileLike | None = None,
+ evaluate_on: list[str] | None = None,
) -> None:
"""Print the model's performance stats.
@@ -192,9 +192,9 @@ def optimization(
model: str,
optimization_type: str,
optimization_target: str,
- layers_to_optimize: Optional[List[str]] = None,
- output: Optional[PathOrFileLike] = None,
- evaluate_on: Optional[List[str]] = None,
+ layers_to_optimize: list[str] | None = None,
+ output: PathOrFileLike | None = None,
+ evaluate_on: list[str] | None = None,
) -> None:
"""Show the performance improvements (if any) after applying the optimizations.
@@ -245,9 +245,9 @@ def optimization(
def backend(
backend_action: str,
- path: Optional[Path] = None,
+ path: Path | None = None,
download: bool = False,
- name: Optional[str] = None,
+ name: str | None = None,
i_agree_to_the_contained_eula: bool = False,
noninteractive: bool = False,
) -> None:
diff --git a/src/mlia/cli/common.py b/src/mlia/cli/common.py
index 54bd457..3f60668 100644
--- a/src/mlia/cli/common.py
+++ b/src/mlia/cli/common.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI common module."""
+from __future__ import annotations
+
import argparse
from dataclasses import dataclass
from typing import Callable
-from typing import List
@dataclass
@@ -12,8 +13,8 @@ class CommandInfo:
"""Command description."""
func: Callable
- aliases: List[str]
- opt_groups: List[Callable[[argparse.ArgumentParser], None]]
+ aliases: list[str]
+ opt_groups: list[Callable[[argparse.ArgumentParser], None]]
is_default: bool = False
@property
@@ -22,7 +23,7 @@ class CommandInfo:
return self.func.__name__
@property
- def command_name_and_aliases(self) -> List[str]:
+ def command_name_and_aliases(self) -> list[str]:
"""Return list of command name and aliases."""
return [self.command_name, *self.aliases]
diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py
index a673230..dc28fa2 100644
--- a/src/mlia/cli/config.py
+++ b/src/mlia/cli/config.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Environment configuration functions."""
+from __future__ import annotations
+
import logging
from functools import lru_cache
-from typing import List
import mlia.backend.manager as backend_manager
from mlia.tools.metadata.common import DefaultInstallationManager
@@ -21,7 +22,7 @@ def get_installation_manager(noninteractive: bool = False) -> InstallationManage
@lru_cache
-def get_available_backends() -> List[str]:
+def get_available_backends() -> list[str]:
"""Return list of the available backends."""
available_backends = ["Vela"]
@@ -42,7 +43,7 @@ def get_available_backends() -> List[str]:
_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300")
-def get_default_backends() -> List[str]:
+def get_default_backends() -> list[str]:
"""Get default backends for evaluation."""
backends = get_available_backends()
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
index 81d5a15..acec837 100644
--- a/src/mlia/cli/helpers.py
+++ b/src/mlia/cli/helpers.py
@@ -1,11 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for various helper classes."""
+from __future__ import annotations
+
from typing import Any
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
from mlia.cli.options import get_target_profile_opts
from mlia.core.helpers import ActionResolver
@@ -17,12 +15,12 @@ from mlia.utils.types import is_list_of
class CLIActionResolver(ActionResolver):
"""Helper class for generating cli commands."""
- def __init__(self, args: Dict[str, Any]) -> None:
+ def __init__(self, args: dict[str, Any]) -> None:
"""Init action resolver."""
self.args = args
@staticmethod
- def _general_optimization_command(model_path: Optional[str]) -> List[str]:
+ def _general_optimization_command(model_path: str | None) -> list[str]:
"""Return general optimization command description."""
keras_note = []
if model_path is None or not is_keras_model(model_path):
@@ -40,8 +38,8 @@ class CLIActionResolver(ActionResolver):
def _specific_optimization_command(
model_path: str,
device_opts: str,
- opt_settings: List[OptimizationSettings],
- ) -> List[str]:
+ opt_settings: list[OptimizationSettings],
+ ) -> list[str]:
"""Return specific optimization command description."""
opt_types = ",".join(opt.optimization_type for opt in opt_settings)
opt_targs = ",".join(str(opt.optimization_target) for opt in opt_settings)
@@ -53,7 +51,7 @@ class CLIActionResolver(ActionResolver):
f"--optimization-target {opt_targs}{device_opts} {model_path}",
]
- def apply_optimizations(self, **kwargs: Any) -> List[str]:
+ def apply_optimizations(self, **kwargs: Any) -> list[str]:
"""Return command details for applying optimizations."""
model_path, device_opts = self._get_model_and_device_opts()
@@ -67,14 +65,14 @@ class CLIActionResolver(ActionResolver):
return []
- def supported_operators_info(self) -> List[str]:
+ def supported_operators_info(self) -> list[str]:
"""Return command details for generating supported ops report."""
return [
"For guidance on supported operators, run: mlia operators "
"--supported-ops-report",
]
- def check_performance(self) -> List[str]:
+ def check_performance(self) -> list[str]:
"""Return command details for checking performance."""
model_path, device_opts = self._get_model_and_device_opts()
if not model_path:
@@ -85,7 +83,7 @@ class CLIActionResolver(ActionResolver):
f"mlia performance{device_opts} {model_path}",
]
- def check_operator_compatibility(self) -> List[str]:
+ def check_operator_compatibility(self) -> list[str]:
"""Return command details for op compatibility."""
model_path, device_opts = self._get_model_and_device_opts()
if not model_path:
@@ -96,17 +94,17 @@ class CLIActionResolver(ActionResolver):
f"mlia operators{device_opts} {model_path}",
]
- def operator_compatibility_details(self) -> List[str]:
+ def operator_compatibility_details(self) -> list[str]:
"""Return command details for op compatibility."""
return ["For more details, run: mlia operators --help"]
- def optimization_details(self) -> List[str]:
+ def optimization_details(self) -> list[str]:
"""Return command details for optimization."""
return ["For more info, see: mlia optimization --help"]
def _get_model_and_device_opts(
self, separate_device_opts: bool = True
- ) -> Tuple[Optional[str], str]:
+ ) -> tuple[str | None, str]:
"""Get model and device options."""
device_opts = " ".join(get_target_profile_opts(self.args))
if separate_device_opts and device_opts:
diff --git a/src/mlia/cli/logging.py b/src/mlia/cli/logging.py
index c5fc7bd..40f47d3 100644
--- a/src/mlia/cli/logging.py
+++ b/src/mlia/cli/logging.py
@@ -1,12 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI logging configuration."""
+from __future__ import annotations
+
import logging
import sys
from pathlib import Path
-from typing import List
-from typing import Optional
-from typing import Union
from mlia.utils.logging import attach_handlers
from mlia.utils.logging import create_log_handler
@@ -18,7 +17,7 @@ _FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
def setup_logging(
- logs_dir: Optional[Union[str, Path]] = None,
+ logs_dir: str | Path | None = None,
verbose: bool = False,
log_filename: str = "mlia.log",
) -> None:
@@ -49,10 +48,10 @@ def setup_logging(
def _get_mlia_handlers(
- logs_dir: Optional[Union[str, Path]],
+ logs_dir: str | Path | None,
log_filename: str,
verbose: bool,
-) -> List[logging.Handler]:
+) -> list[logging.Handler]:
"""Get handlers for the MLIA loggers."""
result = []
stdout_handler = create_log_handler(
@@ -84,10 +83,10 @@ def _get_mlia_handlers(
def _get_tools_handlers(
- logs_dir: Optional[Union[str, Path]],
+ logs_dir: str | Path | None,
log_filename: str,
verbose: bool,
-) -> List[logging.Handler]:
+) -> list[logging.Handler]:
"""Get handler for the tools loggers."""
result = []
if verbose:
@@ -110,7 +109,7 @@ def _get_tools_handlers(
return result
-def _get_log_file(logs_dir: Union[str, Path], log_filename: str) -> Path:
+def _get_log_file(logs_dir: str | Path, log_filename: str) -> Path:
"""Get the log file path."""
logs_dir_path = Path(logs_dir)
logs_dir_path.mkdir(exist_ok=True)
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index f8fc00c..0ece289 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -1,16 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI main entry point."""
+from __future__ import annotations
+
import argparse
import logging
import sys
from functools import partial
from inspect import signature
from pathlib import Path
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
from mlia import __version__
from mlia.cli.commands import all_tests
@@ -50,7 +48,7 @@ Supported targets:
""".strip()
-def get_commands() -> List[CommandInfo]:
+def get_commands() -> list[CommandInfo]:
"""Return commands configuration."""
return [
CommandInfo(
@@ -111,7 +109,7 @@ def get_commands() -> List[CommandInfo]:
]
-def get_default_command() -> Optional[str]:
+def get_default_command() -> str | None:
"""Get name of the default command."""
commands = get_commands()
@@ -121,7 +119,7 @@ def get_default_command() -> Optional[str]:
return next(iter(marked_as_default), None)
-def get_possible_command_names() -> List[str]:
+def get_possible_command_names() -> list[str]:
"""Get all possible command names including aliases."""
return [
name_or_alias
@@ -151,7 +149,7 @@ def init_commands(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
def setup_context(
args: argparse.Namespace, context_var_name: str = "ctx"
-) -> Tuple[ExecutionContext, Dict]:
+) -> tuple[ExecutionContext, dict]:
"""Set up context and resolve function parameters."""
ctx = ExecutionContext(
working_dir=args.working_dir,
@@ -252,7 +250,7 @@ def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.Argument
return parser
-def add_default_command_if_needed(args: List[str]) -> None:
+def add_default_command_if_needed(args: list[str]) -> None:
"""Add default command to the list of the arguments if needed."""
default_command = get_default_command()
@@ -265,7 +263,7 @@ def add_default_command_if_needed(args: List[str]) -> None:
args.insert(0, default_command)
-def main(argv: Optional[List[str]] = None) -> int:
+def main(argv: list[str] | None = None) -> int:
"""Entry point of the application."""
common_parser = init_common_parser()
subcommand_parser = init_subcommand_parser(common_parser)
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 29a0d89..3f0dc1f 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -1,13 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the CLI options."""
+from __future__ import annotations
+
import argparse
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import Dict
-from typing import List
-from typing import Optional
from mlia.cli.config import get_available_backends
from mlia.cli.config import get_default_backends
@@ -17,7 +16,7 @@ from mlia.utils.types import is_number
def add_target_options(
- parser: argparse.ArgumentParser, profiles_to_skip: Optional[List[str]] = None
+ parser: argparse.ArgumentParser, profiles_to_skip: list[str] | None = None
) -> None:
"""Add target specific options."""
target_profiles = get_supported_profile_names()
@@ -217,8 +216,8 @@ def parse_optimization_parameters(
optimization_type: str,
optimization_target: str,
sep: str = ",",
- layers_to_optimize: Optional[List[str]] = None,
-) -> List[Dict[str, Any]]:
+ layers_to_optimize: list[str] | None = None,
+) -> list[dict[str, Any]]:
"""Parse provided optimization parameters."""
if not optimization_type:
raise Exception("Optimization type is not provided")
@@ -250,7 +249,7 @@ def parse_optimization_parameters(
return optimizer_params
-def get_target_profile_opts(device_args: Optional[Dict]) -> List[str]:
+def get_target_profile_opts(device_args: dict | None) -> list[str]:
"""Get non default values passed as parameters for the target profile."""
if not device_args:
return []
@@ -270,7 +269,7 @@ def get_target_profile_opts(device_args: Optional[Dict]) -> List[str]:
if arg_name in args and vars(args)[arg_name] != arg_value
]
- def construct_param(name: str, value: Any) -> List[str]:
+ def construct_param(name: str, value: Any) -> list[str]:
"""Construct parameter."""
if isinstance(value, list):
return [str(item) for v in value for item in [name, v]]
diff --git a/src/mlia/core/advice_generation.py b/src/mlia/core/advice_generation.py
index 76cc1f2..86285fe 100644
--- a/src/mlia/core/advice_generation.py
+++ b/src/mlia/core/advice_generation.py
@@ -1,14 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for advice generation."""
+from __future__ import annotations
+
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from functools import wraps
from typing import Any
from typing import Callable
-from typing import List
-from typing import Union
from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
@@ -20,7 +20,7 @@ from mlia.core.mixins import ContextMixin
class Advice:
"""Base class for the advice."""
- messages: List[str]
+ messages: list[str]
@dataclass
@@ -56,7 +56,7 @@ class AdviceProducer(ABC):
"""
@abstractmethod
- def get_advice(self) -> Union[Advice, List[Advice]]:
+ def get_advice(self) -> Advice | list[Advice]:
"""Get produced advice."""
@@ -76,13 +76,13 @@ class FactBasedAdviceProducer(ContextAwareAdviceProducer):
def __init__(self) -> None:
"""Init advice producer."""
- self.advice: List[Advice] = []
+ self.advice: list[Advice] = []
- def get_advice(self) -> Union[Advice, List[Advice]]:
+ def get_advice(self) -> Advice | list[Advice]:
"""Get produced advice."""
return self.advice
- def add_advice(self, messages: List[str]) -> None:
+ def add_advice(self, messages: list[str]) -> None:
"""Add advice."""
self.advice.append(Advice(messages))
diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py
index 13689fa..d684241 100644
--- a/src/mlia/core/advisor.py
+++ b/src/mlia/core/advisor.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Inference advisor module."""
+from __future__ import annotations
+
from abc import abstractmethod
from pathlib import Path
from typing import cast
-from typing import List
from mlia.core.advice_generation import AdviceProducer
from mlia.core.common import NamedEntity
@@ -44,19 +45,19 @@ class DefaultInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
)
@abstractmethod
- def get_collectors(self, context: Context) -> List[DataCollector]:
+ def get_collectors(self, context: Context) -> list[DataCollector]:
"""Return list of the data collectors."""
@abstractmethod
- def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
+ def get_analyzers(self, context: Context) -> list[DataAnalyzer]:
"""Return list of the data analyzers."""
@abstractmethod
- def get_producers(self, context: Context) -> List[AdviceProducer]:
+ def get_producers(self, context: Context) -> list[AdviceProducer]:
"""Return list of the advice producers."""
@abstractmethod
- def get_events(self, context: Context) -> List[Event]:
+ def get_events(self, context: Context) -> list[Event]:
"""Return list of the startup events."""
def get_string_parameter(self, context: Context, param: str) -> str:
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
index a11bf9a..63fb324 100644
--- a/src/mlia/core/common.py
+++ b/src/mlia/core/common.py
@@ -5,6 +5,8 @@
This module contains common interfaces/classess shared across
core module.
"""
+from __future__ import annotations
+
from abc import ABC
from abc import abstractmethod
from enum import auto
@@ -30,7 +32,7 @@ class AdviceCategory(Flag):
ALL = OPERATORS | PERFORMANCE | OPTIMIZATION
@classmethod
- def from_string(cls, value: str) -> "AdviceCategory":
+ def from_string(cls, value: str) -> AdviceCategory:
"""Resolve enum value from string value."""
category_names = [item.name for item in AdviceCategory]
if not value or value.upper() not in category_names:
diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py
index 83d2f7c..a4737bb 100644
--- a/src/mlia/core/context.py
+++ b/src/mlia/core/context.py
@@ -7,15 +7,14 @@ Context is an object that describes advisor working environment
and requested behavior (advice categories, input configuration
parameters).
"""
+from __future__ import annotations
+
import logging
from abc import ABC
from abc import abstractmethod
from pathlib import Path
from typing import Any
-from typing import List
from typing import Mapping
-from typing import Optional
-from typing import Union
from mlia.core.common import AdviceCategory
from mlia.core.events import DefaultEventPublisher
@@ -50,7 +49,7 @@ class Context(ABC):
@property
@abstractmethod
- def event_handlers(self) -> Optional[List[EventHandler]]:
+ def event_handlers(self) -> list[EventHandler] | None:
"""Return list of the event_handlers."""
@property
@@ -60,7 +59,7 @@ class Context(ABC):
@property
@abstractmethod
- def config_parameters(self) -> Optional[Mapping[str, Any]]:
+ def config_parameters(self) -> Mapping[str, Any] | None:
"""Return configuration parameters."""
@property
@@ -73,7 +72,7 @@ class Context(ABC):
self,
*,
advice_category: AdviceCategory,
- event_handlers: List[EventHandler],
+ event_handlers: list[EventHandler],
config_parameters: Mapping[str, Any],
) -> None:
"""Update context parameters."""
@@ -98,14 +97,14 @@ class ExecutionContext(Context):
self,
*,
advice_category: AdviceCategory = AdviceCategory.ALL,
- config_parameters: Optional[Mapping[str, Any]] = None,
- working_dir: Optional[Union[str, Path]] = None,
- event_handlers: Optional[List[EventHandler]] = None,
- event_publisher: Optional[EventPublisher] = None,
+ config_parameters: Mapping[str, Any] | None = None,
+ working_dir: str | Path | None = None,
+ event_handlers: list[EventHandler] | None = None,
+ event_publisher: EventPublisher | None = None,
verbose: bool = False,
logs_dir: str = "logs",
models_dir: str = "models",
- action_resolver: Optional[ActionResolver] = None,
+ action_resolver: ActionResolver | None = None,
) -> None:
"""Init execution context.
@@ -151,22 +150,22 @@ class ExecutionContext(Context):
self._advice_category = advice_category
@property
- def config_parameters(self) -> Optional[Mapping[str, Any]]:
+ def config_parameters(self) -> Mapping[str, Any] | None:
"""Return configuration parameters."""
return self._config_parameters
@config_parameters.setter
- def config_parameters(self, config_parameters: Optional[Mapping[str, Any]]) -> None:
+ def config_parameters(self, config_parameters: Mapping[str, Any] | None) -> None:
"""Setter for the configuration parameters."""
self._config_parameters = config_parameters
@property
- def event_handlers(self) -> Optional[List[EventHandler]]:
+ def event_handlers(self) -> list[EventHandler] | None:
"""Return list of the event handlers."""
return self._event_handlers
@event_handlers.setter
- def event_handlers(self, event_handlers: List[EventHandler]) -> None:
+ def event_handlers(self, event_handlers: list[EventHandler]) -> None:
"""Setter for the event handlers."""
self._event_handlers = event_handlers
@@ -196,7 +195,7 @@ class ExecutionContext(Context):
self,
*,
advice_category: AdviceCategory,
- event_handlers: List[EventHandler],
+ event_handlers: list[EventHandler],
config_parameters: Mapping[str, Any],
) -> None:
"""Update context parameters."""
diff --git a/src/mlia/core/data_analysis.py b/src/mlia/core/data_analysis.py
index 6adb41e..0603425 100644
--- a/src/mlia/core/data_analysis.py
+++ b/src/mlia/core/data_analysis.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for data analysis."""
+from __future__ import annotations
+
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
-from typing import List
from mlia.core.common import DataItem
from mlia.core.mixins import ContextMixin
@@ -29,7 +30,7 @@ class DataAnalyzer(ABC):
"""
@abstractmethod
- def get_analyzed_data(self) -> List[DataItem]:
+ def get_analyzed_data(self) -> list[DataItem]:
"""Get analyzed data."""
@@ -59,9 +60,9 @@ class FactExtractor(ContextAwareDataAnalyzer):
def __init__(self) -> None:
"""Init fact extractor."""
- self.facts: List[Fact] = []
+ self.facts: list[Fact] = []
- def get_analyzed_data(self) -> List[DataItem]:
+ def get_analyzed_data(self) -> list[DataItem]:
"""Return list of the collected facts."""
return self.facts
diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py
index 0b8461b..71c86e2 100644
--- a/src/mlia/core/events.py
+++ b/src/mlia/core/events.py
@@ -9,6 +9,8 @@ calling application.
Each component of the workflow can generate events of specific type.
Application can subscribe and react to those events.
"""
+from __future__ import annotations
+
import traceback
import uuid
from abc import ABC
@@ -19,11 +21,7 @@ from dataclasses import dataclass
from dataclasses import field
from functools import singledispatchmethod
from typing import Any
-from typing import Dict
from typing import Generator
-from typing import List
-from typing import Optional
-from typing import Tuple
from mlia.core.common import DataItem
@@ -41,7 +39,7 @@ class Event:
"""Generate unique ID for the event."""
self.event_id = str(uuid.uuid4())
- def compare_without_id(self, other: "Event") -> bool:
+ def compare_without_id(self, other: Event) -> bool:
"""Compare two events without event_id field."""
if not isinstance(other, Event) or self.__class__ != other.__class__:
return False
@@ -73,7 +71,7 @@ class ActionStartedEvent(Event):
"""
action_type: str
- params: Optional[Dict] = None
+ params: dict | None = None
@dataclass
@@ -84,7 +82,7 @@ class SubActionEvent(ChildEvent):
"""
action_type: str
- params: Optional[Dict] = None
+ params: dict | None = None
@dataclass
@@ -271,8 +269,8 @@ class EventDispatcherMetaclass(type):
def __new__(
cls,
clsname: str,
- bases: Tuple,
- namespace: Dict[str, Any],
+ bases: tuple[type, ...],
+ namespace: dict[str, Any],
event_handler_method_prefix: str = "on_",
) -> Any:
"""Create event dispatcher and link event handlers."""
@@ -321,7 +319,7 @@ class EventPublisher(ABC):
"""
def register_event_handlers(
- self, event_handlers: Optional[List[EventHandler]]
+ self, event_handlers: list[EventHandler] | None
) -> None:
"""Register event handlers.
@@ -354,7 +352,7 @@ class DefaultEventPublisher(EventPublisher):
def __init__(self) -> None:
"""Init the event publisher."""
- self.handlers: List[EventHandler] = []
+ self.handlers: list[EventHandler] = []
def register_event_handler(self, event_handler: EventHandler) -> None:
"""Register the event handler.
@@ -374,7 +372,7 @@ class DefaultEventPublisher(EventPublisher):
@contextmanager
def stage(
- publisher: EventPublisher, events: Tuple[Event, Event]
+ publisher: EventPublisher, events: tuple[Event, Event]
) -> Generator[None, None, None]:
"""Generate events before and after stage.
@@ -390,7 +388,7 @@ def stage(
@contextmanager
def action(
- publisher: EventPublisher, action_type: str, params: Optional[Dict] = None
+ publisher: EventPublisher, action_type: str, params: dict | None = None
) -> Generator[None, None, None]:
"""Generate events before and after action."""
action_started = ActionStartedEvent(action_type, params)
diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py
index e576f74..a3255ae 100644
--- a/src/mlia/core/handlers.py
+++ b/src/mlia/core/handlers.py
@@ -1,13 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Event handlers module."""
+from __future__ import annotations
+
import logging
from typing import Any
from typing import Callable
-from typing import List
-from typing import Optional
-from mlia.core._typing import PathOrFileLike
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
from mlia.core.events import ActionFinishedEvent
@@ -28,6 +27,7 @@ from mlia.core.events import ExecutionStartedEvent
from mlia.core.reporting import Report
from mlia.core.reporting import Reporter
from mlia.core.reporting import resolve_output_format
+from mlia.core.typing import PathOrFileLike
from mlia.utils.console import create_section_header
@@ -101,14 +101,14 @@ class WorkflowEventsHandler(SystemEventsHandler):
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output: Optional[PathOrFileLike] = None,
+ output: PathOrFileLike | None = None,
) -> None:
"""Init event handler."""
output_format = resolve_output_format(output)
self.reporter = Reporter(formatter_resolver, output_format)
self.output = output
- self.advice: List[Advice] = []
+ self.advice: list[Advice] = []
def on_execution_started(self, event: ExecutionStartedEvent) -> None:
"""Handle ExecutionStarted event."""
diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py
index d10ea5d..f0c4474 100644
--- a/src/mlia/core/helpers.py
+++ b/src/mlia/core/helpers.py
@@ -2,34 +2,35 @@
# SPDX-License-Identifier: Apache-2.0
"""Module for various helper classes."""
# pylint: disable=no-self-use, unused-argument
+from __future__ import annotations
+
from typing import Any
-from typing import List
class ActionResolver:
"""Helper class for generating actions (e.g. commands with parameters)."""
- def apply_optimizations(self, **kwargs: Any) -> List[str]:
+ def apply_optimizations(self, **kwargs: Any) -> list[str]:
"""Return action details for applying optimizations."""
return []
- def supported_operators_info(self) -> List[str]:
+ def supported_operators_info(self) -> list[str]:
"""Return action details for generating supported ops report."""
return []
- def check_performance(self) -> List[str]:
+ def check_performance(self) -> list[str]:
"""Return action details for checking performance."""
return []
- def check_operator_compatibility(self) -> List[str]:
+ def check_operator_compatibility(self) -> list[str]:
"""Return action details for checking op compatibility."""
return []
- def operator_compatibility_details(self) -> List[str]:
+ def operator_compatibility_details(self) -> list[str]:
"""Return action details for getting more information about op compatibility."""
return []
- def optimization_details(self) -> List[str]:
+ def optimization_details(self) -> list[str]:
"""Return action detail for getting information about optimizations."""
return []
diff --git a/src/mlia/core/mixins.py b/src/mlia/core/mixins.py
index ee03100..5ef9d66 100644
--- a/src/mlia/core/mixins.py
+++ b/src/mlia/core/mixins.py
@@ -1,8 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Mixins module."""
+from __future__ import annotations
+
from typing import Any
-from typing import Optional
from mlia.core.context import Context
@@ -27,8 +28,8 @@ class ParameterResolverMixin:
section: str,
name: str,
expected: bool = True,
- expected_type: Optional[type] = None,
- context: Optional[Context] = None,
+ expected_type: type | None = None,
+ context: Context | None = None,
) -> Any:
"""Get parameter value."""
ctx = context or self.context
diff --git a/src/mlia/core/performance.py b/src/mlia/core/performance.py
index 5433d5c..cb12918 100644
--- a/src/mlia/core/performance.py
+++ b/src/mlia/core/performance.py
@@ -1,30 +1,31 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for performance estimation."""
+from __future__ import annotations
+
from abc import abstractmethod
from typing import Callable
from typing import Generic
-from typing import List
from typing import TypeVar
-ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
-PerfMetricsType = TypeVar("PerfMetricsType") # pylint: disable=invalid-name
+M = TypeVar("M") # model type
+P = TypeVar("P") # performance metrics
-class PerformanceEstimator(Generic[ModelType, PerfMetricsType]):
+class PerformanceEstimator(Generic[M, P]):
"""Base class for the performance estimation."""
@abstractmethod
- def estimate(self, model: ModelType) -> PerfMetricsType:
+ def estimate(self, model: M) -> P:
"""Estimate performance."""
def estimate_performance(
- original_model: ModelType,
- estimator: PerformanceEstimator[ModelType, PerfMetricsType],
- model_transformations: List[Callable[[ModelType], ModelType]],
-) -> List[PerfMetricsType]:
+ original_model: M,
+ estimator: PerformanceEstimator[M, P],
+ model_transformations: list[Callable[[M], M]],
+) -> list[P]:
"""Estimate performance impact.
This function estimates performance impact on model performance after
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index 58a41d3..0c8fabc 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Reporting module."""
+from __future__ import annotations
+
import csv
import json
import logging
@@ -19,19 +21,14 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Collection
-from typing import Dict
from typing import Generator
from typing import Iterable
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
import numpy as np
-from mlia.core._typing import FileLike
-from mlia.core._typing import OutputFormat
-from mlia.core._typing import PathOrFileLike
+from mlia.core.typing import FileLike
+from mlia.core.typing import OutputFormat
+from mlia.core.typing import PathOrFileLike
from mlia.utils.console import apply_style
from mlia.utils.console import produce_table
from mlia.utils.logging import LoggerWriter
@@ -48,7 +45,7 @@ class Report(ABC):
"""Convert to json serializible format."""
@abstractmethod
- def to_csv(self, **kwargs: Any) -> List[Any]:
+ def to_csv(self, **kwargs: Any) -> list[Any]:
"""Convert to csv serializible format."""
@abstractmethod
@@ -62,9 +59,9 @@ class ReportItem:
def __init__(
self,
name: str,
- alias: Optional[str] = None,
- value: Optional[Union[str, int, "Cell"]] = None,
- nested_items: Optional[List["ReportItem"]] = None,
+ alias: str | None = None,
+ value: str | int | Cell | None = None,
+ nested_items: list[ReportItem] | None = None,
) -> None:
"""Init the report item."""
self.name = name
@@ -98,9 +95,9 @@ class Format:
:param style: text style
"""
- wrap_width: Optional[int] = None
- str_fmt: Optional[Union[str, Callable[[Any], str]]] = None
- style: Optional[str] = None
+ wrap_width: int | None = None
+ str_fmt: str | Callable[[Any], str] | None = None
+ style: str | None = None
@dataclass
@@ -112,7 +109,7 @@ class Cell:
"""
value: Any
- fmt: Optional[Format] = None
+ fmt: Format | None = None
def _apply_style(self, value: str) -> str:
"""Apply style to the value."""
@@ -151,7 +148,7 @@ class CountAwareCell(Cell):
def __init__(
self,
- value: Optional[Union[int, float]],
+ value: int | float | None,
singular: str,
plural: str,
format_string: str = ",d",
@@ -159,7 +156,7 @@ class CountAwareCell(Cell):
"""Init cell instance."""
self.unit = singular if value == 1 else plural
- def format_value(val: Optional[Union[int, float]]) -> str:
+ def format_value(val: int | float | None) -> str:
"""Provide string representation for the value."""
if val is None:
return ""
@@ -183,7 +180,7 @@ class CountAwareCell(Cell):
class BytesCell(CountAwareCell):
"""Cell that represents memory size."""
- def __init__(self, value: Optional[int]) -> None:
+ def __init__(self, value: int | None) -> None:
"""Init cell instance."""
super().__init__(value, "byte", "bytes")
@@ -191,7 +188,7 @@ class BytesCell(CountAwareCell):
class CyclesCell(CountAwareCell):
"""Cell that represents cycles."""
- def __init__(self, value: Optional[Union[int, float]]) -> None:
+ def __init__(self, value: int | float | None) -> None:
"""Init cell instance."""
super().__init__(value, "cycle", "cycles", ",.0f")
@@ -199,7 +196,7 @@ class CyclesCell(CountAwareCell):
class ClockCell(CountAwareCell):
"""Cell that represents clock value."""
- def __init__(self, value: Optional[Union[int, float]]) -> None:
+ def __init__(self, value: int | float | None) -> None:
"""Init cell instance."""
super().__init__(value, "Hz", "Hz", ",.0f")
@@ -210,9 +207,9 @@ class Column:
def __init__(
self,
header: str,
- alias: Optional[str] = None,
- fmt: Optional[Format] = None,
- only_for: Optional[List[str]] = None,
+ alias: str | None = None,
+ fmt: Format | None = None,
+ only_for: list[str] | None = None,
) -> None:
"""Init column definition.
@@ -228,7 +225,7 @@ class Column:
self.fmt = fmt
self.only_for = only_for
- def supports_format(self, fmt: str) -> bool:
+ def supports_format(self, fmt: OutputFormat) -> bool:
"""Return true if column should be shown."""
return not self.only_for or fmt in self.only_for
@@ -236,20 +233,20 @@ class Column:
class NestedReport(Report):
"""Report with nested items."""
- def __init__(self, name: str, alias: str, items: List[ReportItem]) -> None:
+ def __init__(self, name: str, alias: str, items: list[ReportItem]) -> None:
"""Init nested report."""
self.name = name
self.alias = alias
self.items = items
- def to_csv(self, **kwargs: Any) -> List[Any]:
+ def to_csv(self, **kwargs: Any) -> list[Any]:
"""Convert to csv serializible format."""
result = {}
def collect_item_values(
item: ReportItem,
- _parent: Optional[ReportItem],
- _prev: Optional[ReportItem],
+ _parent: ReportItem | None,
+ _prev: ReportItem | None,
_level: int,
) -> None:
"""Collect item values into a dictionary.."""
@@ -279,13 +276,13 @@ class NestedReport(Report):
def to_json(self, **kwargs: Any) -> Any:
"""Convert to json serializible format."""
- per_parent: Dict[Optional[ReportItem], Dict] = defaultdict(dict)
+ per_parent: dict[ReportItem | None, dict] = defaultdict(dict)
result = per_parent[None]
def collect_as_dicts(
item: ReportItem,
- parent: Optional[ReportItem],
- _prev: Optional[ReportItem],
+ parent: ReportItem | None,
+ _prev: ReportItem | None,
_level: int,
) -> None:
"""Collect item values as nested dictionaries."""
@@ -313,8 +310,8 @@ class NestedReport(Report):
def convert_to_text(
item: ReportItem,
- _parent: Optional[ReportItem],
- prev: Optional[ReportItem],
+ _parent: ReportItem | None,
+ prev: ReportItem | None,
level: int,
) -> None:
"""Convert item to text representation."""
@@ -345,12 +342,12 @@ class NestedReport(Report):
def _traverse(
self,
- items: List[ReportItem],
+ items: list[ReportItem],
visit_item: Callable[
- [ReportItem, Optional[ReportItem], Optional[ReportItem], int], None
+ [ReportItem, ReportItem | None, ReportItem | None, int], None
],
level: int = 1,
- parent: Optional[ReportItem] = None,
+ parent: ReportItem | None = None,
) -> None:
"""Traverse through items."""
prev = None
@@ -369,11 +366,11 @@ class Table(Report):
def __init__(
self,
- columns: List[Column],
+ columns: list[Column],
rows: Collection,
name: str,
- alias: Optional[str] = None,
- notes: Optional[str] = None,
+ alias: str | None = None,
+ notes: str | None = None,
) -> None:
"""Init table definition.
@@ -477,7 +474,7 @@ class Table(Report):
return title + formatted_table + footer
- def to_csv(self, **kwargs: Any) -> List[Any]:
+ def to_csv(self, **kwargs: Any) -> list[Any]:
"""Convert table to csv format."""
headers = [[c.header for c in self.columns if c.supports_format("csv")]]
@@ -528,7 +525,7 @@ class CompoundReport(Report):
This class could be used for producing multiple reports at once.
"""
- def __init__(self, reports: List[Report]) -> None:
+ def __init__(self, reports: list[Report]) -> None:
"""Init compound report instance."""
self.reports = reports
@@ -538,13 +535,13 @@ class CompoundReport(Report):
Method attempts to create compound dictionary based on provided
parts.
"""
- result: Dict[str, Any] = {}
+ result: dict[str, Any] = {}
for item in self.reports:
result.update(item.to_json(**kwargs))
return result
- def to_csv(self, **kwargs: Any) -> List[Any]:
+ def to_csv(self, **kwargs: Any) -> list[Any]:
"""Convert to csv serializible format.
CSV format does support only one table. In order to be able to export
@@ -592,7 +589,7 @@ class CompoundReport(Report):
class CompoundFormatter:
"""Compound data formatter."""
- def __init__(self, formatters: List[Callable]) -> None:
+ def __init__(self, formatters: list[Callable]) -> None:
"""Init compound formatter."""
self.formatters = formatters
@@ -637,7 +634,7 @@ def produce_report(
data: Any,
formatter: Callable[[Any], Report],
fmt: OutputFormat = "plain_text",
- output: Optional[PathOrFileLike] = None,
+ output: PathOrFileLike | None = None,
**kwargs: Any,
) -> None:
"""Produce report based on provided data."""
@@ -679,8 +676,8 @@ class Reporter:
self.output_format = output_format
self.print_as_submitted = print_as_submitted
- self.data: List[Tuple[Any, Callable[[Any], Report]]] = []
- self.delayed: List[Tuple[Any, Callable[[Any], Report]]] = []
+ self.data: list[tuple[Any, Callable[[Any], Report]]] = []
+ self.delayed: list[tuple[Any, Callable[[Any], Report]]] = []
def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None:
"""Submit data for the report."""
@@ -713,7 +710,7 @@ class Reporter:
)
self.delayed = []
- def generate_report(self, output: Optional[PathOrFileLike]) -> None:
+ def generate_report(self, output: PathOrFileLike | None) -> None:
"""Generate report."""
already_printed = (
self.print_as_submitted
@@ -735,7 +732,7 @@ class Reporter:
@contextmanager
def get_reporter(
output_format: OutputFormat,
- output: Optional[PathOrFileLike],
+ output: PathOrFileLike | None,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
) -> Generator[Reporter, None, None]:
"""Get reporter and generate report."""
@@ -762,7 +759,7 @@ def _apply_format_parameters(
return wrapper
-def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat:
+def resolve_output_format(output: PathOrFileLike | None) -> OutputFormat:
"""Resolve output format based on the output name."""
if isinstance(output, (str, Path)):
format_from_filename = Path(output).suffix.lstrip(".")
diff --git a/src/mlia/core/_typing.py b/src/mlia/core/typing.py
index bda995c..bda995c 100644
--- a/src/mlia/core/_typing.py
+++ b/src/mlia/core/typing.py
diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py
index 03f3d1c..d862a86 100644
--- a/src/mlia/core/workflow.py
+++ b/src/mlia/core/workflow.py
@@ -5,16 +5,15 @@
This module contains implementation of the workflow
executors.
"""
+from __future__ import annotations
+
import itertools
from abc import ABC
from abc import abstractmethod
from functools import wraps
from typing import Any
from typing import Callable
-from typing import List
-from typing import Optional
from typing import Sequence
-from typing import Tuple
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
@@ -57,7 +56,7 @@ STAGE_ANALYSIS = (DataAnalysisStageStartedEvent(), DataAnalysisStageFinishedEven
STAGE_ADVICE = (AdviceStageStartedEvent(), AdviceStageFinishedEvent())
-def on_stage(stage_events: Tuple[Event, Event]) -> Callable:
+def on_stage(stage_events: tuple[Event, Event]) -> Callable:
"""Mark start/finish of the stage with appropriate events."""
def wrapper(method: Callable) -> Callable:
@@ -87,7 +86,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
collectors: Sequence[DataCollector],
analyzers: Sequence[DataAnalyzer],
producers: Sequence[AdviceProducer],
- startup_events: Optional[Sequence[Event]] = None,
+ startup_events: Sequence[Event] | None = None,
):
"""Init default workflow executor.
@@ -130,7 +129,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
self.publish(event)
@on_stage(STAGE_COLLECTION)
- def collect_data(self) -> List[DataItem]:
+ def collect_data(self) -> list[DataItem]:
"""Collect data.
Run each of data collector components and return list of
@@ -148,7 +147,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
return collected_data
@on_stage(STAGE_ANALYSIS)
- def analyze_data(self, collected_data: List[DataItem]) -> List[DataItem]:
+ def analyze_data(self, collected_data: list[DataItem]) -> list[DataItem]:
"""Analyze data.
Pass each collected data item into each data analyzer and
@@ -168,7 +167,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
return analyzed_data
@on_stage(STAGE_ADVICE)
- def produce_advice(self, analyzed_data: List[DataItem]) -> None:
+ def produce_advice(self, analyzed_data: list[DataItem]) -> None:
"""Produce advice.
Pass each analyzed data item into each advice producer and
diff --git a/src/mlia/devices/ethosu/advice_generation.py b/src/mlia/devices/ethosu/advice_generation.py
index 0b1352b..dee1650 100644
--- a/src/mlia/devices/ethosu/advice_generation.py
+++ b/src/mlia/devices/ethosu/advice_generation.py
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U advice generation."""
+from __future__ import annotations
+
from functools import singledispatchmethod
-from typing import List
-from typing import Union
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import advice_category
@@ -146,8 +146,8 @@ class EthosUAdviceProducer(FactBasedAdviceProducer):
@staticmethod
def get_next_optimization_targets(
- opt_type: List[OptimizationSettings],
- ) -> List[OptimizationSettings]:
+ opt_type: list[OptimizationSettings],
+ ) -> list[OptimizationSettings]:
"""Get next optimization targets."""
next_targets = (item.next_target() for item in opt_type)
@@ -173,7 +173,7 @@ class EthosUStaticAdviceProducer(ContextAwareAdviceProducer):
def produce_advice(self, data_item: DataItem) -> None:
"""Do not process passed data items."""
- def get_advice(self) -> Union[Advice, List[Advice]]:
+ def get_advice(self) -> Advice | list[Advice]:
"""Return predefined advice based on category."""
advice_per_category = {
AdviceCategory.PERFORMANCE: [
diff --git a/src/mlia/devices/ethosu/advisor.py b/src/mlia/devices/ethosu/advisor.py
index b7b8305..be58de7 100644
--- a/src/mlia/devices/ethosu/advisor.py
+++ b/src/mlia/devices/ethosu/advisor.py
@@ -1,14 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U MLIA module."""
+from __future__ import annotations
+
from pathlib import Path
from typing import Any
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Union
-from mlia.core._typing import PathOrFileLike
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
@@ -18,6 +15,7 @@ from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
+from mlia.core.typing import PathOrFileLike
from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer
from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer
from mlia.devices.ethosu.config import EthosUConfiguration
@@ -40,13 +38,13 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
"""Return name of the advisor."""
return "ethos_u_inference_advisor"
- def get_collectors(self, context: Context) -> List[DataCollector]:
+ def get_collectors(self, context: Context) -> list[DataCollector]:
"""Return list of the data collectors."""
model = self.get_model(context)
device = self._get_device(context)
backends = self._get_backends(context)
- collectors: List[DataCollector] = []
+ collectors: list[DataCollector] = []
if AdviceCategory.OPERATORS in context.advice_category:
collectors.append(EthosUOperatorCompatibility(model, device))
@@ -75,20 +73,20 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
return collectors
- def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
+ def get_analyzers(self, context: Context) -> list[DataAnalyzer]:
"""Return list of the data analyzers."""
return [
EthosUDataAnalyzer(),
]
- def get_producers(self, context: Context) -> List[AdviceProducer]:
+ def get_producers(self, context: Context) -> list[AdviceProducer]:
"""Return list of the advice producers."""
return [
EthosUAdviceProducer(),
EthosUStaticAdviceProducer(),
]
- def get_events(self, context: Context) -> List[Event]:
+ def get_events(self, context: Context) -> list[Event]:
"""Return list of the startup events."""
model = self.get_model(context)
device = self._get_device(context)
@@ -103,7 +101,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
return get_target(target_profile)
- def _get_optimization_settings(self, context: Context) -> List[List[dict]]:
+ def _get_optimization_settings(self, context: Context) -> list[list[dict]]:
"""Get optimization settings."""
return self.get_parameter( # type: ignore
EthosUOptimizationPerformance.name(),
@@ -113,7 +111,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
context=context,
)
- def _get_backends(self, context: Context) -> Optional[List[str]]:
+ def _get_backends(self, context: Context) -> list[str] | None:
"""Get list of backends."""
return self.get_parameter( # type: ignore
self.name(),
@@ -127,8 +125,8 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def configure_and_get_ethosu_advisor(
context: ExecutionContext,
target_profile: str,
- model: Union[Path, str],
- output: Optional[PathOrFileLike] = None,
+ model: str | Path,
+ output: PathOrFileLike | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Ethos-U advisor."""
@@ -158,12 +156,12 @@ _DEFAULT_OPTIMIZATION_TARGETS = [
def _get_config_parameters(
- model: Union[Path, str],
+ model: str | Path,
target_profile: str,
**extra_args: Any,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""Get configuration parameters for the advisor."""
- advisor_parameters: Dict[str, Any] = {
+ advisor_parameters: dict[str, Any] = {
"ethos_u_inference_advisor": {
"model": model,
"target_profile": target_profile,
diff --git a/src/mlia/devices/ethosu/config.py b/src/mlia/devices/ethosu/config.py
index cecbb27..e44dcdc 100644
--- a/src/mlia/devices/ethosu/config.py
+++ b/src/mlia/devices/ethosu/config.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U configuration."""
+from __future__ import annotations
+
import logging
from typing import Any
-from typing import Dict
from mlia.devices.config import IPConfiguration
from mlia.tools.vela_wrapper import resolve_compiler_config
@@ -38,7 +39,7 @@ class EthosUConfiguration(IPConfiguration):
)
@property
- def resolved_compiler_config(self) -> Dict[str, Any]:
+ def resolved_compiler_config(self) -> dict[str, Any]:
"""Resolve compiler configuration."""
return resolve_compiler_config(self.compiler_options)
@@ -63,7 +64,7 @@ def get_target(target_profile: str) -> EthosUConfiguration:
return EthosUConfiguration(target_profile)
-def _check_target_data_complete(target_data: Dict[str, Any]) -> None:
+def _check_target_data_complete(target_data: dict[str, Any]) -> None:
"""Check if profile contains all needed data."""
mandatory_keys = {"target", "mac", "system_config", "memory_mode"}
missing_keys = sorted(mandatory_keys - target_data.keys())
diff --git a/src/mlia/devices/ethosu/data_analysis.py b/src/mlia/devices/ethosu/data_analysis.py
index 9ed32ff..8d88cf7 100644
--- a/src/mlia/devices/ethosu/data_analysis.py
+++ b/src/mlia/devices/ethosu/data_analysis.py
@@ -1,11 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U data analysis module."""
+from __future__ import annotations
+
from dataclasses import dataclass
from functools import singledispatchmethod
-from typing import Dict
-from typing import List
-from typing import Union
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
@@ -19,7 +18,7 @@ from mlia.tools.vela_wrapper import Operators
class HasCPUOnlyOperators(Fact):
"""Model has CPU only operators."""
- cpu_only_ops: List[str]
+ cpu_only_ops: list[str]
@dataclass
@@ -38,8 +37,8 @@ class AllOperatorsSupportedOnNPU(Fact):
class PerfMetricDiff:
"""Performance metric difference."""
- original_value: Union[int, float]
- optimized_value: Union[int, float]
+ original_value: int | float
+ optimized_value: int | float
@property
def diff(self) -> float:
@@ -69,15 +68,15 @@ class PerfMetricDiff:
class OptimizationDiff:
"""Optimization performance impact."""
- opt_type: List[OptimizationSettings]
- opt_diffs: Dict[str, PerfMetricDiff]
+ opt_type: list[OptimizationSettings]
+ opt_diffs: dict[str, PerfMetricDiff]
@dataclass
class OptimizationResults(Fact):
"""Optimization results."""
- diffs: List[OptimizationDiff]
+ diffs: list[OptimizationDiff]
class EthosUDataAnalyzer(FactExtractor):
@@ -113,13 +112,13 @@ class EthosUDataAnalyzer(FactExtractor):
orig_memory = orig.memory_usage
orig_cycles = orig.npu_cycles
- diffs: List[OptimizationDiff] = []
+ diffs: list[OptimizationDiff] = []
for opt_type, opt_perf_metrics in optimizations:
opt = opt_perf_metrics.in_kilobytes()
opt_memory = opt.memory_usage
opt_cycles = opt.npu_cycles
- opt_diffs: Dict[str, PerfMetricDiff] = {}
+ opt_diffs: dict[str, PerfMetricDiff] = {}
if orig_memory and opt_memory:
opt_diffs.update(
diff --git a/src/mlia/devices/ethosu/data_collection.py b/src/mlia/devices/ethosu/data_collection.py
index 291f1b8..6ddebac 100644
--- a/src/mlia/devices/ethosu/data_collection.py
+++ b/src/mlia/devices/ethosu/data_collection.py
@@ -1,10 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Data collection module for Ethos-U."""
+from __future__ import annotations
+
import logging
from pathlib import Path
-from typing import List
-from typing import Optional
from mlia.core.context import Context
from mlia.core.data_collection import ContextAwareDataCollector
@@ -59,7 +59,7 @@ class EthosUPerformance(ContextAwareDataCollector):
self,
model: Path,
device: EthosUConfiguration,
- backends: Optional[List[str]] = None,
+ backends: list[str] | None = None,
) -> None:
"""Init performance data collector."""
self.model = model
@@ -87,7 +87,7 @@ class OptimizeModel:
"""Helper class for model optimization."""
def __init__(
- self, context: Context, opt_settings: List[OptimizationSettings]
+ self, context: Context, opt_settings: list[OptimizationSettings]
) -> None:
"""Init helper."""
self.context = context
@@ -115,8 +115,8 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector):
self,
model: Path,
device: EthosUConfiguration,
- optimizations: List[List[dict]],
- backends: Optional[List[str]] = None,
+ optimizations: list[list[dict]],
+ backends: list[str] | None = None,
) -> None:
"""Init performance optimizations data collector."""
self.model = model
@@ -124,7 +124,7 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector):
self.optimizations = optimizations
self.backends = backends
- def collect_data(self) -> Optional[OptimizationPerformanceMetrics]:
+ def collect_data(self) -> OptimizationPerformanceMetrics | None:
"""Collect performance metrics for the optimizations."""
logger.info("Estimate performance ...")
@@ -164,8 +164,8 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector):
@staticmethod
def _parse_optimization_params(
- optimizations: List[List[dict]],
- ) -> List[List[OptimizationSettings]]:
+ optimizations: list[list[dict]],
+ ) -> list[list[OptimizationSettings]]:
"""Parse optimization parameters."""
if not is_list_of(optimizations, list):
raise Exception("Optimization parameters expected to be a list")
diff --git a/src/mlia/devices/ethosu/handlers.py b/src/mlia/devices/ethosu/handlers.py
index ee0b809..48f9a2e 100644
--- a/src/mlia/devices/ethosu/handlers.py
+++ b/src/mlia/devices/ethosu/handlers.py
@@ -1,12 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Event handler."""
+from __future__ import annotations
+
import logging
-from typing import Optional
-from mlia.core._typing import PathOrFileLike
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
+from mlia.core.typing import PathOrFileLike
from mlia.devices.ethosu.events import EthosUAdvisorEventHandler
from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent
from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
@@ -20,7 +21,7 @@ logger = logging.getLogger(__name__)
class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
"""CLI event handler."""
- def __init__(self, output: Optional[PathOrFileLike] = None) -> None:
+ def __init__(self, output: PathOrFileLike | None = None) -> None:
"""Init event handler."""
super().__init__(ethos_u_formatters, output)
diff --git a/src/mlia/devices/ethosu/performance.py b/src/mlia/devices/ethosu/performance.py
index a73045a..e89a65a 100644
--- a/src/mlia/devices/ethosu/performance.py
+++ b/src/mlia/devices/ethosu/performance.py
@@ -1,13 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Performance estimation."""
+from __future__ import annotations
+
import logging
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
-from typing import List
-from typing import Optional
-from typing import Tuple
from typing import Union
import mlia.backend.manager as backend_manager
@@ -49,11 +48,11 @@ class MemorySizeType(Enum):
class MemoryUsage:
"""Memory usage metrics."""
- sram_memory_area_size: Union[int, float]
- dram_memory_area_size: Union[int, float]
- unknown_memory_area_size: Union[int, float]
- on_chip_flash_memory_area_size: Union[int, float]
- off_chip_flash_memory_area_size: Union[int, float]
+ sram_memory_area_size: int | float
+ dram_memory_area_size: int | float
+ unknown_memory_area_size: int | float
+ on_chip_flash_memory_area_size: int | float
+ off_chip_flash_memory_area_size: int | float
memory_size_type: MemorySizeType = MemorySizeType.BYTES
_default_columns = [
@@ -64,7 +63,7 @@ class MemoryUsage:
"Off chip flash used",
]
- def in_kilobytes(self) -> "MemoryUsage":
+ def in_kilobytes(self) -> MemoryUsage:
"""Return memory usage with values in kilobytes."""
if self.memory_size_type == MemorySizeType.KILOBYTES:
return self
@@ -91,10 +90,10 @@ class PerformanceMetrics:
"""Performance metrics."""
device: EthosUConfiguration
- npu_cycles: Optional[NPUCycles]
- memory_usage: Optional[MemoryUsage]
+ npu_cycles: NPUCycles | None
+ memory_usage: MemoryUsage | None
- def in_kilobytes(self) -> "PerformanceMetrics":
+ def in_kilobytes(self) -> PerformanceMetrics:
"""Return metrics with memory usage in KiB."""
if self.memory_usage is None:
return PerformanceMetrics(self.device, self.npu_cycles, self.memory_usage)
@@ -109,8 +108,8 @@ class OptimizationPerformanceMetrics:
"""Optimization performance metrics."""
original_perf_metrics: PerformanceMetrics
- optimizations_perf_metrics: List[
- Tuple[List[OptimizationSettings], PerformanceMetrics]
+ optimizations_perf_metrics: list[
+ tuple[list[OptimizationSettings], PerformanceMetrics]
]
@@ -124,7 +123,7 @@ class VelaPerformanceEstimator(
self.context = context
self.device = device
- def estimate(self, model: Union[Path, ModelConfiguration]) -> MemoryUsage:
+ def estimate(self, model: Path | ModelConfiguration) -> MemoryUsage:
"""Estimate performance."""
logger.info("Getting the memory usage metrics ...")
@@ -160,7 +159,7 @@ class CorstonePerformanceEstimator(
self.device = device
self.backend = backend
- def estimate(self, model: Union[Path, ModelConfiguration]) -> NPUCycles:
+ def estimate(self, model: Path | ModelConfiguration) -> NPUCycles:
"""Estimate performance."""
logger.info("Getting the performance metrics for '%s' ...", self.backend)
logger.info(
@@ -212,7 +211,7 @@ class EthosUPerformanceEstimator(
self,
context: Context,
device: EthosUConfiguration,
- backends: Optional[List[str]] = None,
+ backends: list[str] | None = None,
) -> None:
"""Init performance estimator."""
self.context = context
@@ -228,7 +227,7 @@ class EthosUPerformanceEstimator(
)
self.backends = set(backends)
- def estimate(self, model: Union[Path, ModelConfiguration]) -> PerformanceMetrics:
+ def estimate(self, model: Path | ModelConfiguration) -> PerformanceMetrics:
"""Estimate performance."""
model_path = (
Path(model.model_path) if isinstance(model, ModelConfiguration) else model
diff --git a/src/mlia/devices/ethosu/reporters.py b/src/mlia/devices/ethosu/reporters.py
index b3aea24..f11430c 100644
--- a/src/mlia/devices/ethosu/reporters.py
+++ b/src/mlia/devices/ethosu/reporters.py
@@ -1,12 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Reports module."""
+from __future__ import annotations
+
from collections import defaultdict
from typing import Any
from typing import Callable
-from typing import List
-from typing import Tuple
-from typing import Union
from mlia.core.advice_generation import Advice
from mlia.core.reporting import BytesCell
@@ -52,7 +51,7 @@ def report_operators_stat(operators: Operators) -> Report:
)
-def report_operators(ops: List[Operator]) -> Report:
+def report_operators(ops: list[Operator]) -> Report:
"""Return table representation for the list of operators."""
columns = [
Column("#", only_for=["plain_text"]),
@@ -235,11 +234,11 @@ def report_device_details(device: EthosUConfiguration) -> Report:
)
-def metrics_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+def metrics_as_records(perf_metrics: list[PerformanceMetrics]) -> list[tuple]:
"""Convert perf metrics object into list of records."""
perf_metrics = [item.in_kilobytes() for item in perf_metrics]
- def _cycles_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+ def _cycles_as_records(perf_metrics: list[PerformanceMetrics]) -> list[tuple]:
metric_map = defaultdict(list)
for metrics in perf_metrics:
if not metrics.npu_cycles:
@@ -253,7 +252,7 @@ def metrics_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
for name, values in metric_map.items()
]
- def _memory_usage_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+ def _memory_usage_as_records(perf_metrics: list[PerformanceMetrics]) -> list[tuple]:
metric_map = defaultdict(list)
for metrics in perf_metrics:
if not metrics.memory_usage:
@@ -276,7 +275,7 @@ def metrics_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
if all(val > 0 for val in values)
]
- def _data_beats_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+ def _data_beats_as_records(perf_metrics: list[PerformanceMetrics]) -> list[tuple]:
metric_map = defaultdict(list)
for metrics in perf_metrics:
if not metrics.npu_cycles:
@@ -308,7 +307,7 @@ def metrics_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
def report_perf_metrics(
- perf_metrics: Union[PerformanceMetrics, List[PerformanceMetrics]]
+ perf_metrics: PerformanceMetrics | list[PerformanceMetrics],
) -> Report:
"""Return comparison table for the performance metrics."""
if isinstance(perf_metrics, PerformanceMetrics):
@@ -361,7 +360,7 @@ def report_perf_metrics(
)
-def report_advice(advice: List[Advice]) -> Report:
+def report_advice(advice: list[Advice]) -> Report:
"""Generate report for the advice."""
return Table(
columns=[
diff --git a/src/mlia/devices/tosa/advisor.py b/src/mlia/devices/tosa/advisor.py
index 6a32b94..53dfa87 100644
--- a/src/mlia/devices/tosa/advisor.py
+++ b/src/mlia/devices/tosa/advisor.py
@@ -1,14 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA advisor."""
+from __future__ import annotations
+
from pathlib import Path
from typing import Any
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Union
-from mlia.core._typing import PathOrFileLike
from mlia.core.advice_generation import AdviceCategory
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
@@ -18,6 +15,7 @@ from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
+from mlia.core.typing import PathOrFileLike
from mlia.devices.tosa.advice_generation import TOSAAdviceProducer
from mlia.devices.tosa.config import TOSAConfiguration
from mlia.devices.tosa.data_analysis import TOSADataAnalyzer
@@ -34,30 +32,30 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
"""Return name of the advisor."""
return "tosa_inference_advisor"
- def get_collectors(self, context: Context) -> List[DataCollector]:
+ def get_collectors(self, context: Context) -> list[DataCollector]:
"""Return list of the data collectors."""
model = self.get_model(context)
- collectors: List[DataCollector] = []
+ collectors: list[DataCollector] = []
if AdviceCategory.OPERATORS in context.advice_category:
collectors.append(TOSAOperatorCompatibility(model))
return collectors
- def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
+ def get_analyzers(self, context: Context) -> list[DataAnalyzer]:
"""Return list of the data analyzers."""
return [
TOSADataAnalyzer(),
]
- def get_producers(self, context: Context) -> List[AdviceProducer]:
+ def get_producers(self, context: Context) -> list[AdviceProducer]:
"""Return list of the advice producers."""
return [
TOSAAdviceProducer(),
]
- def get_events(self, context: Context) -> List[Event]:
+ def get_events(self, context: Context) -> list[Event]:
"""Return list of the startup events."""
model = self.get_model(context)
target_profile = self.get_target_profile(context)
@@ -70,9 +68,9 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
def configure_and_get_tosa_advisor(
context: ExecutionContext,
target_profile: str,
- model: Union[Path, str],
- output: Optional[PathOrFileLike] = None,
- **_extra_args: Any
+ model: str | Path,
+ output: PathOrFileLike | None = None,
+ **_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure TOSA advisor."""
if context.event_handlers is None:
@@ -84,11 +82,9 @@ def configure_and_get_tosa_advisor(
return TOSAInferenceAdvisor()
-def _get_config_parameters(
- model: Union[Path, str], target_profile: str
-) -> Dict[str, Any]:
+def _get_config_parameters(model: str | Path, target_profile: str) -> dict[str, Any]:
"""Get configuration parameters for the advisor."""
- advisor_parameters: Dict[str, Any] = {
+ advisor_parameters: dict[str, Any] = {
"tosa_inference_advisor": {
"model": str(model),
"target_profile": target_profile,
diff --git a/src/mlia/devices/tosa/handlers.py b/src/mlia/devices/tosa/handlers.py
index 00c18c5..5f015c4 100644
--- a/src/mlia/devices/tosa/handlers.py
+++ b/src/mlia/devices/tosa/handlers.py
@@ -2,12 +2,13 @@
# SPDX-License-Identifier: Apache-2.0
"""TOSA Advisor event handlers."""
# pylint: disable=R0801
+from __future__ import annotations
+
import logging
-from typing import Optional
-from mlia.core._typing import PathOrFileLike
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
+from mlia.core.typing import PathOrFileLike
from mlia.devices.tosa.events import TOSAAdvisorEventHandler
from mlia.devices.tosa.events import TOSAAdvisorStartedEvent
from mlia.devices.tosa.operators import TOSACompatibilityInfo
@@ -19,7 +20,7 @@ logger = logging.getLogger(__name__)
class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
"""Event handler for TOSA advisor."""
- def __init__(self, output: Optional[PathOrFileLike] = None) -> None:
+ def __init__(self, output: PathOrFileLike | None = None) -> None:
"""Init event handler."""
super().__init__(tosa_formatters, output)
diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py
index 4f3df10..6cfb87f 100644
--- a/src/mlia/devices/tosa/operators.py
+++ b/src/mlia/devices/tosa/operators.py
@@ -1,14 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Operators module."""
+from __future__ import annotations
+
from dataclasses import dataclass
from typing import Any
from typing import cast
-from typing import List
-from typing import Optional
from typing import Protocol
-from mlia.core._typing import PathOrFileLike
+from mlia.core.typing import PathOrFileLike
class TOSAChecker(Protocol):
@@ -17,7 +17,7 @@ class TOSAChecker(Protocol):
def is_tosa_compatible(self) -> bool:
"""Return true if model is TOSA compatible."""
- def _get_tosa_compatibility_for_ops(self) -> List[Any]:
+ def _get_tosa_compatibility_for_ops(self) -> list[Any]:
"""Return list of operators."""
@@ -35,7 +35,7 @@ class TOSACompatibilityInfo:
"""Models' TOSA compatibility information."""
tosa_compatible: bool
- operators: List[Operator]
+ operators: list[Operator]
def get_tosa_compatibility_info(
@@ -59,7 +59,7 @@ def get_tosa_compatibility_info(
return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops)
-def get_tosa_checker(tflite_model_path: PathOrFileLike) -> Optional[TOSAChecker]:
+def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None:
"""Return instance of the TOSA checker."""
try:
import tosa_checker as tc # pylint: disable=import-outside-toplevel
diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py
index 8fba95c..4363793 100644
--- a/src/mlia/devices/tosa/reporters.py
+++ b/src/mlia/devices/tosa/reporters.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Reports module."""
+from __future__ import annotations
+
from typing import Any
from typing import Callable
-from typing import List
from mlia.core.advice_generation import Advice
from mlia.core.reporting import Cell
@@ -30,7 +31,7 @@ def report_device(device: TOSAConfiguration) -> Report:
)
-def report_advice(advice: List[Advice]) -> Report:
+def report_advice(advice: list[Advice]) -> Report:
"""Generate report for the advice."""
return Table(
columns=[
@@ -43,7 +44,7 @@ def report_advice(advice: List[Advice]) -> Report:
)
-def report_tosa_operators(ops: List[Operator]) -> Report:
+def report_tosa_operators(ops: list[Operator]) -> Report:
"""Generate report for the operators."""
return Table(
[
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index d3235d7..6ee32e7 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -1,12 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Model configuration."""
+from __future__ import annotations
+
import logging
from pathlib import Path
from typing import cast
-from typing import Dict
from typing import List
-from typing import Union
import tensorflow as tf
@@ -24,17 +24,17 @@ logger = logging.getLogger(__name__)
class ModelConfiguration:
"""Base class for model configuration."""
- def __init__(self, model_path: Union[str, Path]) -> None:
+ def __init__(self, model_path: str | Path) -> None:
"""Init model configuration instance."""
self.model_path = str(model_path)
def convert_to_tflite(
- self, tflite_model_path: Union[str, Path], quantized: bool = False
- ) -> "TFLiteModel":
+ self, tflite_model_path: str | Path, quantized: bool = False
+ ) -> TFLiteModel:
"""Convert model to TFLite format."""
raise NotImplementedError()
- def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel:
"""Convert model to Keras format."""
raise NotImplementedError()
@@ -50,8 +50,8 @@ class KerasModel(ModelConfiguration):
return tf.keras.models.load_model(self.model_path)
def convert_to_tflite(
- self, tflite_model_path: Union[str, Path], quantized: bool = False
- ) -> "TFLiteModel":
+ self, tflite_model_path: str | Path, quantized: bool = False
+ ) -> TFLiteModel:
"""Convert model to TFLite format."""
logger.info("Converting Keras to TFLite ...")
@@ -65,7 +65,7 @@ class KerasModel(ModelConfiguration):
return TFLiteModel(tflite_model_path)
- def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel:
"""Convert model to Keras format."""
return self
@@ -73,14 +73,14 @@ class KerasModel(ModelConfiguration):
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""TFLite model configuration."""
- def input_details(self) -> List[Dict]:
+ def input_details(self) -> list[dict]:
"""Get model's input details."""
interpreter = tf.lite.Interpreter(model_path=self.model_path)
- return cast(List[Dict], interpreter.get_input_details())
+ return cast(List[dict], interpreter.get_input_details())
def convert_to_tflite(
- self, tflite_model_path: Union[str, Path], quantized: bool = False
- ) -> "TFLiteModel":
+ self, tflite_model_path: str | Path, quantized: bool = False
+ ) -> TFLiteModel:
"""Convert model to TFLite format."""
return self
@@ -92,8 +92,8 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method
"""
def convert_to_tflite(
- self, tflite_model_path: Union[str, Path], quantized: bool = False
- ) -> "TFLiteModel":
+ self, tflite_model_path: str | Path, quantized: bool = False
+ ) -> TFLiteModel:
"""Convert model to TFLite format."""
converted_model = convert_tf_to_tflite(self.model_path, quantized)
save_tflite_model(converted_model, tflite_model_path)
@@ -101,7 +101,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method
return TFLiteModel(tflite_model_path)
-def get_model(model: Union[Path, str]) -> "ModelConfiguration":
+def get_model(model: str | Path) -> ModelConfiguration:
"""Return the model object."""
if is_tflite_model(model):
return TFLiteModel(model)
@@ -118,7 +118,7 @@ def get_model(model: Union[Path, str]) -> "ModelConfiguration":
)
-def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel":
+def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel:
"""Convert input model to TFLite and returns TFLiteModel object."""
tflite_model_path = ctx.get_model_path("converted_model.tflite")
converted_model = get_model(model)
@@ -126,7 +126,7 @@ def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel":
return converted_model.convert_to_tflite(tflite_model_path, True)
-def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel":
+def get_keras_model(model: str | Path, ctx: Context) -> KerasModel:
"""Convert input model to Keras and returns KerasModel object."""
keras_model_path = ctx.get_model_path("converted_model.h5")
converted_model = get_model(model)
diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py
index 16d9e4b..4aaa33e 100644
--- a/src/mlia/nn/tensorflow/optimizations/clustering.py
+++ b/src/mlia/nn/tensorflow/optimizations/clustering.py
@@ -7,11 +7,10 @@ In order to do this, we need to have a base model and corresponding training dat
We also have to specify a subset of layers we want to cluster. For more details,
please refer to the documentation for TensorFlow Model Optimization Toolkit.
"""
+from __future__ import annotations
+
from dataclasses import dataclass
from typing import Any
-from typing import Dict
-from typing import List
-from typing import Optional
import tensorflow as tf
import tensorflow_model_optimization as tfmot
@@ -28,7 +27,7 @@ class ClusteringConfiguration(OptimizerConfiguration):
"""Clustering configuration."""
optimization_target: int
- layers_to_optimize: Optional[List[str]] = None
+ layers_to_optimize: list[str] | None = None
def __str__(self) -> str:
"""Return string representation of the configuration."""
@@ -61,7 +60,7 @@ class Clusterer(Optimizer):
"""Return string representation of the optimization config."""
return str(self.optimizer_configuration)
- def _setup_clustering_params(self) -> Dict[str, Any]:
+ def _setup_clustering_params(self) -> dict[str, Any]:
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
return {
"number_of_clusters": self.optimizer_configuration.optimization_target,
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
index 0a3fda5..41954b9 100644
--- a/src/mlia/nn/tensorflow/optimizations/pruning.py
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -7,11 +7,10 @@ In order to do this, we need to have a base model and corresponding training dat
We also have to specify a subset of layers we want to prune. For more details,
please refer to the documentation for TensorFlow Model Optimization Toolkit.
"""
+from __future__ import annotations
+
import typing
from dataclasses import dataclass
-from typing import List
-from typing import Optional
-from typing import Tuple
import numpy as np
import tensorflow as tf
@@ -29,9 +28,9 @@ class PruningConfiguration(OptimizerConfiguration):
"""Pruning configuration."""
optimization_target: float
- layers_to_optimize: Optional[List[str]] = None
- x_train: Optional[np.ndarray] = None
- y_train: Optional[np.ndarray] = None
+ layers_to_optimize: list[str] | None = None
+ x_train: np.ndarray | None = None
+ y_train: np.ndarray | None = None
batch_size: int = 1
num_epochs: int = 1
@@ -74,7 +73,7 @@ class Pruner(Optimizer):
"""Return string representation of the optimization config."""
return str(self.optimizer_configuration)
- def _mock_train_data(self) -> Tuple[np.ndarray, np.ndarray]:
+ def _mock_train_data(self) -> tuple[np.ndarray, np.ndarray]:
# get rid of the batch_size dimension in input and output shape
input_shape = tuple(x for x in self.model.input_shape if x is not None)
output_shape = tuple(x for x in self.model.output_shape if x is not None)
diff --git a/src/mlia/nn/tensorflow/optimizations/select.py b/src/mlia/nn/tensorflow/optimizations/select.py
index 1b0c755..d4a8ea4 100644
--- a/src/mlia/nn/tensorflow/optimizations/select.py
+++ b/src/mlia/nn/tensorflow/optimizations/select.py
@@ -1,12 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for optimization selection."""
+from __future__ import annotations
+
import math
-from typing import List
from typing import NamedTuple
-from typing import Optional
-from typing import Tuple
-from typing import Union
import tensorflow as tf
@@ -25,14 +23,14 @@ class OptimizationSettings(NamedTuple):
"""Optimization settings."""
optimization_type: str
- optimization_target: Union[int, float]
- layers_to_optimize: Optional[List[str]]
+ optimization_target: int | float
+ layers_to_optimize: list[str] | None
@staticmethod
def create_from(
- optimizer_params: List[Tuple[str, float]],
- layers_to_optimize: Optional[List[str]] = None,
- ) -> List["OptimizationSettings"]:
+ optimizer_params: list[tuple[str, float]],
+ layers_to_optimize: list[str] | None = None,
+ ) -> list[OptimizationSettings]:
"""Create optimization settings from the provided parameters."""
return [
OptimizationSettings(
@@ -47,7 +45,7 @@ class OptimizationSettings(NamedTuple):
"""Return string representation."""
return f"{self.optimization_type}: {self.optimization_target}"
- def next_target(self) -> "OptimizationSettings":
+ def next_target(self) -> OptimizationSettings:
"""Return next optimization target."""
if self.optimization_type == "pruning":
next_target = round(min(self.optimization_target + 0.1, 0.9), 2)
@@ -75,7 +73,7 @@ class MultiStageOptimizer(Optimizer):
def __init__(
self,
model: tf.keras.Model,
- optimizations: List[OptimizerConfiguration],
+ optimizations: list[OptimizerConfiguration],
) -> None:
"""Init MultiStageOptimizer instance."""
self.model = model
@@ -98,10 +96,8 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
- model: Union[tf.keras.Model, KerasModel],
- config: Union[
- OptimizerConfiguration, OptimizationSettings, List[OptimizationSettings]
- ],
+ model: tf.keras.Model | KerasModel,
+ config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
@@ -123,7 +119,7 @@ def get_optimizer(
def _get_optimizer(
model: tf.keras.Model,
- optimization_settings: Union[OptimizationSettings, List[OptimizationSettings]],
+ optimization_settings: OptimizationSettings | list[OptimizationSettings],
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
@@ -145,8 +141,8 @@ def _get_optimizer(
def _get_optimizer_configuration(
optimization_type: str,
- optimization_target: Union[int, float],
- layers_to_optimize: Optional[List[str]] = None,
+ optimization_target: int | float,
+ layers_to_optimize: list[str] | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
@@ -169,7 +165,7 @@ def _get_optimizer_configuration(
def _check_optimizer_params(
- optimization_type: str, optimization_target: Union[int, float]
+ optimization_type: str, optimization_target: int | float
) -> None:
"""Check optimizer params."""
if not optimization_target:
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py
index 3f41487..0af7500 100644
--- a/src/mlia/nn/tensorflow/tflite_metrics.py
+++ b/src/mlia/nn/tensorflow/tflite_metrics.py
@@ -8,13 +8,13 @@ These metrics include:
* Unique weights (clusters) (per layer)
* gzip compression ratio
"""
+from __future__ import annotations
+
import os
import typing
from enum import Enum
from pprint import pprint
from typing import Any
-from typing import List
-from typing import Optional
import numpy as np
import tensorflow as tf
@@ -42,7 +42,7 @@ def calculate_num_unique_weights(weights: np.ndarray) -> int:
return num_unique_weights
-def calculate_num_unique_weights_per_axis(weights: np.ndarray, axis: int) -> List[int]:
+def calculate_num_unique_weights_per_axis(weights: np.ndarray, axis: int) -> list[int]:
"""Calculate unique weights per quantization axis."""
# Make quantized dimension the first dimension
weights_trans = np.swapaxes(weights, 0, axis)
@@ -74,7 +74,7 @@ class SparsityAccumulator:
def calculate_sparsity(
- weights: np.ndarray, accumulator: Optional[SparsityAccumulator] = None
+ weights: np.ndarray, accumulator: SparsityAccumulator | None = None
) -> float:
"""
Calculate the sparsity for the given weights.
@@ -110,9 +110,7 @@ class TFLiteMetrics:
* File compression via gzip
"""
- def __init__(
- self, tflite_file: str, ignore_list: Optional[List[str]] = None
- ) -> None:
+ def __init__(self, tflite_file: str, ignore_list: list[str] | None = None) -> None:
"""Load the TFLite file and filter layers."""
self.tflite_file = tflite_file
if ignore_list is None:
@@ -159,7 +157,7 @@ class TFLiteMetrics:
acc(self.get_tensor(details))
return acc.sparsity()
- def calc_num_clusters_per_axis(self, details: dict) -> List[int]:
+ def calc_num_clusters_per_axis(self, details: dict) -> list[int]:
"""Calculate number of clusters per axis."""
quant_params = details["quantization_parameters"]
per_axis = len(quant_params["zero_points"]) > 1
@@ -178,14 +176,14 @@ class TFLiteMetrics:
aggregation_func = self.calc_num_clusters_per_axis
elif mode == ReportClusterMode.NUM_CLUSTERS_MIN_MAX:
- def cluster_min_max(details: dict) -> List[int]:
+ def cluster_min_max(details: dict) -> list[int]:
num_clusters = self.calc_num_clusters_per_axis(details)
return [min(num_clusters), max(num_clusters)]
aggregation_func = cluster_min_max
elif mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM:
- def cluster_hist(details: dict) -> List[int]:
+ def cluster_hist(details: dict) -> list[int]:
num_clusters = self.calc_num_clusters_per_axis(details)
max_num = max(num_clusters)
hist = [0] * (max_num)
@@ -289,7 +287,7 @@ class TFLiteMetrics:
print(f"- {self._prettify_name(name)}: {nums}")
@staticmethod
- def _print_in_outs(ios: List[dict], verbose: bool = False) -> None:
+ def _print_in_outs(ios: list[dict], verbose: bool = False) -> None:
for item in ios:
if verbose:
pprint(item)
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index b1034d9..6250f56 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Collection of useful functions for optimizations."""
+from __future__ import annotations
+
import logging
from pathlib import Path
from typing import Callable
from typing import Iterable
-from typing import Union
import numpy as np
import tensorflow as tf
@@ -101,21 +102,19 @@ def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter:
return tflite_model
-def save_keras_model(model: tf.keras.Model, save_path: Union[str, Path]) -> None:
+def save_keras_model(model: tf.keras.Model, save_path: str | Path) -> None:
"""Save Keras model at provided path."""
# Checkpoint: saving the optimizer is necessary.
model.save(save_path, include_optimizer=True)
-def save_tflite_model(
- model: tf.lite.TFLiteConverter, save_path: Union[str, Path]
-) -> None:
+def save_tflite_model(model: tf.lite.TFLiteConverter, save_path: str | Path) -> None:
"""Save TFLite model at provided path."""
with open(save_path, "wb") as file:
file.write(model)
-def is_tflite_model(model: Union[Path, str]) -> bool:
+def is_tflite_model(model: str | Path) -> bool:
"""Check if model type is supported by TFLite API.
TFLite model is indicated by the model file extension .tflite
@@ -124,7 +123,7 @@ def is_tflite_model(model: Union[Path, str]) -> bool:
return model_path.suffix == ".tflite"
-def is_keras_model(model: Union[Path, str]) -> bool:
+def is_keras_model(model: str | Path) -> bool:
"""Check if model type is supported by Keras API.
Keras model is indicated by:
@@ -139,7 +138,7 @@ def is_keras_model(model: Union[Path, str]) -> bool:
return model_path.suffix in (".h5", ".hdf5")
-def is_tf_model(model: Union[Path, str]) -> bool:
+def is_tf_model(model: str | Path) -> bool:
"""Check if model type is supported by TensorFlow API.
TensorFlow model is indicated if its directory (meaning saved model)
diff --git a/src/mlia/tools/metadata/common.py b/src/mlia/tools/metadata/common.py
index 924e870..32da4a4 100644
--- a/src/mlia/tools/metadata/common.py
+++ b/src/mlia/tools/metadata/common.py
@@ -1,14 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for installation process."""
+from __future__ import annotations
+
import logging
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
-from typing import List
-from typing import Optional
from typing import Union
from mlia.utils.misc import yes
@@ -100,7 +100,7 @@ class SupportsInstallTypeFilter:
class SearchByNameFilter:
"""Filter installation by name."""
- def __init__(self, backend_name: Optional[str]) -> None:
+ def __init__(self, backend_name: str | None) -> None:
"""Init filter."""
self.backend_name = backend_name
@@ -113,12 +113,12 @@ class InstallationManager(ABC):
"""Helper class for managing installations."""
@abstractmethod
- def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None:
+ def install_from(self, backend_path: Path, backend_name: str | None) -> None:
"""Install backend from the local directory."""
@abstractmethod
def download_and_install(
- self, backend_name: Optional[str], eula_agreement: bool
+ self, backend_name: str | None, eula_agreement: bool
) -> None:
"""Download and install backends."""
@@ -134,9 +134,9 @@ class InstallationManager(ABC):
class InstallationFiltersMixin:
"""Mixin for filtering installation based on different conditions."""
- installations: List[Installation]
+ installations: list[Installation]
- def filter_by(self, *filters: InstallationFilter) -> List[Installation]:
+ def filter_by(self, *filters: InstallationFilter) -> list[Installation]:
"""Filter installations."""
return [
installation
@@ -145,8 +145,8 @@ class InstallationFiltersMixin:
]
def could_be_installed_from(
- self, backend_path: Path, backend_name: Optional[str]
- ) -> List[Installation]:
+ self, backend_path: Path, backend_name: str | None
+ ) -> list[Installation]:
"""Return installations that could be installed from provided directory."""
return self.filter_by(
SupportsInstallTypeFilter(InstallFromPath(backend_path)),
@@ -154,8 +154,8 @@ class InstallationFiltersMixin:
)
def could_be_downloaded_and_installed(
- self, backend_name: Optional[str] = None
- ) -> List[Installation]:
+ self, backend_name: str | None = None
+ ) -> list[Installation]:
"""Return installations that could be downloaded and installed."""
return self.filter_by(
SupportsInstallTypeFilter(DownloadAndInstall()),
@@ -163,15 +163,13 @@ class InstallationFiltersMixin:
ReadyForInstallationFilter(),
)
- def already_installed(
- self, backend_name: Optional[str] = None
- ) -> List[Installation]:
+ def already_installed(self, backend_name: str | None = None) -> list[Installation]:
"""Return list of backends that are already installed."""
return self.filter_by(
AlreadyInstalledFilter(), SearchByNameFilter(backend_name)
)
- def ready_for_installation(self) -> List[Installation]:
+ def ready_for_installation(self) -> list[Installation]:
"""Return list of the backends that could be installed."""
return self.filter_by(ReadyForInstallationFilter())
@@ -180,15 +178,15 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
"""Interactive installation manager."""
def __init__(
- self, installations: List[Installation], noninteractive: bool = False
+ self, installations: list[Installation], noninteractive: bool = False
) -> None:
"""Init the manager."""
self.installations = installations
self.noninteractive = noninteractive
def choose_installation_for_path(
- self, backend_path: Path, backend_name: Optional[str]
- ) -> Optional[Installation]:
+ self, backend_path: Path, backend_name: str | None
+ ) -> Installation | None:
"""Check available installation and select one if possible."""
installs = self.could_be_installed_from(backend_path, backend_name)
@@ -220,7 +218,7 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
return installation
- def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None:
+ def install_from(self, backend_path: Path, backend_name: str | None) -> None:
"""Install from the provided directory."""
installation = self.choose_installation_for_path(backend_path, backend_name)
@@ -234,7 +232,7 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
self._install(installation, InstallFromPath(backend_path), prompt)
def download_and_install(
- self, backend_name: Optional[str] = None, eula_agreement: bool = True
+ self, backend_name: str | None = None, eula_agreement: bool = True
) -> None:
"""Download and install available backends."""
installations = self.could_be_downloaded_and_installed(backend_name)
@@ -269,7 +267,7 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
@staticmethod
def _print_installation_list(
- header: str, installations: List[Installation], new_section: bool = False
+ header: str, installations: list[Installation], new_section: bool = False
) -> None:
"""Print list of the installations."""
logger.info("%s%s\n", "\n" if new_section else "", header)
diff --git a/src/mlia/tools/metadata/corstone.py b/src/mlia/tools/metadata/corstone.py
index 023369c..feef7ad 100644
--- a/src/mlia/tools/metadata/corstone.py
+++ b/src/mlia/tools/metadata/corstone.py
@@ -6,6 +6,8 @@ The import of subprocess module raises a B404 bandit error. MLIA usage of
subprocess is needed and can be considered safe hence disabling the security
check.
"""
+from __future__ import annotations
+
import logging
import platform
import subprocess # nosec
@@ -14,7 +16,6 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Callable
from typing import Iterable
-from typing import List
from typing import Optional
import mlia.backend.manager as backend_manager
@@ -40,7 +41,7 @@ class BackendInfo:
backend_path: Path
copy_source: bool = True
- system_config: Optional[str] = None
+ system_config: str | None = None
PathChecker = Callable[[Path], Optional[BackendInfo]]
@@ -55,10 +56,10 @@ class BackendMetadata:
name: str,
description: str,
system_config: str,
- apps_resources: List[str],
+ apps_resources: list[str],
fvp_dir_name: str,
- download_artifact: Optional[DownloadArtifact],
- supported_platforms: Optional[List[str]] = None,
+ download_artifact: DownloadArtifact | None,
+ supported_platforms: list[str] | None = None,
) -> None:
"""
Initialize BackendMetadata.
@@ -100,7 +101,7 @@ class BackendInstallation(Installation):
backend_runner: backend_manager.BackendRunner,
metadata: BackendMetadata,
path_checker: PathChecker,
- backend_installer: Optional[BackendInstaller],
+ backend_installer: BackendInstaller | None,
) -> None:
"""Init the backend installation."""
self.backend_runner = backend_runner
@@ -209,13 +210,13 @@ class PackagePathChecker:
"""Package path checker."""
def __init__(
- self, expected_files: List[str], backend_subfolder: Optional[str] = None
+ self, expected_files: list[str], backend_subfolder: str | None = None
) -> None:
"""Init the path checker."""
self.expected_files = expected_files
self.backend_subfolder = backend_subfolder
- def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ def __call__(self, backend_path: Path) -> BackendInfo | None:
"""Check if directory contains all expected files."""
resolved_paths = (backend_path / file for file in self.expected_files)
if not all_files_exist(resolved_paths):
@@ -238,9 +239,9 @@ class StaticPathChecker:
def __init__(
self,
static_backend_path: Path,
- expected_files: List[str],
+ expected_files: list[str],
copy_source: bool = False,
- system_config: Optional[str] = None,
+ system_config: str | None = None,
) -> None:
"""Init static path checker."""
self.static_backend_path = static_backend_path
@@ -248,7 +249,7 @@ class StaticPathChecker:
self.copy_source = copy_source
self.system_config = system_config
- def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ def __call__(self, backend_path: Path) -> BackendInfo | None:
"""Check if directory equals static backend path with all expected files."""
if backend_path != self.static_backend_path:
return None
@@ -271,7 +272,7 @@ class CompoundPathChecker:
"""Init compound path checker."""
self.path_checkers = path_checkers
- def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ def __call__(self, backend_path: Path) -> BackendInfo | None:
"""Iterate over checkers and return first non empty backend info."""
first_resolved_backend_info = (
backend_info
@@ -401,7 +402,7 @@ def get_corstone_310_installation() -> Installation:
return corstone_310
-def get_corstone_installations() -> List[Installation]:
+def get_corstone_installations() -> list[Installation]:
"""Get Corstone installations."""
return [
get_corstone_300_installation(),
diff --git a/src/mlia/tools/vela_wrapper.py b/src/mlia/tools/vela_wrapper.py
index 7225797..47c15e9 100644
--- a/src/mlia/tools/vela_wrapper.py
+++ b/src/mlia/tools/vela_wrapper.py
@@ -1,18 +1,15 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Vela wrapper module."""
+from __future__ import annotations
+
import itertools
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
-from typing import Dict
-from typing import List
from typing import Literal
-from typing import Optional
-from typing import Tuple
-from typing import Union
import numpy as np
from ethosu.vela.architecture_features import ArchitectureFeatures
@@ -70,7 +67,7 @@ class NpuSupported:
"""Operator's npu supported attribute."""
supported: bool
- reasons: List[Tuple[str, str]]
+ reasons: list[tuple[str, str]]
@dataclass
@@ -95,7 +92,7 @@ class Operator:
class Operators:
"""Model's operators."""
- ops: List[Operator]
+ ops: list[Operator]
@property
def npu_supported_ratio(self) -> float:
@@ -150,7 +147,7 @@ class OptimizedModel:
compiler_options: CompilerOptions
scheduler_options: SchedulerOptions
- def save(self, output_filename: Union[str, Path]) -> None:
+ def save(self, output_filename: str | Path) -> None:
"""Save instance of the optimized model to the file."""
write_tflite(self.nng, output_filename)
@@ -173,16 +170,16 @@ OptimizationStrategyType = Literal["Performance", "Size"]
class VelaCompilerOptions: # pylint: disable=too-many-instance-attributes
"""Vela compiler options."""
- config_files: Optional[Union[str, List[str]]] = None
+ config_files: str | list[str] | None = None
system_config: str = ArchitectureFeatures.DEFAULT_CONFIG
memory_mode: str = ArchitectureFeatures.DEFAULT_CONFIG
- accelerator_config: Optional[AcceleratorConfigType] = None
+ accelerator_config: AcceleratorConfigType | None = None
max_block_dependency: int = ArchitectureFeatures.MAX_BLOCKDEP
- arena_cache_size: Optional[int] = None
+ arena_cache_size: int | None = None
tensor_allocator: TensorAllocatorType = "HillClimb"
cpu_tensor_alignment: int = Tensor.AllocationQuantum
optimization_strategy: OptimizationStrategyType = "Performance"
- output_dir: Optional[str] = None
+ output_dir: str | None = None
recursion_limit: int = 1000
@@ -207,14 +204,14 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
sys.setrecursionlimit(self.recursion_limit)
- def read_model(self, model: Union[str, Path]) -> Model:
+ def read_model(self, model: str | Path) -> Model:
"""Read model."""
logger.debug("Read model %s", model)
nng, network_type = self._read_model(model)
return Model(nng, network_type)
- def compile_model(self, model: Union[str, Path, Model]) -> OptimizedModel:
+ def compile_model(self, model: str | Path | Model) -> OptimizedModel:
"""Compile the model."""
if isinstance(model, (str, Path)):
nng, network_type = self._read_model(model)
@@ -240,7 +237,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
except (SystemExit, Exception) as err:
raise Exception("Model could not be optimized with Vela compiler") from err
- def get_config(self) -> Dict[str, Any]:
+ def get_config(self) -> dict[str, Any]:
"""Get compiler configuration."""
arch = self._architecture_features()
@@ -277,7 +274,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
}
@staticmethod
- def _read_model(model: Union[str, Path]) -> Tuple[Graph, NetworkType]:
+ def _read_model(model: str | Path) -> tuple[Graph, NetworkType]:
"""Read TFLite model."""
try:
model_path = str(model) if isinstance(model, Path) else model
@@ -334,7 +331,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
def resolve_compiler_config(
vela_compiler_options: VelaCompilerOptions,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""Resolve passed compiler options.
Vela has number of configuration parameters that being
@@ -397,7 +394,7 @@ def _performance_metrics(optimized_model: OptimizedModel) -> PerformanceMetrics:
def memory_usage(mem_area: MemArea) -> int:
"""Get memory usage for the proviced memory area type."""
- memory_used: Dict[MemArea, int] = optimized_model.nng.memory_used
+ memory_used: dict[MemArea, int] = optimized_model.nng.memory_used
bandwidths = optimized_model.nng.bandwidths
return memory_used.get(mem_area, 0) if np.sum(bandwidths[mem_area]) > 0 else 0
diff --git a/src/mlia/utils/console.py b/src/mlia/utils/console.py
index 7cb3d83..1f428a7 100644
--- a/src/mlia/utils/console.py
+++ b/src/mlia/utils/console.py
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Console output utility functions."""
+from __future__ import annotations
+
from typing import Iterable
-from typing import List
-from typing import Optional
from rich.console import Console
from rich.console import RenderableType
@@ -13,7 +13,7 @@ from rich.text import Text
def create_section_header(
- section_name: Optional[str] = None, length: int = 80, sep: str = "-"
+ section_name: str | None = None, length: int = 80, sep: str = "-"
) -> str:
"""Return section header."""
if not section_name:
@@ -41,7 +41,7 @@ def style_improvement(result: bool) -> str:
def produce_table(
rows: Iterable,
- headers: Optional[List[str]] = None,
+ headers: list[str] | None = None,
table_style: str = "default",
) -> str:
"""Represent data in tabular form."""
diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py
index 4658738..9ef2d9e 100644
--- a/src/mlia/utils/download.py
+++ b/src/mlia/utils/download.py
@@ -1,11 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Utils for files downloading."""
+from __future__ import annotations
+
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
-from typing import List
-from typing import Optional
import requests
from rich.progress import BarColumn
@@ -20,10 +20,10 @@ from mlia.utils.types import parse_int
def download_progress(
- content_chunks: Iterable[bytes], content_length: Optional[int], label: Optional[str]
+ content_chunks: Iterable[bytes], content_length: int | None, label: str | None
) -> Iterable[bytes]:
"""Show progress info while reading content."""
- columns: List[ProgressColumn] = [TextColumn("{task.description}")]
+ columns: list[ProgressColumn] = [TextColumn("{task.description}")]
if content_length is None:
total = float("inf")
@@ -44,7 +44,7 @@ def download(
url: str,
dest: Path,
show_progress: bool = False,
- label: Optional[str] = None,
+ label: str | None = None,
chunk_size: int = 8192,
) -> None:
"""Download the file."""
diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py
index 0c28d35..25619c5 100644
--- a/src/mlia/utils/filesystem.py
+++ b/src/mlia/utils/filesystem.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Utils related to file management."""
+from __future__ import annotations
+
import hashlib
import importlib.resources as pkg_resources
import json
@@ -12,12 +14,8 @@ from tempfile import mkstemp
from tempfile import TemporaryDirectory
from typing import Any
from typing import cast
-from typing import Dict
from typing import Generator
from typing import Iterable
-from typing import List
-from typing import Optional
-from typing import Union
def get_mlia_resources() -> Path:
@@ -37,7 +35,7 @@ def get_profiles_file() -> Path:
return get_mlia_resources() / "profiles.json"
-def get_profiles_data() -> Dict[str, Dict[str, Any]]:
+def get_profiles_data() -> dict[str, dict[str, Any]]:
"""Get the profile values as a dictionary."""
with open(get_profiles_file(), encoding="utf-8") as json_file:
profiles = json.load(json_file)
@@ -48,7 +46,7 @@ def get_profiles_data() -> Dict[str, Dict[str, Any]]:
return profiles
-def get_profile(target_profile: str) -> Dict[str, Any]:
+def get_profile(target_profile: str) -> dict[str, Any]:
"""Get settings for the provided target profile."""
if not target_profile:
raise Exception("Target profile is not provided")
@@ -61,7 +59,7 @@ def get_profile(target_profile: str) -> Dict[str, Any]:
raise Exception(f"Unable to find target profile {target_profile}") from err
-def get_supported_profile_names() -> List[str]:
+def get_supported_profile_names() -> list[str]:
"""Get the supported Ethos-U profile names."""
return list(get_profiles_data().keys())
@@ -73,7 +71,7 @@ def get_target(target_profile: str) -> str:
@contextmanager
-def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]:
+def temp_file(suffix: str | None = None) -> Generator[Path, None, None]:
"""Create temp file and remove it after."""
_, tmp_file = mkstemp(suffix=suffix)
@@ -84,14 +82,14 @@ def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]:
@contextmanager
-def temp_directory(suffix: Optional[str] = None) -> Generator[Path, None, None]:
+def temp_directory(suffix: str | None = None) -> Generator[Path, None, None]:
"""Create temp directory and remove it after."""
with TemporaryDirectory(suffix=suffix) as tmpdir:
yield Path(tmpdir)
def file_chunks(
- filepath: Union[Path, str], chunk_size: int = 4096
+ filepath: str | Path, chunk_size: int = 4096
) -> Generator[bytes, None, None]:
"""Return sequence of the file chunks."""
with open(filepath, "rb") as file:
@@ -99,7 +97,7 @@ def file_chunks(
yield data
-def hexdigest(filepath: Union[Path, str], hash_obj: "hashlib._Hash") -> str:
+def hexdigest(filepath: str | Path, hash_obj: "hashlib._Hash") -> str:
"""Return hex digest of the file."""
for chunk in file_chunks(filepath):
hash_obj.update(chunk)
diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py
index 86d7567..793500a 100644
--- a/src/mlia/utils/logging.py
+++ b/src/mlia/utils/logging.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Logging utility functions."""
+from __future__ import annotations
+
import logging
from contextlib import contextmanager
from contextlib import ExitStack
@@ -10,8 +12,6 @@ from pathlib import Path
from typing import Any
from typing import Callable
from typing import Generator
-from typing import List
-from typing import Optional
class LoggerWriter:
@@ -61,7 +61,7 @@ class LogFilter(logging.Filter):
return self.log_record_filter(record)
@classmethod
- def equals(cls, log_level: int) -> "LogFilter":
+ def equals(cls, log_level: int) -> LogFilter:
"""Return log filter that filters messages by log level."""
def filter_by_level(log_record: logging.LogRecord) -> bool:
@@ -70,7 +70,7 @@ class LogFilter(logging.Filter):
return cls(filter_by_level)
@classmethod
- def skip(cls, log_level: int) -> "LogFilter":
+ def skip(cls, log_level: int) -> LogFilter:
"""Return log filter that skips messages with particular level."""
def skip_by_level(log_record: logging.LogRecord) -> bool:
@@ -81,15 +81,15 @@ class LogFilter(logging.Filter):
def create_log_handler(
*,
- file_path: Optional[Path] = None,
- stream: Optional[Any] = None,
- log_level: Optional[int] = None,
- log_format: Optional[str] = None,
- log_filter: Optional[logging.Filter] = None,
+ file_path: Path | None = None,
+ stream: Any | None = None,
+ log_level: int | None = None,
+ log_format: str | None = None,
+ log_filter: logging.Filter | None = None,
delay: bool = True,
) -> logging.Handler:
"""Create logger handler."""
- handler: Optional[logging.Handler] = None
+ handler: logging.Handler | None = None
if file_path is not None:
handler = logging.FileHandler(file_path, delay=delay)
@@ -112,7 +112,7 @@ def create_log_handler(
def attach_handlers(
- handlers: List[logging.Handler], loggers: List[logging.Logger]
+ handlers: list[logging.Handler], loggers: list[logging.Logger]
) -> None:
"""Attach handlers to the loggers."""
for handler in handlers:
diff --git a/src/mlia/utils/types.py b/src/mlia/utils/types.py
index 9b63928..ea067b8 100644
--- a/src/mlia/utils/types.py
+++ b/src/mlia/utils/types.py
@@ -1,11 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Types related utility functions."""
+from __future__ import annotations
+
from typing import Any
-from typing import Optional
-def is_list_of(data: Any, cls: type, elem_num: Optional[int] = None) -> bool:
+def is_list_of(data: Any, cls: type, elem_num: int | None = None) -> bool:
"""Check if data is a list of object of the same class."""
return (
isinstance(data, (tuple, list))
@@ -24,7 +25,7 @@ def is_number(value: str) -> bool:
return True
-def parse_int(value: Any, default: Optional[int] = None) -> Optional[int]:
+def parse_int(value: Any, default: int | None = None) -> int | None:
"""Parse integer value."""
try:
return int(value)
diff --git a/tests/test_backend_application.py b/tests/test_backend_application.py
index 6860ecb..9606802 100644
--- a/tests/test_backend_application.py
+++ b/tests/test_backend_application.py
@@ -2,11 +2,12 @@
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=no-self-use
"""Tests for the application backend."""
+from __future__ import annotations
+
from collections import Counter
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import List
from unittest.mock import MagicMock
import pytest
@@ -289,7 +290,7 @@ class TestApplication:
),
)
def test_remove_unused_params(
- self, config: ApplicationConfig, expected_params: List[Param]
+ self, config: ApplicationConfig, expected_params: list[Param]
) -> None:
"""Test mod remove_unused_parameter."""
application = Application(config)
diff --git a/tests/test_backend_common.py b/tests/test_backend_common.py
index 0533ef6..d11261e 100644
--- a/tests/test_backend_common.py
+++ b/tests/test_backend_common.py
@@ -2,16 +2,14 @@
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=no-self-use,protected-access
"""Tests for the common backend module."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import cast
-from typing import Dict
from typing import IO
from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
from unittest.mock import MagicMock
import pytest
@@ -62,7 +60,7 @@ def test_load_config(
) -> None:
"""Test load_config."""
with expected_exception:
- configs: List[Optional[Union[Path, IO[bytes]]]] = (
+ configs: list[Path | IO[bytes] | None] = (
[None]
if not filename
else [
@@ -283,8 +281,8 @@ class TestBackend:
def test_resolved_parameters(
self,
class_: type,
- config: Dict,
- expected_output: List[Tuple[Optional[str], Param]],
+ config: dict,
+ expected_output: list[tuple[str | None, Param]],
) -> None:
"""Test command building."""
backend = class_(config)
@@ -343,7 +341,7 @@ class TestBackend:
],
)
def test__parse_raw_parameter(
- self, input_param: str, expected: Tuple[str, Optional[str]]
+ self, input_param: str, expected: tuple[str, str | None]
) -> None:
"""Test internal method of parsing a single raw parameter."""
assert parse_raw_parameter(input_param) == expected
@@ -476,7 +474,7 @@ class TestCommand:
],
],
)
- def test_validate_params(self, params: List[Param], expected_error: Any) -> None:
+ def test_validate_params(self, params: list[Param], expected_error: Any) -> None:
"""Test command validation function."""
with expected_error:
Command([], params)
diff --git a/tests/test_backend_fs.py b/tests/test_backend_fs.py
index 7423222..21226a9 100644
--- a/tests/test_backend_fs.py
+++ b/tests/test_backend_fs.py
@@ -2,10 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=no-self-use
"""Module for testing fs.py."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import Union
from unittest.mock import MagicMock
import pytest
@@ -108,7 +109,7 @@ def test_recreate_directory(tmpdir: Any) -> None:
def write_to_file(
- write_directory: Any, write_mode: str, write_text: Union[str, bytes]
+ write_directory: Any, write_mode: str, write_text: str | bytes
) -> Path:
"""Write some text to a temporary test file."""
tmpdir_path = Path(write_directory)
diff --git a/tests/test_backend_manager.py b/tests/test_backend_manager.py
index 1b5fea1..a1e9198 100644
--- a/tests/test_backend_manager.py
+++ b/tests/test_backend_manager.py
@@ -1,16 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module backend/manager."""
+from __future__ import annotations
+
import base64
import json
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Set
-from typing import Tuple
from unittest.mock import MagicMock
from unittest.mock import PropertyMock
@@ -35,7 +32,7 @@ from mlia.backend.output_consumer import Base64OutputConsumer
from mlia.backend.system import get_system
-def _mock_encode_b64(data: Dict[str, int]) -> str:
+def _mock_encode_b64(data: dict[str, int]) -> str:
"""
Encode the given data into a mock base64-encoded string of JSON.
@@ -138,7 +135,7 @@ def _mock_encode_b64(data: Dict[str, int]) -> str:
],
)
def test_generic_inference_output_parser(
- data: Dict[str, int], is_ready: bool, result: Dict, missed_keys: Set[str]
+ data: dict[str, int], is_ready: bool, result: dict, missed_keys: set[str]
) -> None:
"""Test generic runner output parser."""
parser = GenericInferenceOutputParser()
@@ -157,8 +154,8 @@ class TestBackendRunner:
@staticmethod
def _setup_backends(
monkeypatch: pytest.MonkeyPatch,
- available_systems: Optional[List[str]] = None,
- available_apps: Optional[List[str]] = None,
+ available_systems: list[str] | None = None,
+ available_apps: list[str] | None = None,
) -> None:
"""Set up backend metadata."""
@@ -196,7 +193,7 @@ class TestBackendRunner:
)
def test_is_system_installed(
self,
- available_systems: List,
+ available_systems: list,
system: str,
installed: bool,
monkeypatch: pytest.MonkeyPatch,
@@ -217,8 +214,8 @@ class TestBackendRunner:
)
def test_installed_systems(
self,
- available_systems: List[str],
- systems: List[str],
+ available_systems: list[str],
+ systems: list[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test method installed_systems."""
@@ -250,8 +247,8 @@ class TestBackendRunner:
)
def test_systems_installed(
self,
- available_systems: List[str],
- systems: List[str],
+ available_systems: list[str],
+ systems: list[str],
expected_result: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -274,8 +271,8 @@ class TestBackendRunner:
)
def test_applications_installed(
self,
- available_apps: List[str],
- applications: List[str],
+ available_apps: list[str],
+ applications: list[str],
expected_result: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -297,8 +294,8 @@ class TestBackendRunner:
)
def test_get_installed_applications(
self,
- available_apps: List[str],
- applications: List[str],
+ available_apps: list[str],
+ applications: list[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test method get_installed_applications."""
@@ -337,7 +334,7 @@ class TestBackendRunner:
)
def test_is_application_installed(
self,
- available_apps: List[str],
+ available_apps: list[str],
application: str,
installed: bool,
monkeypatch: pytest.MonkeyPatch,
@@ -377,7 +374,7 @@ class TestBackendRunner:
def test_run_application_local(
monkeypatch: pytest.MonkeyPatch,
execution_params: ExecutionParams,
- expected_command: List[str],
+ expected_command: list[str],
) -> None:
"""Test method run_application with local systems."""
run_app = MagicMock(wraps=run_application)
@@ -491,8 +488,8 @@ class TestBackendRunner:
)
def test_estimate_performance(
device: DeviceInfo,
- system: Tuple[str, bool],
- application: Tuple[str, bool],
+ system: tuple[str, bool],
+ application: tuple[str, bool],
backend: str,
expected_error: Any,
test_tflite_model: Path,
@@ -588,7 +585,7 @@ def test_estimate_performance_invalid_output(
)
-def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock:
+def create_mock_process(stdout: list[str], stderr: list[str]) -> MagicMock:
"""Mock underlying process."""
mock_process = MagicMock()
mock_process.poll.return_value = 0
@@ -597,7 +594,7 @@ def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock:
return mock_process
-def create_mock_context(stdout: List[str]) -> ExecutionContext:
+def create_mock_context(stdout: list[str]) -> ExecutionContext:
"""Mock ExecutionContext."""
ctx = ExecutionContext(
app=get_application("application_1")[0],
diff --git a/tests/test_backend_output_consumer.py b/tests/test_backend_output_consumer.py
index 881112e..2ecb07f 100644
--- a/tests/test_backend_output_consumer.py
+++ b/tests/test_backend_output_consumer.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the output parsing."""
+from __future__ import annotations
+
import base64
import json
from typing import Any
-from typing import Dict
import pytest
@@ -42,7 +43,7 @@ REGEX_CONFIG = {
"FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"},
}
-EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {}
+EMPTY_REGEX_CONFIG: dict[str, dict[str, Any]] = {}
EXPECTED_METRICS_ALL = {
"FirstString": "My awesome string!",
@@ -63,7 +64,7 @@ EXPECTED_METRICS_PARTIAL = {
EXPECTED_METRICS_PARTIAL,
],
)
-def test_base64_output_consumer(expected_metrics: Dict) -> None:
+def test_base64_output_consumer(expected_metrics: dict) -> None:
"""
Make sure the Base64OutputConsumer yields valid results.
@@ -73,7 +74,7 @@ def test_base64_output_consumer(expected_metrics: Dict) -> None:
parser = Base64OutputConsumer()
assert isinstance(parser, OutputConsumer)
- def create_base64_output(expected_metrics: Dict) -> bytearray:
+ def create_base64_output(expected_metrics: dict) -> bytearray:
json_str = json.dumps(expected_metrics, indent=4)
json_b64 = base64.b64encode(json_str.encode("utf-8"))
return (
diff --git a/tests/test_backend_system.py b/tests/test_backend_system.py
index 13347c6..7a8b1de 100644
--- a/tests/test_backend_system.py
+++ b/tests/test_backend_system.py
@@ -1,14 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for system backend."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
from unittest.mock import MagicMock
import pytest
@@ -27,12 +25,12 @@ from mlia.backend.system import System
def dummy_resolver(
- values: Optional[Dict[str, str]] = None
-) -> Callable[[str, str, List[Tuple[Optional[str], Param]]], str]:
+ values: dict[str, str] | None = None
+) -> Callable[[str, str, list[tuple[str | None, Param]]], str]:
"""Return dummy parameter resolver implementation."""
# pylint: disable=unused-argument
def resolver(
- param: str, cmd: str, param_values: List[Tuple[Optional[str], Param]]
+ param: str, cmd: str, param_values: list[tuple[str | None, Param]]
) -> str:
"""Implement dummy parameter resolver."""
return values.get(param, "") if values else ""
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index bf17339..eaa08e6 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for cli.commands module."""
+from __future__ import annotations
+
from pathlib import Path
from typing import Any
-from typing import Optional
from unittest.mock import call
from unittest.mock import MagicMock
@@ -165,7 +166,7 @@ def test_backend_command_action_status(installation_manager_mock: MagicMock) ->
def test_backend_command_action_add_downoad(
installation_manager_mock: MagicMock,
i_agree_to_the_contained_eula: bool,
- backend_name: Optional[str],
+ backend_name: str | None,
expected_calls: Any,
) -> None:
"""Test backend command "install" with download option."""
@@ -183,7 +184,7 @@ def test_backend_command_action_add_downoad(
def test_backend_command_action_install_from_path(
installation_manager_mock: MagicMock,
tmp_path: Path,
- backend_name: Optional[str],
+ backend_name: str | None,
) -> None:
"""Test backend command "install" with backend path."""
backend(backend_action="install", path=tmp_path, name=backend_name)
diff --git a/tests/test_cli_config.py b/tests/test_cli_config.py
index 6d19eec..1a7cb3f 100644
--- a/tests/test_cli_config.py
+++ b/tests/test_cli_config.py
@@ -1,7 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for cli.config module."""
-from typing import List
+from __future__ import annotations
+
from unittest.mock import MagicMock
import pytest
@@ -30,8 +31,8 @@ from mlia.cli.config import is_corstone_backend
)
def test_get_default_backends(
monkeypatch: pytest.MonkeyPatch,
- available_backends: List[str],
- expected_default_backends: List[str],
+ available_backends: list[str],
+ expected_default_backends: list[str],
) -> None:
"""Test function get_default backends."""
monkeypatch.setattr(
diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py
index 2c52885..c8aeebe 100644
--- a/tests/test_cli_helpers.py
+++ b/tests/test_cli_helpers.py
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the helper classes."""
+from __future__ import annotations
+
from typing import Any
-from typing import Dict
-from typing import List
import pytest
@@ -67,9 +67,9 @@ class TestCliActionResolver:
],
)
def test_apply_optimizations(
- args: Dict[str, Any],
- params: Dict[str, Any],
- expected_result: List[str],
+ args: dict[str, Any],
+ params: dict[str, Any],
+ expected_result: list[str],
) -> None:
"""Test action resolving for applying optimizations."""
resolver = CLIActionResolver(args)
@@ -127,7 +127,7 @@ class TestCliActionResolver:
],
)
def test_check_performance(
- args: Dict[str, Any], expected_result: List[str]
+ args: dict[str, Any], expected_result: list[str]
) -> None:
"""Test check performance info."""
resolver = CLIActionResolver(args)
@@ -158,7 +158,7 @@ class TestCliActionResolver:
],
)
def test_check_operator_compatibility(
- args: Dict[str, Any], expected_result: List[str]
+ args: dict[str, Any], expected_result: list[str]
) -> None:
"""Test checking operator compatibility info."""
resolver = CLIActionResolver(args)
diff --git a/tests/test_cli_logging.py b/tests/test_cli_logging.py
index 5d26551..1e2cc85 100644
--- a/tests/test_cli_logging.py
+++ b/tests/test_cli_logging.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module cli.logging."""
+from __future__ import annotations
+
import logging
from pathlib import Path
-from typing import Optional
import pytest
@@ -78,7 +79,7 @@ def test_setup_logging(
def check_log_assertions(
- logs_dir_path: Optional[Path], expected_log_file_content: str
+ logs_dir_path: Path | None, expected_log_file_content: str
) -> None:
"""Test assertions for log file."""
if logs_dir_path is not None:
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index 28abc7b..78adc53 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -1,12 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for main module."""
+from __future__ import annotations
+
import argparse
from functools import wraps
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import List
from unittest.mock import ANY
from unittest.mock import call
from unittest.mock import MagicMock
@@ -252,7 +253,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non
],
)
def test_commands_execution(
- monkeypatch: pytest.MonkeyPatch, params: List[str], expected_call: Any
+ monkeypatch: pytest.MonkeyPatch, params: list[str], expected_call: Any
) -> None:
"""Test calling commands from the main function."""
mock = MagicMock()
@@ -320,7 +321,7 @@ def test_verbose_output(
capsys: pytest.CaptureFixture,
verbose: bool,
exc_mock: MagicMock,
- expected_output: List[str],
+ expected_output: list[str],
) -> None:
"""Test flag --verbose."""
diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py
index a441e58..f898146 100644
--- a/tests/test_cli_options.py
+++ b/tests/test_cli_options.py
@@ -1,13 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module options."""
+from __future__ import annotations
+
import argparse
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import Dict
-from typing import List
-from typing import Optional
import pytest
@@ -137,7 +136,7 @@ def test_parse_optimization_parameters(
],
],
)
-def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None:
+def test_get_target_opts(args: dict | None, expected_opts: list[str]) -> None:
"""Test getting target options."""
assert get_target_profile_opts(args) == expected_opts
@@ -153,7 +152,7 @@ def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None
[["--output", "some_folder/report.csv"], "some_folder/report.csv"],
],
)
-def test_output_options(output_parameters: List[str], expected_path: str) -> None:
+def test_output_options(output_parameters: list[str], expected_path: str) -> None:
"""Test output options resolving."""
parser = argparse.ArgumentParser()
add_output_options(parser)
diff --git a/tests/test_core_advice_generation.py b/tests/test_core_advice_generation.py
index 05db698..f5e2960 100644
--- a/tests/test_core_advice_generation.py
+++ b/tests/test_core_advice_generation.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module advice_generation."""
-from typing import List
+from __future__ import annotations
import pytest
@@ -46,7 +46,7 @@ def test_advice_generation() -> None:
)
def test_advice_category_decorator(
category: AdviceCategory,
- expected_advice: List[Advice],
+ expected_advice: list[Advice],
dummy_context: Context,
) -> None:
"""Test for advice_category decorator."""
diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py
index d7a6ade..da7998c 100644
--- a/tests/test_core_reporting.py
+++ b/tests/test_core_reporting.py
@@ -1,13 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for reporting module."""
-from typing import List
-from typing import Optional
+from __future__ import annotations
import pytest
-from mlia.core._typing import OutputFormat
-from mlia.core._typing import PathOrFileLike
from mlia.core.reporting import BytesCell
from mlia.core.reporting import Cell
from mlia.core.reporting import ClockCell
@@ -19,6 +16,8 @@ from mlia.core.reporting import ReportItem
from mlia.core.reporting import resolve_output_format
from mlia.core.reporting import SingleRow
from mlia.core.reporting import Table
+from mlia.core.typing import OutputFormat
+from mlia.core.typing import PathOrFileLike
from mlia.utils.console import remove_ascii_codes
@@ -370,7 +369,7 @@ def test_nested_report_representation(
report: NestedReport,
expected_plain_text: str,
expected_json_data: dict,
- expected_csv_data: List,
+ expected_csv_data: list,
) -> None:
"""Test representation of the NestedReport."""
plain_text = report.to_plain_text()
@@ -429,7 +428,7 @@ Single row example:
],
)
def test_resolve_output_format(
- output: Optional[PathOrFileLike], expected_output_format: OutputFormat
+ output: PathOrFileLike | None, expected_output_format: OutputFormat
) -> None:
"""Test function resolve_output_format."""
assert resolve_output_format(output) == expected_output_format
diff --git a/tests/test_devices_ethosu_advice_generation.py b/tests/test_devices_ethosu_advice_generation.py
index 5d37376..5a49089 100644
--- a/tests/test_devices_ethosu_advice_generation.py
+++ b/tests/test_devices_ethosu_advice_generation.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Ethos-U advice generation."""
-from typing import List
+from __future__ import annotations
import pytest
@@ -363,7 +363,7 @@ from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
def test_ethosu_advice_producer(
tmpdir: str,
input_data: DataItem,
- expected_advice: List[Advice],
+ expected_advice: list[Advice],
advice_category: AdviceCategory,
action_resolver: ActionResolver,
) -> None:
@@ -468,7 +468,7 @@ def test_ethosu_static_advice_producer(
tmpdir: str,
advice_category: AdviceCategory,
action_resolver: ActionResolver,
- expected_advice: List[Advice],
+ expected_advice: list[Advice],
) -> None:
"""Test static advice generation."""
producer = EthosUStaticAdviceProducer()
diff --git a/tests/test_devices_ethosu_config.py b/tests/test_devices_ethosu_config.py
index 49c999a..d4e043f 100644
--- a/tests/test_devices_ethosu_config.py
+++ b/tests/test_devices_ethosu_config.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for config module."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from typing import Any
-from typing import Dict
from unittest.mock import MagicMock
import pytest
@@ -113,7 +114,7 @@ def test_get_target() -> None:
],
)
def test_ethosu_configuration(
- monkeypatch: pytest.MonkeyPatch, profile_data: Dict[str, Any], expected_error: Any
+ monkeypatch: pytest.MonkeyPatch, profile_data: dict[str, Any], expected_error: Any
) -> None:
"""Test creating Ethos-U configuration."""
monkeypatch.setattr(
diff --git a/tests/test_devices_ethosu_data_analysis.py b/tests/test_devices_ethosu_data_analysis.py
index 4b1d38b..26aae76 100644
--- a/tests/test_devices_ethosu_data_analysis.py
+++ b/tests/test_devices_ethosu_data_analysis.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Ethos-U data analysis module."""
-from typing import List
+from __future__ import annotations
import pytest
@@ -139,7 +139,7 @@ def test_perf_metrics_diff() -> None:
],
)
def test_ethos_u_data_analyzer(
- input_data: DataItem, expected_facts: List[Fact]
+ input_data: DataItem, expected_facts: list[Fact]
) -> None:
"""Test Ethos-U data analyzer."""
analyzer = EthosUDataAnalyzer()
diff --git a/tests/test_devices_ethosu_reporters.py b/tests/test_devices_ethosu_reporters.py
index a63db1c..f8a7d86 100644
--- a/tests/test_devices_ethosu_reporters.py
+++ b/tests/test_devices_ethosu_reporters.py
@@ -1,14 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for reports module."""
+from __future__ import annotations
+
import json
import sys
from contextlib import ExitStack as doesnt_raise
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import Dict
-from typing import List
from typing import Literal
import pytest
@@ -91,7 +91,7 @@ from mlia.utils.console import remove_ascii_codes
)
def test_report(
data: Any,
- formatters: List[Callable],
+ formatters: list[Callable],
fmt: Literal["plain_text", "json", "csv"],
output: Any,
expected_error: Any,
@@ -202,10 +202,10 @@ Operators:
],
)
def test_report_operators(
- ops: List[Operator],
+ ops: list[Operator],
expected_plain_text: str,
- expected_json_dict: Dict,
- expected_csv_list: List,
+ expected_json_dict: dict,
+ expected_csv_list: list,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test report_operatos formatter."""
@@ -380,8 +380,8 @@ def test_report_operators(
def test_report_device_details(
device: EthosUConfiguration,
expected_plain_text: str,
- expected_json_dict: Dict,
- expected_csv_list: List,
+ expected_json_dict: dict,
+ expected_csv_list: list,
) -> None:
"""Test report_operatos formatter."""
report = report_device_details(device)
diff --git a/tests/test_devices_tosa_advice_generation.py b/tests/test_devices_tosa_advice_generation.py
index 018ba57..1b97c8b 100644
--- a/tests/test_devices_tosa_advice_generation.py
+++ b/tests/test_devices_tosa_advice_generation.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for advice generation."""
-from typing import List
+from __future__ import annotations
import pytest
@@ -40,7 +40,7 @@ def test_tosa_advice_producer(
tmpdir: str,
input_data: DataItem,
advice_category: AdviceCategory,
- expected_advice: List[Advice],
+ expected_advice: list[Advice],
) -> None:
"""Test TOSA advice producer."""
producer = TOSAAdviceProducer()
diff --git a/tests/test_devices_tosa_data_analysis.py b/tests/test_devices_tosa_data_analysis.py
index 60bcee8..ff95978 100644
--- a/tests/test_devices_tosa_data_analysis.py
+++ b/tests/test_devices_tosa_data_analysis.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for TOSA data analysis module."""
-from typing import List
+from __future__ import annotations
import pytest
@@ -26,7 +26,7 @@ from mlia.devices.tosa.operators import TOSACompatibilityInfo
],
],
)
-def test_tosa_data_analyzer(input_data: DataItem, expected_facts: List[Fact]) -> None:
+def test_tosa_data_analyzer(input_data: DataItem, expected_facts: list[Fact]) -> None:
"""Test TOSA data analyzer."""
analyzer = TOSADataAnalyzer()
analyzer.analyze_data(input_data)
diff --git a/tests/test_devices_tosa_operators.py b/tests/test_devices_tosa_operators.py
index b7736d2..d4372aa 100644
--- a/tests/test_devices_tosa_operators.py
+++ b/tests/test_devices_tosa_operators.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for TOSA compatibility."""
+from __future__ import annotations
+
from pathlib import Path
from types import SimpleNamespace
from typing import Any
-from typing import Optional
from unittest.mock import MagicMock
import pytest
@@ -15,7 +16,7 @@ from mlia.devices.tosa.operators import TOSACompatibilityInfo
def replace_get_tosa_checker_with_mock(
- monkeypatch: pytest.MonkeyPatch, mock: Optional[MagicMock]
+ monkeypatch: pytest.MonkeyPatch, mock: MagicMock | None
) -> None:
"""Replace TOSA checker with mock."""
monkeypatch.setattr(
diff --git a/tests/test_nn_tensorflow_optimizations_clustering.py b/tests/test_nn_tensorflow_optimizations_clustering.py
index c12a1e8..13dfb31 100644
--- a/tests/test_nn_tensorflow_optimizations_clustering.py
+++ b/tests/test_nn_tensorflow_optimizations_clustering.py
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module optimizations/clustering."""
+from __future__ import annotations
+
from pathlib import Path
-from typing import List
-from typing import Optional
import pytest
import tensorflow as tf
@@ -21,7 +21,7 @@ from tests.utils.common import train_model
def _prune_model(
- model: tf.keras.Model, target_sparsity: float, layers_to_prune: Optional[List[str]]
+ model: tf.keras.Model, target_sparsity: float, layers_to_prune: list[str] | None
) -> tf.keras.Model:
x_train, y_train = get_dataset()
batch_size = 1
@@ -47,7 +47,7 @@ def _prune_model(
def _test_num_unique_weights(
metrics: TFLiteMetrics,
target_num_clusters: int,
- layers_to_cluster: Optional[List[str]],
+ layers_to_cluster: list[str] | None,
) -> None:
clustered_uniqueness_dict = metrics.num_unique_weights(
ReportClusterMode.NUM_CLUSTERS_PER_AXIS
@@ -71,7 +71,7 @@ def _test_num_unique_weights(
def _test_sparsity(
metrics: TFLiteMetrics,
target_sparsity: float,
- layers_to_cluster: Optional[List[str]],
+ layers_to_cluster: list[str] | None,
) -> None:
pruned_sparsity_dict = metrics.sparsity_per_layer()
num_sparse_layers = 0
@@ -95,7 +95,7 @@ def _test_sparsity(
def test_cluster_simple_model_fully(
target_num_clusters: int,
sparsity_aware: bool,
- layers_to_cluster: Optional[List[str]],
+ layers_to_cluster: list[str] | None,
tmp_path: Path,
test_keras_model: Path,
) -> None:
diff --git a/tests/test_nn_tensorflow_optimizations_pruning.py b/tests/test_nn_tensorflow_optimizations_pruning.py
index 5d92f5e..d97b3d3 100644
--- a/tests/test_nn_tensorflow_optimizations_pruning.py
+++ b/tests/test_nn_tensorflow_optimizations_pruning.py
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module optimizations/pruning."""
+from __future__ import annotations
+
from pathlib import Path
-from typing import List
-from typing import Optional
import pytest
import tensorflow as tf
@@ -21,7 +21,7 @@ from tests.utils.common import train_model
def _test_sparsity(
metrics: TFLiteMetrics,
target_sparsity: float,
- layers_to_prune: Optional[List[str]],
+ layers_to_prune: list[str] | None,
) -> None:
pruned_sparsity_dict = metrics.sparsity_per_layer()
num_sparse_layers = 0
@@ -62,7 +62,7 @@ def _get_tflite_metrics(
def test_prune_simple_model_fully(
target_sparsity: float,
mock_data: bool,
- layers_to_prune: Optional[List[str]],
+ layers_to_prune: list[str] | None,
tmp_path: Path,
test_keras_model: Path,
) -> None:
diff --git a/tests/test_nn_tensorflow_optimizations_select.py b/tests/test_nn_tensorflow_optimizations_select.py
index 5cac8ba..e22a9d8 100644
--- a/tests/test_nn_tensorflow_optimizations_select.py
+++ b/tests/test_nn_tensorflow_optimizations_select.py
@@ -1,11 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module select."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import List
-from typing import Tuple
import pytest
import tensorflow as tf
@@ -187,7 +187,7 @@ def test_get_optimizer(
],
)
def test_optimization_settings_create_from(
- params: List[Tuple[str, float]], expected_result: List[OptimizationSettings]
+ params: list[tuple[str, float]], expected_result: list[OptimizationSettings]
) -> None:
"""Test creating settings from parsed params."""
assert OptimizationSettings.create_from(params) == expected_result
diff --git a/tests/test_nn_tensorflow_tflite_metrics.py b/tests/test_nn_tensorflow_tflite_metrics.py
index 00eacef..a5e7736 100644
--- a/tests/test_nn_tensorflow_tflite_metrics.py
+++ b/tests/test_nn_tensorflow_tflite_metrics.py
@@ -1,12 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module utils/tflite_metrics."""
+from __future__ import annotations
+
import os
import tempfile
from math import isclose
from pathlib import Path
from typing import Generator
-from typing import List
import numpy as np
import pytest
@@ -31,7 +32,7 @@ def _dummy_keras_model() -> tf.keras.Model:
def _sparse_binary_keras_model() -> tf.keras.Model:
- def get_sparse_weights(shape: List[int]) -> np.ndarray:
+ def get_sparse_weights(shape: list[int]) -> np.ndarray:
weights = np.zeros(shape)
with np.nditer(weights, op_flags=["writeonly"]) as weight_iterator:
for idx, value in enumerate(weight_iterator):
diff --git a/tests/test_tools_metadata_common.py b/tests/test_tools_metadata_common.py
index 7663b83..69bc3e5 100644
--- a/tests/test_tools_metadata_common.py
+++ b/tests/test_tools_metadata_common.py
@@ -1,10 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for commmon installation related functions."""
+from __future__ import annotations
+
from pathlib import Path
from typing import Any
-from typing import List
-from typing import Optional
from unittest.mock import call
from unittest.mock import MagicMock
from unittest.mock import PropertyMock
@@ -22,7 +22,7 @@ def get_installation_mock(
name: str,
already_installed: bool = False,
could_be_installed: bool = False,
- supported_install_type: Optional[type] = None,
+ supported_install_type: type | None = None,
) -> MagicMock:
"""Get mock instance for the installation."""
mock = MagicMock(spec=Installation)
@@ -81,7 +81,7 @@ def _could_be_installed_from_mock() -> MagicMock:
def get_installation_manager(
noninteractive: bool,
- installations: List[Any],
+ installations: list[Any],
monkeypatch: pytest.MonkeyPatch,
yes_response: bool = True,
) -> DefaultInstallationManager:
@@ -146,7 +146,7 @@ def test_installation_manager_download_and_install(
install_mock: MagicMock,
noninteractive: bool,
eula_agreement: bool,
- backend_name: Optional[str],
+ backend_name: str | None,
expected_call: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -183,7 +183,7 @@ def test_installation_manager_download_and_install(
def test_installation_manager_install_from(
install_mock: MagicMock,
noninteractive: bool,
- backend_name: Optional[str],
+ backend_name: str | None,
expected_call: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
diff --git a/tests/test_tools_metadata_corstone.py b/tests/test_tools_metadata_corstone.py
index 017d0c7..e2b2ae5 100644
--- a/tests/test_tools_metadata_corstone.py
+++ b/tests/test_tools_metadata_corstone.py
@@ -1,10 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Corstone related installation functions.."""
+from __future__ import annotations
+
import tarfile
from pathlib import Path
-from typing import List
-from typing import Optional
from unittest.mock import MagicMock
import pytest
@@ -44,12 +44,12 @@ def get_backend_installation( # pylint: disable=too-many-arguments
backend_runner_mock: MagicMock = MagicMock(),
name: str = "test_name",
description: str = "test_description",
- download_artifact: Optional[MagicMock] = None,
+ download_artifact: MagicMock | None = None,
path_checker: PathChecker = MagicMock(),
- apps_resources: Optional[List[str]] = None,
- system_config: Optional[str] = None,
+ apps_resources: list[str] | None = None,
+ system_config: str | None = None,
backend_installer: BackendInstaller = MagicMock(),
- supported_platforms: Optional[List[str]] = None,
+ supported_platforms: list[str] | None = None,
) -> BackendInstallation:
"""Get backend installation."""
return BackendInstallation(
@@ -79,7 +79,7 @@ def get_backend_installation( # pylint: disable=too-many-arguments
)
def test_could_be_installed_depends_on_platform(
platform: str,
- supported_platforms: Optional[List[str]],
+ supported_platforms: list[str] | None,
expected_result: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -309,7 +309,7 @@ def test_backend_installation_download_and_install(
],
)
def test_corstone_path_checker_valid_path(
- tmp_path: Path, dir_content: List[str], expected_result: Optional[str]
+ tmp_path: Path, dir_content: list[str], expected_result: str | None
) -> None:
"""Test Corstone path checker valid scenario."""
path_checker = PackagePathChecker(["file1.txt", "file2.txt"], "models")
@@ -333,7 +333,7 @@ def test_corstone_path_checker_valid_path(
@pytest.mark.parametrize("system_config", [None, "system_config"])
@pytest.mark.parametrize("copy_source", [True, False])
def test_static_path_checker(
- tmp_path: Path, copy_source: bool, system_config: Optional[str]
+ tmp_path: Path, copy_source: bool, system_config: str | None
) -> None:
"""Test static path checker."""
static_checker = StaticPathChecker(
@@ -404,7 +404,7 @@ def test_corstone_300_installer(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
eula_agreement: bool,
- expected_command: List[str],
+ expected_command: list[str],
) -> None:
"""Test Corstone-300 installer."""
command_mock = MagicMock()
diff --git a/tests/test_utils_console.py b/tests/test_utils_console.py
index 36975f8..5b01403 100644
--- a/tests/test_utils_console.py
+++ b/tests/test_utils_console.py
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for console utility functions."""
+from __future__ import annotations
+
from typing import Iterable
-from typing import List
-from typing import Optional
import pytest
@@ -44,7 +44,7 @@ from mlia.utils.console import remove_ascii_codes
],
)
def test_produce_table(
- rows: Iterable, headers: Optional[List[str]], table_style: str, expected_result: str
+ rows: Iterable, headers: list[str] | None, table_style: str, expected_result: str
) -> None:
"""Test produce_table function."""
result = produce_table(rows, headers, table_style)
diff --git a/tests/test_utils_download.py b/tests/test_utils_download.py
index 4f8e2dc..28af74f 100644
--- a/tests/test_utils_download.py
+++ b/tests/test_utils_download.py
@@ -1,11 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for download functionality."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import Iterable
-from typing import Optional
from unittest.mock import MagicMock
from unittest.mock import PropertyMock
@@ -17,7 +18,7 @@ from mlia.utils.download import DownloadArtifact
def response_mock(
- content_length: Optional[str], content_chunks: Iterable[bytes]
+ content_length: str | None, content_chunks: Iterable[bytes]
) -> MagicMock:
"""Mock response object."""
mock = MagicMock(spec=requests.Response)
@@ -59,9 +60,9 @@ def test_download(
show_progress: bool,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
- content_length: Optional[str],
+ content_length: str | None,
content_chunks: Iterable[bytes],
- label: Optional[str],
+ label: str | None,
) -> None:
"""Test function download."""
monkeypatch.setattr(
@@ -97,7 +98,7 @@ def test_download(
)
def test_download_artifact_download_to(
monkeypatch: pytest.MonkeyPatch,
- content_length: Optional[str],
+ content_length: str | None,
content_chunks: Iterable[bytes],
sha256_hash: str,
expected_error: Any,
diff --git a/tests/test_utils_logging.py b/tests/test_utils_logging.py
index 75ebceb..1e212b2 100644
--- a/tests/test_utils_logging.py
+++ b/tests/test_utils_logging.py
@@ -1,12 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the logging utility functions."""
+from __future__ import annotations
+
import logging
import sys
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import Optional
import pytest
@@ -43,9 +44,9 @@ from mlia.cli.logging import create_log_handler
],
)
def test_create_log_handler(
- file_path: Optional[Path],
- stream: Optional[Any],
- log_filter: Optional[logging.Filter],
+ file_path: Path | None,
+ stream: Any | None,
+ log_filter: logging.Filter | None,
delay: bool,
expected_error: Any,
expected_class: type,
diff --git a/tests/test_utils_types.py b/tests/test_utils_types.py
index 4909efe..f7e0de8 100644
--- a/tests/test_utils_types.py
+++ b/tests/test_utils_types.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the types related utility functions."""
+from __future__ import annotations
+
from typing import Any
from typing import Iterable
-from typing import Optional
import pytest
@@ -42,7 +43,7 @@ def test_is_number(value: str, expected_result: bool) -> None:
],
)
def test_is_list(
- data: Any, cls: type, elem_num: Optional[int], expected_result: bool
+ data: Any, cls: type, elem_num: int | None, expected_result: bool
) -> None:
"""Test function is_list."""
assert is_list_of(data, cls, elem_num) == expected_result
@@ -70,8 +71,6 @@ def test_only_one_selected(options: Iterable[bool], expected_result: bool) -> No
[None, 11, 11],
],
)
-def test_parse_int(
- value: Any, default: Optional[int], expected_int: Optional[int]
-) -> None:
+def test_parse_int(value: Any, default: int | None, expected_int: int | None) -> None:
"""Test function parse_int."""
assert parse_int(value, default) == expected_int
diff --git a/tests/utils/common.py b/tests/utils/common.py
index 932343e..616a407 100644
--- a/tests/utils/common.py
+++ b/tests/utils/common.py
@@ -1,13 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Common test utils module."""
-from typing import Tuple
+from __future__ import annotations
import numpy as np
import tensorflow as tf
-def get_dataset() -> Tuple[np.ndarray, np.ndarray]:
+def get_dataset() -> tuple[np.ndarray, np.ndarray]:
"""Return sample dataset."""
mnist = tf.keras.datasets.mnist
(x_train, y_train), _ = mnist.load_data()