From f5b293d0927506c2a979a091bf0d07ecc78fa181 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 8 Sep 2022 14:24:39 +0100 Subject: MLIA-386 Simplify typing in the source code - Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a --- src/mlia/api.py | 24 +++--- src/mlia/backend/application.py | 18 ++-- src/mlia/backend/common.py | 64 +++++++------- src/mlia/backend/config.py | 19 +++-- src/mlia/backend/execution.py | 41 +++++---- src/mlia/backend/fs.py | 7 +- src/mlia/backend/manager.py | 41 +++++---- src/mlia/backend/output_consumer.py | 5 +- src/mlia/backend/proc.py | 17 ++-- src/mlia/backend/source.py | 22 ++--- src/mlia/backend/system.py | 12 +-- src/mlia/cli/commands.py | 26 +++--- src/mlia/cli/common.py | 9 +- src/mlia/cli/config.py | 7 +- src/mlia/cli/helpers.py | 28 +++--- src/mlia/cli/logging.py | 17 ++-- src/mlia/cli/main.py | 18 ++-- src/mlia/cli/options.py | 15 ++-- src/mlia/core/_typing.py | 12 --- src/mlia/core/advice_generation.py | 14 +-- src/mlia/core/advisor.py | 11 +-- src/mlia/core/common.py | 4 +- src/mlia/core/context.py | 31 ++++--- src/mlia/core/data_analysis.py | 9 +- src/mlia/core/events.py | 24 +++--- src/mlia/core/handlers.py | 10 +-- src/mlia/core/helpers.py | 15 ++-- src/mlia/core/mixins.py | 7 +- src/mlia/core/performance.py | 19 +++-- src/mlia/core/reporting.py | 99 +++++++++++----------- src/mlia/core/typing.py | 12 +++ src/mlia/core/workflow.py | 15 ++-- src/mlia/devices/ethosu/advice_generation.py | 10 +-- src/mlia/devices/ethosu/advisor.py | 32 ++++--- src/mlia/devices/ethosu/config.py | 7 +- src/mlia/devices/ethosu/data_analysis.py | 21 +++-- src/mlia/devices/ethosu/data_collection.py | 18 ++-- src/mlia/devices/ethosu/handlers.py | 7 +- src/mlia/devices/ethosu/performance.py | 35 ++++---- src/mlia/devices/ethosu/reporters.py | 19 ++--- src/mlia/devices/tosa/advisor.py | 30 +++---- src/mlia/devices/tosa/handlers.py | 7 +- src/mlia/devices/tosa/operators.py | 12 +-- src/mlia/devices/tosa/reporters.py | 7 +- src/mlia/nn/tensorflow/config.py | 36 ++++---- src/mlia/nn/tensorflow/optimizations/clustering.py | 9 +- src/mlia/nn/tensorflow/optimizations/pruning.py | 13 ++- src/mlia/nn/tensorflow/optimizations/select.py | 34 ++++---- src/mlia/nn/tensorflow/tflite_metrics.py | 20 ++--- src/mlia/nn/tensorflow/utils.py | 15 ++-- src/mlia/tools/metadata/common.py | 40 +++++---- src/mlia/tools/metadata/corstone.py | 27 +++--- src/mlia/tools/vela_wrapper.py | 33 ++++---- src/mlia/utils/console.py | 8 +- src/mlia/utils/download.py | 10 +-- src/mlia/utils/filesystem.py | 20 ++--- src/mlia/utils/logging.py | 22 ++--- src/mlia/utils/types.py | 7 +- 58 files changed, 569 insertions(+), 602 deletions(-) delete mode 100644 src/mlia/core/_typing.py create mode 100644 src/mlia/core/typing.py (limited to 'src') diff --git a/src/mlia/api.py b/src/mlia/api.py index c720b8d..878e316 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -1,19 +1,17 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for the API functions.""" +from __future__ import annotations + import logging from pathlib import Path from typing import Any -from typing import Dict -from typing import List from typing import Literal -from typing import Optional -from typing import Union -from mlia.core._typing import PathOrFileLike from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext +from mlia.core.typing import PathOrFileLike from mlia.devices.ethosu.advisor import configure_and_get_ethosu_advisor from mlia.devices.tosa.advisor import configure_and_get_tosa_advisor from mlia.utils.filesystem import get_target @@ -24,13 +22,13 @@ logger = logging.getLogger(__name__) def get_advice( target_profile: str, - model: Union[Path, str], + model: str | Path, category: Literal["all", "operators", "performance", "optimization"] = "all", - optimization_targets: Optional[List[Dict[str, Any]]] = None, - working_dir: Union[str, Path] = "mlia_output", - output: Optional[PathOrFileLike] = None, - context: Optional[ExecutionContext] = None, - backends: Optional[List[str]] = None, + optimization_targets: list[dict[str, Any]] | None = None, + working_dir: str | Path = "mlia_output", + output: PathOrFileLike | None = None, + context: ExecutionContext | None = None, + backends: list[str] | None = None, ) -> None: """Get the advice. @@ -97,8 +95,8 @@ def get_advice( def get_advisor( context: ExecutionContext, target_profile: str, - model: Union[Path, str], - output: Optional[PathOrFileLike] = None, + model: str | Path, + output: PathOrFileLike | None = None, **extra_args: Any, ) -> InferenceAdvisor: """Find appropriate advisor for the target.""" diff --git a/src/mlia/backend/application.py b/src/mlia/backend/application.py index 4b04324..a093afe 100644 --- a/src/mlia/backend/application.py +++ b/src/mlia/backend/application.py @@ -1,13 +1,13 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Application backend module.""" +from __future__ import annotations + import re from pathlib import Path from typing import Any from typing import cast -from typing import Dict from typing import List -from typing import Optional from mlia.backend.common import Backend from mlia.backend.common import ConfigurationException @@ -23,12 +23,12 @@ from mlia.backend.source import create_destination_and_install from mlia.backend.source import get_source -def get_available_application_directory_names() -> List[str]: +def get_available_application_directory_names() -> list[str]: """Return a list of directory names for all available applications.""" return [entry.name for entry in get_backend_directories("applications")] -def get_available_applications() -> List["Application"]: +def get_available_applications() -> list[Application]: """Return a list with all available applications.""" available_applications = [] for config_json in get_backend_configs("applications"): @@ -42,8 +42,8 @@ def get_available_applications() -> List["Application"]: def get_application( - application_name: str, system_name: Optional[str] = None -) -> List["Application"]: + application_name: str, system_name: str | None = None +) -> list[Application]: """Return a list of application instances with provided name.""" return [ application @@ -85,7 +85,7 @@ def remove_application(directory_name: str) -> None: remove_backend(directory_name, "applications") -def get_unique_application_names(system_name: Optional[str] = None) -> List[str]: +def get_unique_application_names(system_name: str | None = None) -> list[str]: """Extract a list of unique application names of all application available.""" return list( set( @@ -120,7 +120,7 @@ class Application(Backend): """Check if the application can run on the system passed as argument.""" return system_name in self.supported_systems - def get_details(self) -> Dict[str, Any]: + def get_details(self) -> dict[str, Any]: """Return dictionary with information about the Application instance.""" output = { "type": "application", @@ -156,7 +156,7 @@ class Application(Backend): command.params = used_params -def load_applications(config: ExtendedApplicationConfig) -> List[Application]: +def load_applications(config: ExtendedApplicationConfig) -> list[Application]: """Load application. Application configuration could contain different parameters/commands for different diff --git a/src/mlia/backend/common.py b/src/mlia/backend/common.py index e61d6b6..697c2a0 100644 --- a/src/mlia/backend/common.py +++ b/src/mlia/backend/common.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Contain all common functions for the backends.""" +from __future__ import annotations + import json import logging import re @@ -10,18 +12,12 @@ from pathlib import Path from typing import Any from typing import Callable from typing import cast -from typing import Dict from typing import Final from typing import IO from typing import Iterable -from typing import List from typing import Match from typing import NamedTuple -from typing import Optional from typing import Pattern -from typing import Tuple -from typing import Type -from typing import Union from mlia.backend.config import BackendConfig from mlia.backend.config import BaseBackendConfig @@ -74,7 +70,7 @@ def remove_backend(directory_name: str, resource_type: ResourceType) -> None: remove_resource(directory_name, resource_type) -def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig: +def load_config(config: Path | IO[bytes] | None) -> BackendConfig: """Return a loaded json file.""" if config is None: raise Exception("Unable to read config") @@ -86,7 +82,7 @@ def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig: return cast(BackendConfig, json.load(config)) -def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]: +def parse_raw_parameter(parameter: str) -> tuple[str, str | None]: """Split the parameter string in name and optional value. It manages the following cases: @@ -176,7 +172,7 @@ class Backend(ABC): def _parse_commands_and_params(self, config: BaseBackendConfig) -> None: """Parse commands and user parameters.""" - self.commands: Dict[str, Command] = {} + self.commands: dict[str, Command] = {} commands = config.get("commands") if commands: @@ -213,15 +209,15 @@ class Backend(ABC): @classmethod def _parse_params( - cls, params: Optional[UserParamsConfig], command: str - ) -> List["Param"]: + cls, params: UserParamsConfig | None, command: str + ) -> list[Param]: if not params: return [] return [cls._parse_param(p) for p in params.get(command, [])] @classmethod - def _parse_param(cls, param: UserParamConfig) -> "Param": + def _parse_param(cls, param: UserParamConfig) -> Param: """Parse a single parameter.""" name = param.get("name") if name is not None and not name: @@ -239,16 +235,14 @@ class Backend(ABC): alias=alias, ) - def _get_command_details(self) -> Dict: + def _get_command_details(self) -> dict: command_details = { command_name: command.get_details() for command_name, command in self.commands.items() } return command_details - def _get_user_param_value( - self, user_params: List[str], param: "Param" - ) -> Optional[str]: + def _get_user_param_value(self, user_params: list[str], param: Param) -> str | None: """Get the user-specified value of a parameter.""" for user_param in user_params: user_param_name, user_param_value = parse_raw_parameter(user_param) @@ -267,7 +261,7 @@ class Backend(ABC): return None @staticmethod - def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool: + def _same_parameter(user_param_name_or_alias: str, param: Param) -> bool: """Compare user parameter name with param name or alias.""" # Strip the "=" sign in the param_name. This is needed just for # comparison with the parameters passed by the user. @@ -277,10 +271,10 @@ class Backend(ABC): return user_param_name_or_alias in [param_name, param.alias] def resolved_parameters( - self, command_name: str, user_params: List[str] - ) -> List[Tuple[Optional[str], "Param"]]: + self, command_name: str, user_params: list[str] + ) -> list[tuple[str | None, Param]]: """Return list of parameters with values.""" - result: List[Tuple[Optional[str], "Param"]] = [] + result: list[tuple[str | None, Param]] = [] command = self.commands.get(command_name) if not command: return result @@ -296,9 +290,9 @@ class Backend(ABC): def build_command( self, command_name: str, - user_params: List[str], - param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str], - ) -> List[str]: + user_params: list[str], + param_resolver: Callable[[str, str, list[tuple[str | None, Param]]], str], + ) -> list[str]: """ Return a list of executable command strings. @@ -328,11 +322,11 @@ class Param: def __init__( # pylint: disable=too-many-arguments self, - name: Optional[str], + name: str | None, description: str, - values: Optional[List[str]] = None, - default_value: Optional[str] = None, - alias: Optional[str] = None, + values: list[str] | None = None, + default_value: str | None = None, + alias: str | None = None, ) -> None: """Construct a Param instance.""" if not name and not alias: @@ -345,7 +339,7 @@ class Param: self.default_value = default_value self.alias = alias - def get_details(self) -> Dict: + def get_details(self) -> dict: """Return a dictionary with all relevant information of a Param.""" return {key: value for key, value in self.__dict__.items() if value} @@ -366,7 +360,7 @@ class Command: """Class for representing a command.""" def __init__( - self, command_strings: List[str], params: Optional[List[Param]] = None + self, command_strings: list[str], params: list[Param] | None = None ) -> None: """Construct a Command instance.""" self.command_strings = command_strings @@ -404,7 +398,7 @@ class Command: "as parameter name." ) - def get_details(self) -> Dict: + def get_details(self) -> dict: """Return a dictionary with all relevant information of a Command.""" output = { "command_strings": self.command_strings, @@ -425,9 +419,9 @@ class Command: def resolve_all_parameters( str_val: str, - param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str], - command_name: Optional[str] = None, - params_values: Optional[List[Tuple[Optional[str], Param]]] = None, + param_resolver: Callable[[str, str, list[tuple[str | None, Param]]], str], + command_name: str | None = None, + params_values: list[tuple[str | None, Param]] | None = None, ) -> str: """Resolve all parameters in the string.""" if not str_val: @@ -446,7 +440,7 @@ def resolve_all_parameters( def load_application_configs( config: Any, - config_type: Type[Any], + config_type: type[Any], is_system_required: bool = True, ) -> Any: """Get one config for each system supported by the application. @@ -456,7 +450,7 @@ def load_application_configs( config with appropriate configuration. """ merged_configs = [] - supported_systems: Optional[List[NamedExecutionConfig]] = config.get( + supported_systems: list[NamedExecutionConfig] | None = config.get( "supported_systems" ) if not supported_systems: diff --git a/src/mlia/backend/config.py b/src/mlia/backend/config.py index 9a56fa9..dca53da 100644 --- a/src/mlia/backend/config.py +++ b/src/mlia/backend/config.py @@ -1,10 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Contain definition of backend configuration.""" +from __future__ import annotations + from pathlib import Path from typing import Dict from typing import List -from typing import Optional from typing import TypedDict from typing import Union @@ -12,9 +13,9 @@ from typing import Union class UserParamConfig(TypedDict, total=False): """User parameter configuration.""" - name: Optional[str] + name: str | None default_value: str - values: List[str] + values: list[str] description: str alias: str @@ -25,9 +26,9 @@ UserParamsConfig = Dict[str, List[UserParamConfig]] class ExecutionConfig(TypedDict, total=False): """Execution configuration.""" - commands: Dict[str, List[str]] + commands: dict[str, list[str]] user_params: UserParamsConfig - variables: Dict[str, str] + variables: dict[str, str] class NamedExecutionConfig(ExecutionConfig): @@ -42,25 +43,25 @@ class BaseBackendConfig(ExecutionConfig, total=False): name: str description: str config_location: Path - annotations: Dict[str, Union[str, List[str]]] + annotations: dict[str, str | list[str]] class ApplicationConfig(BaseBackendConfig, total=False): """Application configuration.""" - supported_systems: List[str] + supported_systems: list[str] class ExtendedApplicationConfig(BaseBackendConfig, total=False): """Extended application configuration.""" - supported_systems: List[NamedExecutionConfig] + supported_systems: list[NamedExecutionConfig] class SystemConfig(BaseBackendConfig, total=False): """System configuration.""" - reporting: Dict[str, Dict] + reporting: dict[str, dict] BackendItemConfig = Union[ApplicationConfig, SystemConfig] diff --git a/src/mlia/backend/execution.py b/src/mlia/backend/execution.py index 5340a47..f3fe401 100644 --- a/src/mlia/backend/execution.py +++ b/src/mlia/backend/execution.py @@ -1,12 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Application execution module.""" +from __future__ import annotations + import logging import re from typing import cast -from typing import List -from typing import Optional -from typing import Tuple from mlia.backend.application import Application from mlia.backend.application import get_application @@ -29,9 +28,9 @@ class ExecutionContext: # pylint: disable=too-few-public-methods def __init__( self, app: Application, - app_params: List[str], + app_params: list[str], system: System, - system_params: List[str], + system_params: list[str], ): """Init execution context.""" self.app = app @@ -41,8 +40,8 @@ class ExecutionContext: # pylint: disable=too-few-public-methods self.param_resolver = ParamResolver(self) - self.stdout: Optional[bytearray] = None - self.stderr: Optional[bytearray] = None + self.stdout: bytearray | None = None + self.stderr: bytearray | None = None class ParamResolver: @@ -54,16 +53,16 @@ class ParamResolver: @staticmethod def resolve_user_params( - cmd_name: Optional[str], + cmd_name: str | None, index_or_alias: str, - resolved_params: Optional[List[Tuple[Optional[str], Param]]], + resolved_params: list[tuple[str | None, Param]] | None, ) -> str: """Resolve user params.""" if not cmd_name or resolved_params is None: raise ConfigurationException("Unable to resolve user params") - param_value: Optional[str] = None - param: Optional[Param] = None + param_value: str | None = None + param: Param | None = None if index_or_alias.isnumeric(): i = int(index_or_alias) @@ -176,8 +175,8 @@ class ParamResolver: def param_matcher( self, param_name: str, - cmd_name: Optional[str], - resolved_params: Optional[List[Tuple[Optional[str], Param]]], + cmd_name: str | None, + resolved_params: list[tuple[str | None, Param]] | None, ) -> str: """Regexp to resolve a param from the param_name.""" # this pattern supports parameter names like "application.commands.run:0" and @@ -224,8 +223,8 @@ class ParamResolver: def param_resolver( self, param_name: str, - cmd_name: Optional[str] = None, - resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, + cmd_name: str | None = None, + resolved_params: list[tuple[str | None, Param]] | None = None, ) -> str: """Resolve parameter value based on current execution context.""" # Note: 'software.*' is included for backward compatibility. @@ -253,15 +252,15 @@ class ParamResolver: def __call__( self, param_name: str, - cmd_name: Optional[str] = None, - resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, + cmd_name: str | None = None, + resolved_params: list[tuple[str | None, Param]] | None = None, ) -> str: """Resolve provided parameter.""" return self.param_resolver(param_name, cmd_name, resolved_params) def validate_parameters( - backend: Backend, command_names: List[str], params: List[str] + backend: Backend, command_names: list[str], params: list[str] ) -> None: """Check parameters passed to backend.""" for param in params: @@ -301,7 +300,7 @@ def get_application_by_name_and_system( def get_application_and_system( application_name: str, system_name: str -) -> Tuple[Application, System]: +) -> tuple[Application, System]: """Return application and system by provided names.""" system = get_system(system_name) if not system: @@ -314,9 +313,9 @@ def get_application_and_system( def run_application( application_name: str, - application_params: List[str], + application_params: list[str], system_name: str, - system_params: List[str], + system_params: list[str], ) -> ExecutionContext: """Run application on the provided system.""" application, system = get_application_and_system(application_name, system_name) diff --git a/src/mlia/backend/fs.py b/src/mlia/backend/fs.py index 9fb53b1..3fce19c 100644 --- a/src/mlia/backend/fs.py +++ b/src/mlia/backend/fs.py @@ -1,11 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module to host all file system related functions.""" +from __future__ import annotations + import re import shutil from pathlib import Path from typing import Literal -from typing import Optional from mlia.utils.filesystem import get_mlia_resources @@ -58,7 +59,7 @@ def remove_resource(resource_directory: str, resource_type: ResourceType) -> Non shutil.rmtree(resource_location) -def remove_directory(directory_path: Optional[Path]) -> None: +def remove_directory(directory_path: Path | None) -> None: """Remove directory.""" if not directory_path or not directory_path.is_dir(): raise Exception("No directory path provided") @@ -66,7 +67,7 @@ def remove_directory(directory_path: Optional[Path]) -> None: shutil.rmtree(directory_path) -def recreate_directory(directory_path: Optional[Path]) -> None: +def recreate_directory(directory_path: Path | None) -> None: """Recreate directory.""" if not directory_path: raise Exception("No directory path provided") diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py index 8d8246d..c8fe0f7 100644 --- a/src/mlia/backend/manager.py +++ b/src/mlia/backend/manager.py @@ -1,17 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for backend integration.""" +from __future__ import annotations + import logging from abc import ABC from abc import abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Dict -from typing import List from typing import Literal -from typing import Optional -from typing import Set -from typing import Tuple from mlia.backend.application import get_available_applications from mlia.backend.application import install_application @@ -58,7 +55,7 @@ def get_system_name(backend: str, device_type: str) -> str: return _SUPPORTED_SYSTEMS[backend][device_type] -def is_supported(backend: str, device_type: Optional[str] = None) -> bool: +def is_supported(backend: str, device_type: str | None = None) -> bool: """Check if the backend (and optionally device type) is supported.""" if device_type is None: return backend in _SUPPORTED_SYSTEMS @@ -70,17 +67,17 @@ def is_supported(backend: str, device_type: Optional[str] = None) -> bool: return False -def supported_backends() -> List[str]: +def supported_backends() -> list[str]: """Get a list of all backends supported by the backend manager.""" return list(_SUPPORTED_SYSTEMS.keys()) -def get_all_system_names(backend: str) -> List[str]: +def get_all_system_names(backend: str) -> list[str]: """Get all systems supported by the backend.""" return list(_SUPPORTED_SYSTEMS.get(backend, {}).values()) -def get_all_application_names(backend: str) -> List[str]: +def get_all_application_names(backend: str) -> list[str]: """Get all applications supported by the backend.""" app_set = { app @@ -124,8 +121,8 @@ class ExecutionParams: application: str system: str - application_params: List[str] - system_params: List[str] + application_params: list[str] + system_params: list[str] class LogWriter(OutputConsumer): @@ -153,7 +150,7 @@ class GenericInferenceOutputParser(Base64OutputConsumer): } @property - def result(self) -> Dict: + def result(self) -> dict: """Merge the raw results and map the names to the right output names.""" merged_result = {} for raw_result in self.parsed_output: @@ -172,7 +169,7 @@ class GenericInferenceOutputParser(Base64OutputConsumer): """Return true if all expected data has been parsed.""" return set(self.result.keys()) == set(self._map.values()) - def missed_keys(self) -> Set[str]: + def missed_keys(self) -> set[str]: """Return a set of the keys that have not been found in the output.""" return set(self._map.values()) - set(self.result.keys()) @@ -184,12 +181,12 @@ class BackendRunner: """Init BackendRunner instance.""" @staticmethod - def get_installed_systems() -> List[str]: + def get_installed_systems() -> list[str]: """Get list of the installed systems.""" return [system.name for system in get_available_systems()] @staticmethod - def get_installed_applications(system: Optional[str] = None) -> List[str]: + def get_installed_applications(system: str | None = None) -> list[str]: """Get list of the installed application.""" return [ app.name @@ -205,7 +202,7 @@ class BackendRunner: """Return true if requested system installed.""" return system in self.get_installed_systems() - def systems_installed(self, systems: List[str]) -> bool: + def systems_installed(self, systems: list[str]) -> bool: """Check if all provided systems are installed.""" if not systems: return False @@ -213,7 +210,7 @@ class BackendRunner: installed_systems = self.get_installed_systems() return all(system in installed_systems for system in systems) - def applications_installed(self, applications: List[str]) -> bool: + def applications_installed(self, applications: list[str]) -> bool: """Check if all provided applications are installed.""" if not applications: return False @@ -221,7 +218,7 @@ class BackendRunner: installed_apps = self.get_installed_applications() return all(app in installed_apps for app in applications) - def all_installed(self, systems: List[str], apps: List[str]) -> bool: + def all_installed(self, systems: list[str], apps: list[str]) -> bool: """Check if all provided artifacts are installed.""" return self.systems_installed(systems) and self.applications_installed(apps) @@ -247,7 +244,7 @@ class BackendRunner: return ctx @staticmethod - def _params(name: str, params: List[str]) -> List[str]: + def _params(name: str, params: list[str]) -> list[str]: return [p for item in [(name, param) for param in params] for p in item] @@ -259,7 +256,7 @@ class GenericInferenceRunner(ABC): self.backend_runner = backend_runner def run( - self, model_info: ModelInfo, output_consumers: List[OutputConsumer] + self, model_info: ModelInfo, output_consumers: list[OutputConsumer] ) -> None: """Run generic inference for the provided device/model.""" execution_params = self.get_execution_params(model_info) @@ -284,7 +281,7 @@ class GenericInferenceRunner(ABC): ) @staticmethod - def consume_output(output: bytearray, consumers: List[OutputConsumer]) -> bytearray: + def consume_output(output: bytearray, consumers: list[OutputConsumer]) -> bytearray: """ Pass program's output to the consumers and filter it. @@ -320,7 +317,7 @@ class GenericInferenceRunnerEthosU(GenericInferenceRunner): @staticmethod def resolve_system_and_app( device_info: DeviceInfo, backend: str - ) -> Tuple[str, str]: + ) -> tuple[str, str]: """Find appropriate system and application for the provided device/backend.""" try: system_name = get_system_name(backend, device_info.device_type) diff --git a/src/mlia/backend/output_consumer.py b/src/mlia/backend/output_consumer.py index bac4186..3c3b132 100644 --- a/src/mlia/backend/output_consumer.py +++ b/src/mlia/backend/output_consumer.py @@ -1,10 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Output consumers module.""" +from __future__ import annotations + import base64 import json import re -from typing import List from typing import Protocol from typing import runtime_checkable @@ -37,7 +38,7 @@ class Base64OutputConsumer(OutputConsumer): def __init__(self) -> None: """Set up the regular expression to extract tagged strings.""" self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)") - self.parsed_output: List = [] + self.parsed_output: list = [] def feed(self, line: str) -> bool: """ diff --git a/src/mlia/backend/proc.py b/src/mlia/backend/proc.py index 911d672..7b3e92a 100644 --- a/src/mlia/backend/proc.py +++ b/src/mlia/backend/proc.py @@ -5,6 +5,8 @@ This module contains all classes and functions for dealing with Linux processes. """ +from __future__ import annotations + import datetime import logging import shlex @@ -13,9 +15,6 @@ import tempfile import time from pathlib import Path from typing import Any -from typing import List -from typing import Optional -from typing import Tuple from sh import Command from sh import CommandNotFound @@ -38,12 +37,12 @@ class ShellCommand: self, cmd: str, *args: str, - _cwd: Optional[Path] = None, + _cwd: Path | None = None, _tee: bool = True, _bg: bool = True, _out: Any = None, _err: Any = None, - _search_paths: Optional[List[Path]] = None, + _search_paths: list[Path] | None = None, ) -> RunningCommand: """Run the shell command with the given arguments. @@ -72,7 +71,7 @@ class ShellCommand: return command(_out=out, _err=err, _tee=_tee, _bg=_bg, _bg_exc=False) @classmethod - def get_stdout_stderr_paths(cls, cmd: str) -> Tuple[Path, Path]: + def get_stdout_stderr_paths(cls, cmd: str) -> tuple[Path, Path]: """Construct and returns the paths of stdout/stderr files.""" timestamp = datetime.datetime.now().timestamp() base_path = Path(tempfile.mkdtemp(prefix="mlia-", suffix=f"{timestamp}")) @@ -88,7 +87,7 @@ class ShellCommand: return stdout, stderr -def parse_command(command: str, shell: str = "bash") -> List[str]: +def parse_command(command: str, shell: str = "bash") -> list[str]: """Parse command.""" cmd, *args = shlex.split(command, posix=True) @@ -130,13 +129,13 @@ def run_and_wait( terminate_on_error: bool = False, out: Any = None, err: Any = None, -) -> Tuple[int, bytearray, bytearray]: +) -> tuple[int, bytearray, bytearray]: """ Run command and wait while it is executing. Returns a tuple: (exit_code, stdout, stderr) """ - running_cmd: Optional[RunningCommand] = None + running_cmd: RunningCommand | None = None try: running_cmd = execute_command(command, cwd, bg=True, out=out, err=err) return running_cmd.exit_code, running_cmd.stdout, running_cmd.stderr diff --git a/src/mlia/backend/source.py b/src/mlia/backend/source.py index f80a774..c951eae 100644 --- a/src/mlia/backend/source.py +++ b/src/mlia/backend/source.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Contain source related classes and functions.""" +from __future__ import annotations + import os import shutil import tarfile @@ -8,8 +10,6 @@ from abc import ABC from abc import abstractmethod from pathlib import Path from tarfile import TarFile -from typing import Optional -from typing import Union from mlia.backend.common import BACKEND_CONFIG_FILE from mlia.backend.common import ConfigurationException @@ -24,11 +24,11 @@ class Source(ABC): """Source class.""" @abstractmethod - def name(self) -> Optional[str]: + def name(self) -> str | None: """Get source name.""" @abstractmethod - def config(self) -> Optional[BackendConfig]: + def config(self) -> BackendConfig | None: """Get configuration file content.""" @abstractmethod @@ -52,7 +52,7 @@ class DirectorySource(Source): """Return name of source.""" return self.directory_path.name - def config(self) -> Optional[BackendConfig]: + def config(self) -> BackendConfig | None: """Return configuration file content.""" if not is_backend_directory(self.directory_path): raise ConfigurationException("No configuration file found") @@ -84,9 +84,9 @@ class TarArchiveSource(Source): """Create the TarArchiveSource class.""" assert isinstance(archive_path, Path) self.archive_path = archive_path - self._config: Optional[BackendConfig] = None - self._has_top_level_folder: Optional[bool] = None - self._name: Optional[str] = None + self._config: BackendConfig | None = None + self._has_top_level_folder: bool | None = None + self._name: str | None = None def _read_archive_content(self) -> None: """Read various information about archive.""" @@ -125,14 +125,14 @@ class TarArchiveSource(Source): content = archive.extractfile(config_entry) self._config = load_config(content) - def config(self) -> Optional[BackendConfig]: + def config(self) -> BackendConfig | None: """Return configuration file content.""" if self._config is None: self._read_archive_content() return self._config - def name(self) -> Optional[str]: + def name(self) -> str | None: """Return name of the source.""" if self._name is None: self._read_archive_content() @@ -171,7 +171,7 @@ class TarArchiveSource(Source): ) -def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]: +def get_source(source_path: Path) -> TarArchiveSource | DirectorySource: """Return appropriate source instance based on provided source path.""" if source_path.is_file(): return TarArchiveSource(source_path) diff --git a/src/mlia/backend/system.py b/src/mlia/backend/system.py index ff85bf3..0e51ab2 100644 --- a/src/mlia/backend/system.py +++ b/src/mlia/backend/system.py @@ -1,12 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """System backend module.""" +from __future__ import annotations + from pathlib import Path from typing import Any from typing import cast -from typing import Dict from typing import List -from typing import Tuple from mlia.backend.common import Backend from mlia.backend.common import ConfigurationException @@ -33,7 +33,7 @@ class System(Backend): def _setup_reporting(self, config: SystemConfig) -> None: self.reporting = config.get("reporting") - def run(self, command: str) -> Tuple[int, bytearray, bytearray]: + def run(self, command: str) -> tuple[int, bytearray, bytearray]: """ Run command on the system. @@ -63,7 +63,7 @@ class System(Backend): return super().__eq__(other) and self.name == other.name - def get_details(self) -> Dict[str, Any]: + def get_details(self) -> dict[str, Any]: """Return a dictionary with all relevant information of a System.""" output = { "type": "system", @@ -76,12 +76,12 @@ class System(Backend): return output -def get_available_systems_directory_names() -> List[str]: +def get_available_systems_directory_names() -> list[str]: """Return a list of directory names for all avialable systems.""" return [entry.name for entry in get_backend_directories("systems")] -def get_available_systems() -> List[System]: +def get_available_systems() -> list[System]: """Return a list with all available systems.""" available_systems = [] for config_json in get_backend_configs("systems"): diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py index 45c7c32..5dd39f9 100644 --- a/src/mlia/cli/commands.py +++ b/src/mlia/cli/commands.py @@ -16,11 +16,11 @@ be configured. Function 'setup_logging' from module >>> mlia.all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256", "path/to/model") """ +from __future__ import annotations + import logging from pathlib import Path from typing import cast -from typing import List -from typing import Optional from mlia.api import ExecutionContext from mlia.api import get_advice @@ -42,8 +42,8 @@ def all_tests( model: str, optimization_type: str = "pruning,clustering", optimization_target: str = "0.5,32", - output: Optional[PathOrFileLike] = None, - evaluate_on: Optional[List[str]] = None, + output: PathOrFileLike | None = None, + evaluate_on: list[str] | None = None, ) -> None: """Generate a full report on the input model. @@ -99,8 +99,8 @@ def all_tests( def operators( ctx: ExecutionContext, target_profile: str, - model: Optional[str] = None, - output: Optional[PathOrFileLike] = None, + model: str | None = None, + output: PathOrFileLike | None = None, supported_ops_report: bool = False, ) -> None: """Print the model's operator list. @@ -149,8 +149,8 @@ def performance( ctx: ExecutionContext, target_profile: str, model: str, - output: Optional[PathOrFileLike] = None, - evaluate_on: Optional[List[str]] = None, + output: PathOrFileLike | None = None, + evaluate_on: list[str] | None = None, ) -> None: """Print the model's performance stats. @@ -192,9 +192,9 @@ def optimization( model: str, optimization_type: str, optimization_target: str, - layers_to_optimize: Optional[List[str]] = None, - output: Optional[PathOrFileLike] = None, - evaluate_on: Optional[List[str]] = None, + layers_to_optimize: list[str] | None = None, + output: PathOrFileLike | None = None, + evaluate_on: list[str] | None = None, ) -> None: """Show the performance improvements (if any) after applying the optimizations. @@ -245,9 +245,9 @@ def optimization( def backend( backend_action: str, - path: Optional[Path] = None, + path: Path | None = None, download: bool = False, - name: Optional[str] = None, + name: str | None = None, i_agree_to_the_contained_eula: bool = False, noninteractive: bool = False, ) -> None: diff --git a/src/mlia/cli/common.py b/src/mlia/cli/common.py index 54bd457..3f60668 100644 --- a/src/mlia/cli/common.py +++ b/src/mlia/cli/common.py @@ -1,10 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI common module.""" +from __future__ import annotations + import argparse from dataclasses import dataclass from typing import Callable -from typing import List @dataclass @@ -12,8 +13,8 @@ class CommandInfo: """Command description.""" func: Callable - aliases: List[str] - opt_groups: List[Callable[[argparse.ArgumentParser], None]] + aliases: list[str] + opt_groups: list[Callable[[argparse.ArgumentParser], None]] is_default: bool = False @property @@ -22,7 +23,7 @@ class CommandInfo: return self.func.__name__ @property - def command_name_and_aliases(self) -> List[str]: + def command_name_and_aliases(self) -> list[str]: """Return list of command name and aliases.""" return [self.command_name, *self.aliases] diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py index a673230..dc28fa2 100644 --- a/src/mlia/cli/config.py +++ b/src/mlia/cli/config.py @@ -1,9 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Environment configuration functions.""" +from __future__ import annotations + import logging from functools import lru_cache -from typing import List import mlia.backend.manager as backend_manager from mlia.tools.metadata.common import DefaultInstallationManager @@ -21,7 +22,7 @@ def get_installation_manager(noninteractive: bool = False) -> InstallationManage @lru_cache -def get_available_backends() -> List[str]: +def get_available_backends() -> list[str]: """Return list of the available backends.""" available_backends = ["Vela"] @@ -42,7 +43,7 @@ def get_available_backends() -> List[str]: _CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300") -def get_default_backends() -> List[str]: +def get_default_backends() -> list[str]: """Get default backends for evaluation.""" backends = get_available_backends() diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py index 81d5a15..acec837 100644 --- a/src/mlia/cli/helpers.py +++ b/src/mlia/cli/helpers.py @@ -1,11 +1,9 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for various helper classes.""" +from __future__ import annotations + from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple from mlia.cli.options import get_target_profile_opts from mlia.core.helpers import ActionResolver @@ -17,12 +15,12 @@ from mlia.utils.types import is_list_of class CLIActionResolver(ActionResolver): """Helper class for generating cli commands.""" - def __init__(self, args: Dict[str, Any]) -> None: + def __init__(self, args: dict[str, Any]) -> None: """Init action resolver.""" self.args = args @staticmethod - def _general_optimization_command(model_path: Optional[str]) -> List[str]: + def _general_optimization_command(model_path: str | None) -> list[str]: """Return general optimization command description.""" keras_note = [] if model_path is None or not is_keras_model(model_path): @@ -40,8 +38,8 @@ class CLIActionResolver(ActionResolver): def _specific_optimization_command( model_path: str, device_opts: str, - opt_settings: List[OptimizationSettings], - ) -> List[str]: + opt_settings: list[OptimizationSettings], + ) -> list[str]: """Return specific optimization command description.""" opt_types = ",".join(opt.optimization_type for opt in opt_settings) opt_targs = ",".join(str(opt.optimization_target) for opt in opt_settings) @@ -53,7 +51,7 @@ class CLIActionResolver(ActionResolver): f"--optimization-target {opt_targs}{device_opts} {model_path}", ] - def apply_optimizations(self, **kwargs: Any) -> List[str]: + def apply_optimizations(self, **kwargs: Any) -> list[str]: """Return command details for applying optimizations.""" model_path, device_opts = self._get_model_and_device_opts() @@ -67,14 +65,14 @@ class CLIActionResolver(ActionResolver): return [] - def supported_operators_info(self) -> List[str]: + def supported_operators_info(self) -> list[str]: """Return command details for generating supported ops report.""" return [ "For guidance on supported operators, run: mlia operators " "--supported-ops-report", ] - def check_performance(self) -> List[str]: + def check_performance(self) -> list[str]: """Return command details for checking performance.""" model_path, device_opts = self._get_model_and_device_opts() if not model_path: @@ -85,7 +83,7 @@ class CLIActionResolver(ActionResolver): f"mlia performance{device_opts} {model_path}", ] - def check_operator_compatibility(self) -> List[str]: + def check_operator_compatibility(self) -> list[str]: """Return command details for op compatibility.""" model_path, device_opts = self._get_model_and_device_opts() if not model_path: @@ -96,17 +94,17 @@ class CLIActionResolver(ActionResolver): f"mlia operators{device_opts} {model_path}", ] - def operator_compatibility_details(self) -> List[str]: + def operator_compatibility_details(self) -> list[str]: """Return command details for op compatibility.""" return ["For more details, run: mlia operators --help"] - def optimization_details(self) -> List[str]: + def optimization_details(self) -> list[str]: """Return command details for optimization.""" return ["For more info, see: mlia optimization --help"] def _get_model_and_device_opts( self, separate_device_opts: bool = True - ) -> Tuple[Optional[str], str]: + ) -> tuple[str | None, str]: """Get model and device options.""" device_opts = " ".join(get_target_profile_opts(self.args)) if separate_device_opts and device_opts: diff --git a/src/mlia/cli/logging.py b/src/mlia/cli/logging.py index c5fc7bd..40f47d3 100644 --- a/src/mlia/cli/logging.py +++ b/src/mlia/cli/logging.py @@ -1,12 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI logging configuration.""" +from __future__ import annotations + import logging import sys from pathlib import Path -from typing import List -from typing import Optional -from typing import Union from mlia.utils.logging import attach_handlers from mlia.utils.logging import create_log_handler @@ -18,7 +17,7 @@ _FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" def setup_logging( - logs_dir: Optional[Union[str, Path]] = None, + logs_dir: str | Path | None = None, verbose: bool = False, log_filename: str = "mlia.log", ) -> None: @@ -49,10 +48,10 @@ def setup_logging( def _get_mlia_handlers( - logs_dir: Optional[Union[str, Path]], + logs_dir: str | Path | None, log_filename: str, verbose: bool, -) -> List[logging.Handler]: +) -> list[logging.Handler]: """Get handlers for the MLIA loggers.""" result = [] stdout_handler = create_log_handler( @@ -84,10 +83,10 @@ def _get_mlia_handlers( def _get_tools_handlers( - logs_dir: Optional[Union[str, Path]], + logs_dir: str | Path | None, log_filename: str, verbose: bool, -) -> List[logging.Handler]: +) -> list[logging.Handler]: """Get handler for the tools loggers.""" result = [] if verbose: @@ -110,7 +109,7 @@ def _get_tools_handlers( return result -def _get_log_file(logs_dir: Union[str, Path], log_filename: str) -> Path: +def _get_log_file(logs_dir: str | Path, log_filename: str) -> Path: """Get the log file path.""" logs_dir_path = Path(logs_dir) logs_dir_path.mkdir(exist_ok=True) diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index f8fc00c..0ece289 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -1,16 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI main entry point.""" +from __future__ import annotations + import argparse import logging import sys from functools import partial from inspect import signature from pathlib import Path -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple from mlia import __version__ from mlia.cli.commands import all_tests @@ -50,7 +48,7 @@ Supported targets: """.strip() -def get_commands() -> List[CommandInfo]: +def get_commands() -> list[CommandInfo]: """Return commands configuration.""" return [ CommandInfo( @@ -111,7 +109,7 @@ def get_commands() -> List[CommandInfo]: ] -def get_default_command() -> Optional[str]: +def get_default_command() -> str | None: """Get name of the default command.""" commands = get_commands() @@ -121,7 +119,7 @@ def get_default_command() -> Optional[str]: return next(iter(marked_as_default), None) -def get_possible_command_names() -> List[str]: +def get_possible_command_names() -> list[str]: """Get all possible command names including aliases.""" return [ name_or_alias @@ -151,7 +149,7 @@ def init_commands(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: def setup_context( args: argparse.Namespace, context_var_name: str = "ctx" -) -> Tuple[ExecutionContext, Dict]: +) -> tuple[ExecutionContext, dict]: """Set up context and resolve function parameters.""" ctx = ExecutionContext( working_dir=args.working_dir, @@ -252,7 +250,7 @@ def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.Argument return parser -def add_default_command_if_needed(args: List[str]) -> None: +def add_default_command_if_needed(args: list[str]) -> None: """Add default command to the list of the arguments if needed.""" default_command = get_default_command() @@ -265,7 +263,7 @@ def add_default_command_if_needed(args: List[str]) -> None: args.insert(0, default_command) -def main(argv: Optional[List[str]] = None) -> int: +def main(argv: list[str] | None = None) -> int: """Entry point of the application.""" common_parser = init_common_parser() subcommand_parser = init_subcommand_parser(common_parser) diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index 29a0d89..3f0dc1f 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -1,13 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for the CLI options.""" +from __future__ import annotations + import argparse from pathlib import Path from typing import Any from typing import Callable -from typing import Dict -from typing import List -from typing import Optional from mlia.cli.config import get_available_backends from mlia.cli.config import get_default_backends @@ -17,7 +16,7 @@ from mlia.utils.types import is_number def add_target_options( - parser: argparse.ArgumentParser, profiles_to_skip: Optional[List[str]] = None + parser: argparse.ArgumentParser, profiles_to_skip: list[str] | None = None ) -> None: """Add target specific options.""" target_profiles = get_supported_profile_names() @@ -217,8 +216,8 @@ def parse_optimization_parameters( optimization_type: str, optimization_target: str, sep: str = ",", - layers_to_optimize: Optional[List[str]] = None, -) -> List[Dict[str, Any]]: + layers_to_optimize: list[str] | None = None, +) -> list[dict[str, Any]]: """Parse provided optimization parameters.""" if not optimization_type: raise Exception("Optimization type is not provided") @@ -250,7 +249,7 @@ def parse_optimization_parameters( return optimizer_params -def get_target_profile_opts(device_args: Optional[Dict]) -> List[str]: +def get_target_profile_opts(device_args: dict | None) -> list[str]: """Get non default values passed as parameters for the target profile.""" if not device_args: return [] @@ -270,7 +269,7 @@ def get_target_profile_opts(device_args: Optional[Dict]) -> List[str]: if arg_name in args and vars(args)[arg_name] != arg_value ] - def construct_param(name: str, value: Any) -> List[str]: + def construct_param(name: str, value: Any) -> list[str]: """Construct parameter.""" if isinstance(value, list): return [str(item) for v in value for item in [name, v]] diff --git a/src/mlia/core/_typing.py b/src/mlia/core/_typing.py deleted file mode 100644 index bda995c..0000000 --- a/src/mlia/core/_typing.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module for custom type hints.""" -from pathlib import Path -from typing import Literal -from typing import TextIO -from typing import Union - - -FileLike = TextIO -PathOrFileLike = Union[str, Path, FileLike] -OutputFormat = Literal["plain_text", "csv", "json"] diff --git a/src/mlia/core/advice_generation.py b/src/mlia/core/advice_generation.py index 76cc1f2..86285fe 100644 --- a/src/mlia/core/advice_generation.py +++ b/src/mlia/core/advice_generation.py @@ -1,14 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for advice generation.""" +from __future__ import annotations + from abc import ABC from abc import abstractmethod from dataclasses import dataclass from functools import wraps from typing import Any from typing import Callable -from typing import List -from typing import Union from mlia.core.common import AdviceCategory from mlia.core.common import DataItem @@ -20,7 +20,7 @@ from mlia.core.mixins import ContextMixin class Advice: """Base class for the advice.""" - messages: List[str] + messages: list[str] @dataclass @@ -56,7 +56,7 @@ class AdviceProducer(ABC): """ @abstractmethod - def get_advice(self) -> Union[Advice, List[Advice]]: + def get_advice(self) -> Advice | list[Advice]: """Get produced advice.""" @@ -76,13 +76,13 @@ class FactBasedAdviceProducer(ContextAwareAdviceProducer): def __init__(self) -> None: """Init advice producer.""" - self.advice: List[Advice] = [] + self.advice: list[Advice] = [] - def get_advice(self) -> Union[Advice, List[Advice]]: + def get_advice(self) -> Advice | list[Advice]: """Get produced advice.""" return self.advice - def add_advice(self, messages: List[str]) -> None: + def add_advice(self, messages: list[str]) -> None: """Add advice.""" self.advice.append(Advice(messages)) diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py index 13689fa..d684241 100644 --- a/src/mlia/core/advisor.py +++ b/src/mlia/core/advisor.py @@ -1,10 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Inference advisor module.""" +from __future__ import annotations + from abc import abstractmethod from pathlib import Path from typing import cast -from typing import List from mlia.core.advice_generation import AdviceProducer from mlia.core.common import NamedEntity @@ -44,19 +45,19 @@ class DefaultInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin): ) @abstractmethod - def get_collectors(self, context: Context) -> List[DataCollector]: + def get_collectors(self, context: Context) -> list[DataCollector]: """Return list of the data collectors.""" @abstractmethod - def get_analyzers(self, context: Context) -> List[DataAnalyzer]: + def get_analyzers(self, context: Context) -> list[DataAnalyzer]: """Return list of the data analyzers.""" @abstractmethod - def get_producers(self, context: Context) -> List[AdviceProducer]: + def get_producers(self, context: Context) -> list[AdviceProducer]: """Return list of the advice producers.""" @abstractmethod - def get_events(self, context: Context) -> List[Event]: + def get_events(self, context: Context) -> list[Event]: """Return list of the startup events.""" def get_string_parameter(self, context: Context, param: str) -> str: diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py index a11bf9a..63fb324 100644 --- a/src/mlia/core/common.py +++ b/src/mlia/core/common.py @@ -5,6 +5,8 @@ This module contains common interfaces/classess shared across core module. """ +from __future__ import annotations + from abc import ABC from abc import abstractmethod from enum import auto @@ -30,7 +32,7 @@ class AdviceCategory(Flag): ALL = OPERATORS | PERFORMANCE | OPTIMIZATION @classmethod - def from_string(cls, value: str) -> "AdviceCategory": + def from_string(cls, value: str) -> AdviceCategory: """Resolve enum value from string value.""" category_names = [item.name for item in AdviceCategory] if not value or value.upper() not in category_names: diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py index 83d2f7c..a4737bb 100644 --- a/src/mlia/core/context.py +++ b/src/mlia/core/context.py @@ -7,15 +7,14 @@ Context is an object that describes advisor working environment and requested behavior (advice categories, input configuration parameters). """ +from __future__ import annotations + import logging from abc import ABC from abc import abstractmethod from pathlib import Path from typing import Any -from typing import List from typing import Mapping -from typing import Optional -from typing import Union from mlia.core.common import AdviceCategory from mlia.core.events import DefaultEventPublisher @@ -50,7 +49,7 @@ class Context(ABC): @property @abstractmethod - def event_handlers(self) -> Optional[List[EventHandler]]: + def event_handlers(self) -> list[EventHandler] | None: """Return list of the event_handlers.""" @property @@ -60,7 +59,7 @@ class Context(ABC): @property @abstractmethod - def config_parameters(self) -> Optional[Mapping[str, Any]]: + def config_parameters(self) -> Mapping[str, Any] | None: """Return configuration parameters.""" @property @@ -73,7 +72,7 @@ class Context(ABC): self, *, advice_category: AdviceCategory, - event_handlers: List[EventHandler], + event_handlers: list[EventHandler], config_parameters: Mapping[str, Any], ) -> None: """Update context parameters.""" @@ -98,14 +97,14 @@ class ExecutionContext(Context): self, *, advice_category: AdviceCategory = AdviceCategory.ALL, - config_parameters: Optional[Mapping[str, Any]] = None, - working_dir: Optional[Union[str, Path]] = None, - event_handlers: Optional[List[EventHandler]] = None, - event_publisher: Optional[EventPublisher] = None, + config_parameters: Mapping[str, Any] | None = None, + working_dir: str | Path | None = None, + event_handlers: list[EventHandler] | None = None, + event_publisher: EventPublisher | None = None, verbose: bool = False, logs_dir: str = "logs", models_dir: str = "models", - action_resolver: Optional[ActionResolver] = None, + action_resolver: ActionResolver | None = None, ) -> None: """Init execution context. @@ -151,22 +150,22 @@ class ExecutionContext(Context): self._advice_category = advice_category @property - def config_parameters(self) -> Optional[Mapping[str, Any]]: + def config_parameters(self) -> Mapping[str, Any] | None: """Return configuration parameters.""" return self._config_parameters @config_parameters.setter - def config_parameters(self, config_parameters: Optional[Mapping[str, Any]]) -> None: + def config_parameters(self, config_parameters: Mapping[str, Any] | None) -> None: """Setter for the configuration parameters.""" self._config_parameters = config_parameters @property - def event_handlers(self) -> Optional[List[EventHandler]]: + def event_handlers(self) -> list[EventHandler] | None: """Return list of the event handlers.""" return self._event_handlers @event_handlers.setter - def event_handlers(self, event_handlers: List[EventHandler]) -> None: + def event_handlers(self, event_handlers: list[EventHandler]) -> None: """Setter for the event handlers.""" self._event_handlers = event_handlers @@ -196,7 +195,7 @@ class ExecutionContext(Context): self, *, advice_category: AdviceCategory, - event_handlers: List[EventHandler], + event_handlers: list[EventHandler], config_parameters: Mapping[str, Any], ) -> None: """Update context parameters.""" diff --git a/src/mlia/core/data_analysis.py b/src/mlia/core/data_analysis.py index 6adb41e..0603425 100644 --- a/src/mlia/core/data_analysis.py +++ b/src/mlia/core/data_analysis.py @@ -1,10 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for data analysis.""" +from __future__ import annotations + from abc import ABC from abc import abstractmethod from dataclasses import dataclass -from typing import List from mlia.core.common import DataItem from mlia.core.mixins import ContextMixin @@ -29,7 +30,7 @@ class DataAnalyzer(ABC): """ @abstractmethod - def get_analyzed_data(self) -> List[DataItem]: + def get_analyzed_data(self) -> list[DataItem]: """Get analyzed data.""" @@ -59,9 +60,9 @@ class FactExtractor(ContextAwareDataAnalyzer): def __init__(self) -> None: """Init fact extractor.""" - self.facts: List[Fact] = [] + self.facts: list[Fact] = [] - def get_analyzed_data(self) -> List[DataItem]: + def get_analyzed_data(self) -> list[DataItem]: """Return list of the collected facts.""" return self.facts diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py index 0b8461b..71c86e2 100644 --- a/src/mlia/core/events.py +++ b/src/mlia/core/events.py @@ -9,6 +9,8 @@ calling application. Each component of the workflow can generate events of specific type. Application can subscribe and react to those events. """ +from __future__ import annotations + import traceback import uuid from abc import ABC @@ -19,11 +21,7 @@ from dataclasses import dataclass from dataclasses import field from functools import singledispatchmethod from typing import Any -from typing import Dict from typing import Generator -from typing import List -from typing import Optional -from typing import Tuple from mlia.core.common import DataItem @@ -41,7 +39,7 @@ class Event: """Generate unique ID for the event.""" self.event_id = str(uuid.uuid4()) - def compare_without_id(self, other: "Event") -> bool: + def compare_without_id(self, other: Event) -> bool: """Compare two events without event_id field.""" if not isinstance(other, Event) or self.__class__ != other.__class__: return False @@ -73,7 +71,7 @@ class ActionStartedEvent(Event): """ action_type: str - params: Optional[Dict] = None + params: dict | None = None @dataclass @@ -84,7 +82,7 @@ class SubActionEvent(ChildEvent): """ action_type: str - params: Optional[Dict] = None + params: dict | None = None @dataclass @@ -271,8 +269,8 @@ class EventDispatcherMetaclass(type): def __new__( cls, clsname: str, - bases: Tuple, - namespace: Dict[str, Any], + bases: tuple[type, ...], + namespace: dict[str, Any], event_handler_method_prefix: str = "on_", ) -> Any: """Create event dispatcher and link event handlers.""" @@ -321,7 +319,7 @@ class EventPublisher(ABC): """ def register_event_handlers( - self, event_handlers: Optional[List[EventHandler]] + self, event_handlers: list[EventHandler] | None ) -> None: """Register event handlers. @@ -354,7 +352,7 @@ class DefaultEventPublisher(EventPublisher): def __init__(self) -> None: """Init the event publisher.""" - self.handlers: List[EventHandler] = [] + self.handlers: list[EventHandler] = [] def register_event_handler(self, event_handler: EventHandler) -> None: """Register the event handler. @@ -374,7 +372,7 @@ class DefaultEventPublisher(EventPublisher): @contextmanager def stage( - publisher: EventPublisher, events: Tuple[Event, Event] + publisher: EventPublisher, events: tuple[Event, Event] ) -> Generator[None, None, None]: """Generate events before and after stage. @@ -390,7 +388,7 @@ def stage( @contextmanager def action( - publisher: EventPublisher, action_type: str, params: Optional[Dict] = None + publisher: EventPublisher, action_type: str, params: dict | None = None ) -> Generator[None, None, None]: """Generate events before and after action.""" action_started = ActionStartedEvent(action_type, params) diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py index e576f74..a3255ae 100644 --- a/src/mlia/core/handlers.py +++ b/src/mlia/core/handlers.py @@ -1,13 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Event handlers module.""" +from __future__ import annotations + import logging from typing import Any from typing import Callable -from typing import List -from typing import Optional -from mlia.core._typing import PathOrFileLike from mlia.core.advice_generation import Advice from mlia.core.advice_generation import AdviceEvent from mlia.core.events import ActionFinishedEvent @@ -28,6 +27,7 @@ from mlia.core.events import ExecutionStartedEvent from mlia.core.reporting import Report from mlia.core.reporting import Reporter from mlia.core.reporting import resolve_output_format +from mlia.core.typing import PathOrFileLike from mlia.utils.console import create_section_header @@ -101,14 +101,14 @@ class WorkflowEventsHandler(SystemEventsHandler): def __init__( self, formatter_resolver: Callable[[Any], Callable[[Any], Report]], - output: Optional[PathOrFileLike] = None, + output: PathOrFileLike | None = None, ) -> None: """Init event handler.""" output_format = resolve_output_format(output) self.reporter = Reporter(formatter_resolver, output_format) self.output = output - self.advice: List[Advice] = [] + self.advice: list[Advice] = [] def on_execution_started(self, event: ExecutionStartedEvent) -> None: """Handle ExecutionStarted event.""" diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py index d10ea5d..f0c4474 100644 --- a/src/mlia/core/helpers.py +++ b/src/mlia/core/helpers.py @@ -2,34 +2,35 @@ # SPDX-License-Identifier: Apache-2.0 """Module for various helper classes.""" # pylint: disable=no-self-use, unused-argument +from __future__ import annotations + from typing import Any -from typing import List class ActionResolver: """Helper class for generating actions (e.g. commands with parameters).""" - def apply_optimizations(self, **kwargs: Any) -> List[str]: + def apply_optimizations(self, **kwargs: Any) -> list[str]: """Return action details for applying optimizations.""" return [] - def supported_operators_info(self) -> List[str]: + def supported_operators_info(self) -> list[str]: """Return action details for generating supported ops report.""" return [] - def check_performance(self) -> List[str]: + def check_performance(self) -> list[str]: """Return action details for checking performance.""" return [] - def check_operator_compatibility(self) -> List[str]: + def check_operator_compatibility(self) -> list[str]: """Return action details for checking op compatibility.""" return [] - def operator_compatibility_details(self) -> List[str]: + def operator_compatibility_details(self) -> list[str]: """Return action details for getting more information about op compatibility.""" return [] - def optimization_details(self) -> List[str]: + def optimization_details(self) -> list[str]: """Return action detail for getting information about optimizations.""" return [] diff --git a/src/mlia/core/mixins.py b/src/mlia/core/mixins.py index ee03100..5ef9d66 100644 --- a/src/mlia/core/mixins.py +++ b/src/mlia/core/mixins.py @@ -1,8 +1,9 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Mixins module.""" +from __future__ import annotations + from typing import Any -from typing import Optional from mlia.core.context import Context @@ -27,8 +28,8 @@ class ParameterResolverMixin: section: str, name: str, expected: bool = True, - expected_type: Optional[type] = None, - context: Optional[Context] = None, + expected_type: type | None = None, + context: Context | None = None, ) -> Any: """Get parameter value.""" ctx = context or self.context diff --git a/src/mlia/core/performance.py b/src/mlia/core/performance.py index 5433d5c..cb12918 100644 --- a/src/mlia/core/performance.py +++ b/src/mlia/core/performance.py @@ -1,30 +1,31 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for performance estimation.""" +from __future__ import annotations + from abc import abstractmethod from typing import Callable from typing import Generic -from typing import List from typing import TypeVar -ModelType = TypeVar("ModelType") # pylint: disable=invalid-name -PerfMetricsType = TypeVar("PerfMetricsType") # pylint: disable=invalid-name +M = TypeVar("M") # model type +P = TypeVar("P") # performance metrics -class PerformanceEstimator(Generic[ModelType, PerfMetricsType]): +class PerformanceEstimator(Generic[M, P]): """Base class for the performance estimation.""" @abstractmethod - def estimate(self, model: ModelType) -> PerfMetricsType: + def estimate(self, model: M) -> P: """Estimate performance.""" def estimate_performance( - original_model: ModelType, - estimator: PerformanceEstimator[ModelType, PerfMetricsType], - model_transformations: List[Callable[[ModelType], ModelType]], -) -> List[PerfMetricsType]: + original_model: M, + estimator: PerformanceEstimator[M, P], + model_transformations: list[Callable[[M], M]], +) -> list[P]: """Estimate performance impact. This function estimates performance impact on model performance after diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py index 58a41d3..0c8fabc 100644 --- a/src/mlia/core/reporting.py +++ b/src/mlia/core/reporting.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Reporting module.""" +from __future__ import annotations + import csv import json import logging @@ -19,19 +21,14 @@ from typing import Any from typing import Callable from typing import cast from typing import Collection -from typing import Dict from typing import Generator from typing import Iterable -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union import numpy as np -from mlia.core._typing import FileLike -from mlia.core._typing import OutputFormat -from mlia.core._typing import PathOrFileLike +from mlia.core.typing import FileLike +from mlia.core.typing import OutputFormat +from mlia.core.typing import PathOrFileLike from mlia.utils.console import apply_style from mlia.utils.console import produce_table from mlia.utils.logging import LoggerWriter @@ -48,7 +45,7 @@ class Report(ABC): """Convert to json serializible format.""" @abstractmethod - def to_csv(self, **kwargs: Any) -> List[Any]: + def to_csv(self, **kwargs: Any) -> list[Any]: """Convert to csv serializible format.""" @abstractmethod @@ -62,9 +59,9 @@ class ReportItem: def __init__( self, name: str, - alias: Optional[str] = None, - value: Optional[Union[str, int, "Cell"]] = None, - nested_items: Optional[List["ReportItem"]] = None, + alias: str | None = None, + value: str | int | Cell | None = None, + nested_items: list[ReportItem] | None = None, ) -> None: """Init the report item.""" self.name = name @@ -98,9 +95,9 @@ class Format: :param style: text style """ - wrap_width: Optional[int] = None - str_fmt: Optional[Union[str, Callable[[Any], str]]] = None - style: Optional[str] = None + wrap_width: int | None = None + str_fmt: str | Callable[[Any], str] | None = None + style: str | None = None @dataclass @@ -112,7 +109,7 @@ class Cell: """ value: Any - fmt: Optional[Format] = None + fmt: Format | None = None def _apply_style(self, value: str) -> str: """Apply style to the value.""" @@ -151,7 +148,7 @@ class CountAwareCell(Cell): def __init__( self, - value: Optional[Union[int, float]], + value: int | float | None, singular: str, plural: str, format_string: str = ",d", @@ -159,7 +156,7 @@ class CountAwareCell(Cell): """Init cell instance.""" self.unit = singular if value == 1 else plural - def format_value(val: Optional[Union[int, float]]) -> str: + def format_value(val: int | float | None) -> str: """Provide string representation for the value.""" if val is None: return "" @@ -183,7 +180,7 @@ class CountAwareCell(Cell): class BytesCell(CountAwareCell): """Cell that represents memory size.""" - def __init__(self, value: Optional[int]) -> None: + def __init__(self, value: int | None) -> None: """Init cell instance.""" super().__init__(value, "byte", "bytes") @@ -191,7 +188,7 @@ class BytesCell(CountAwareCell): class CyclesCell(CountAwareCell): """Cell that represents cycles.""" - def __init__(self, value: Optional[Union[int, float]]) -> None: + def __init__(self, value: int | float | None) -> None: """Init cell instance.""" super().__init__(value, "cycle", "cycles", ",.0f") @@ -199,7 +196,7 @@ class CyclesCell(CountAwareCell): class ClockCell(CountAwareCell): """Cell that represents clock value.""" - def __init__(self, value: Optional[Union[int, float]]) -> None: + def __init__(self, value: int | float | None) -> None: """Init cell instance.""" super().__init__(value, "Hz", "Hz", ",.0f") @@ -210,9 +207,9 @@ class Column: def __init__( self, header: str, - alias: Optional[str] = None, - fmt: Optional[Format] = None, - only_for: Optional[List[str]] = None, + alias: str | None = None, + fmt: Format | None = None, + only_for: list[str] | None = None, ) -> None: """Init column definition. @@ -228,7 +225,7 @@ class Column: self.fmt = fmt self.only_for = only_for - def supports_format(self, fmt: str) -> bool: + def supports_format(self, fmt: OutputFormat) -> bool: """Return true if column should be shown.""" return not self.only_for or fmt in self.only_for @@ -236,20 +233,20 @@ class Column: class NestedReport(Report): """Report with nested items.""" - def __init__(self, name: str, alias: str, items: List[ReportItem]) -> None: + def __init__(self, name: str, alias: str, items: list[ReportItem]) -> None: """Init nested report.""" self.name = name self.alias = alias self.items = items - def to_csv(self, **kwargs: Any) -> List[Any]: + def to_csv(self, **kwargs: Any) -> list[Any]: """Convert to csv serializible format.""" result = {} def collect_item_values( item: ReportItem, - _parent: Optional[ReportItem], - _prev: Optional[ReportItem], + _parent: ReportItem | None, + _prev: ReportItem | None, _level: int, ) -> None: """Collect item values into a dictionary..""" @@ -279,13 +276,13 @@ class NestedReport(Report): def to_json(self, **kwargs: Any) -> Any: """Convert to json serializible format.""" - per_parent: Dict[Optional[ReportItem], Dict] = defaultdict(dict) + per_parent: dict[ReportItem | None, dict] = defaultdict(dict) result = per_parent[None] def collect_as_dicts( item: ReportItem, - parent: Optional[ReportItem], - _prev: Optional[ReportItem], + parent: ReportItem | None, + _prev: ReportItem | None, _level: int, ) -> None: """Collect item values as nested dictionaries.""" @@ -313,8 +310,8 @@ class NestedReport(Report): def convert_to_text( item: ReportItem, - _parent: Optional[ReportItem], - prev: Optional[ReportItem], + _parent: ReportItem | None, + prev: ReportItem | None, level: int, ) -> None: """Convert item to text representation.""" @@ -345,12 +342,12 @@ class NestedReport(Report): def _traverse( self, - items: List[ReportItem], + items: list[ReportItem], visit_item: Callable[ - [ReportItem, Optional[ReportItem], Optional[ReportItem], int], None + [ReportItem, ReportItem | None, ReportItem | None, int], None ], level: int = 1, - parent: Optional[ReportItem] = None, + parent: ReportItem | None = None, ) -> None: """Traverse through items.""" prev = None @@ -369,11 +366,11 @@ class Table(Report): def __init__( self, - columns: List[Column], + columns: list[Column], rows: Collection, name: str, - alias: Optional[str] = None, - notes: Optional[str] = None, + alias: str | None = None, + notes: str | None = None, ) -> None: """Init table definition. @@ -477,7 +474,7 @@ class Table(Report): return title + formatted_table + footer - def to_csv(self, **kwargs: Any) -> List[Any]: + def to_csv(self, **kwargs: Any) -> list[Any]: """Convert table to csv format.""" headers = [[c.header for c in self.columns if c.supports_format("csv")]] @@ -528,7 +525,7 @@ class CompoundReport(Report): This class could be used for producing multiple reports at once. """ - def __init__(self, reports: List[Report]) -> None: + def __init__(self, reports: list[Report]) -> None: """Init compound report instance.""" self.reports = reports @@ -538,13 +535,13 @@ class CompoundReport(Report): Method attempts to create compound dictionary based on provided parts. """ - result: Dict[str, Any] = {} + result: dict[str, Any] = {} for item in self.reports: result.update(item.to_json(**kwargs)) return result - def to_csv(self, **kwargs: Any) -> List[Any]: + def to_csv(self, **kwargs: Any) -> list[Any]: """Convert to csv serializible format. CSV format does support only one table. In order to be able to export @@ -592,7 +589,7 @@ class CompoundReport(Report): class CompoundFormatter: """Compound data formatter.""" - def __init__(self, formatters: List[Callable]) -> None: + def __init__(self, formatters: list[Callable]) -> None: """Init compound formatter.""" self.formatters = formatters @@ -637,7 +634,7 @@ def produce_report( data: Any, formatter: Callable[[Any], Report], fmt: OutputFormat = "plain_text", - output: Optional[PathOrFileLike] = None, + output: PathOrFileLike | None = None, **kwargs: Any, ) -> None: """Produce report based on provided data.""" @@ -679,8 +676,8 @@ class Reporter: self.output_format = output_format self.print_as_submitted = print_as_submitted - self.data: List[Tuple[Any, Callable[[Any], Report]]] = [] - self.delayed: List[Tuple[Any, Callable[[Any], Report]]] = [] + self.data: list[tuple[Any, Callable[[Any], Report]]] = [] + self.delayed: list[tuple[Any, Callable[[Any], Report]]] = [] def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None: """Submit data for the report.""" @@ -713,7 +710,7 @@ class Reporter: ) self.delayed = [] - def generate_report(self, output: Optional[PathOrFileLike]) -> None: + def generate_report(self, output: PathOrFileLike | None) -> None: """Generate report.""" already_printed = ( self.print_as_submitted @@ -735,7 +732,7 @@ class Reporter: @contextmanager def get_reporter( output_format: OutputFormat, - output: Optional[PathOrFileLike], + output: PathOrFileLike | None, formatter_resolver: Callable[[Any], Callable[[Any], Report]], ) -> Generator[Reporter, None, None]: """Get reporter and generate report.""" @@ -762,7 +759,7 @@ def _apply_format_parameters( return wrapper -def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat: +def resolve_output_format(output: PathOrFileLike | None) -> OutputFormat: """Resolve output format based on the output name.""" if isinstance(output, (str, Path)): format_from_filename = Path(output).suffix.lstrip(".") diff --git a/src/mlia/core/typing.py b/src/mlia/core/typing.py new file mode 100644 index 0000000..bda995c --- /dev/null +++ b/src/mlia/core/typing.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for custom type hints.""" +from pathlib import Path +from typing import Literal +from typing import TextIO +from typing import Union + + +FileLike = TextIO +PathOrFileLike = Union[str, Path, FileLike] +OutputFormat = Literal["plain_text", "csv", "json"] diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py index 03f3d1c..d862a86 100644 --- a/src/mlia/core/workflow.py +++ b/src/mlia/core/workflow.py @@ -5,16 +5,15 @@ This module contains implementation of the workflow executors. """ +from __future__ import annotations + import itertools from abc import ABC from abc import abstractmethod from functools import wraps from typing import Any from typing import Callable -from typing import List -from typing import Optional from typing import Sequence -from typing import Tuple from mlia.core.advice_generation import Advice from mlia.core.advice_generation import AdviceEvent @@ -57,7 +56,7 @@ STAGE_ANALYSIS = (DataAnalysisStageStartedEvent(), DataAnalysisStageFinishedEven STAGE_ADVICE = (AdviceStageStartedEvent(), AdviceStageFinishedEvent()) -def on_stage(stage_events: Tuple[Event, Event]) -> Callable: +def on_stage(stage_events: tuple[Event, Event]) -> Callable: """Mark start/finish of the stage with appropriate events.""" def wrapper(method: Callable) -> Callable: @@ -87,7 +86,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): collectors: Sequence[DataCollector], analyzers: Sequence[DataAnalyzer], producers: Sequence[AdviceProducer], - startup_events: Optional[Sequence[Event]] = None, + startup_events: Sequence[Event] | None = None, ): """Init default workflow executor. @@ -130,7 +129,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): self.publish(event) @on_stage(STAGE_COLLECTION) - def collect_data(self) -> List[DataItem]: + def collect_data(self) -> list[DataItem]: """Collect data. Run each of data collector components and return list of @@ -148,7 +147,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): return collected_data @on_stage(STAGE_ANALYSIS) - def analyze_data(self, collected_data: List[DataItem]) -> List[DataItem]: + def analyze_data(self, collected_data: list[DataItem]) -> list[DataItem]: """Analyze data. Pass each collected data item into each data analyzer and @@ -168,7 +167,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): return analyzed_data @on_stage(STAGE_ADVICE) - def produce_advice(self, analyzed_data: List[DataItem]) -> None: + def produce_advice(self, analyzed_data: list[DataItem]) -> None: """Produce advice. Pass each analyzed data item into each advice producer and diff --git a/src/mlia/devices/ethosu/advice_generation.py b/src/mlia/devices/ethosu/advice_generation.py index 0b1352b..dee1650 100644 --- a/src/mlia/devices/ethosu/advice_generation.py +++ b/src/mlia/devices/ethosu/advice_generation.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U advice generation.""" +from __future__ import annotations + from functools import singledispatchmethod -from typing import List -from typing import Union from mlia.core.advice_generation import Advice from mlia.core.advice_generation import advice_category @@ -146,8 +146,8 @@ class EthosUAdviceProducer(FactBasedAdviceProducer): @staticmethod def get_next_optimization_targets( - opt_type: List[OptimizationSettings], - ) -> List[OptimizationSettings]: + opt_type: list[OptimizationSettings], + ) -> list[OptimizationSettings]: """Get next optimization targets.""" next_targets = (item.next_target() for item in opt_type) @@ -173,7 +173,7 @@ class EthosUStaticAdviceProducer(ContextAwareAdviceProducer): def produce_advice(self, data_item: DataItem) -> None: """Do not process passed data items.""" - def get_advice(self) -> Union[Advice, List[Advice]]: + def get_advice(self) -> Advice | list[Advice]: """Return predefined advice based on category.""" advice_per_category = { AdviceCategory.PERFORMANCE: [ diff --git a/src/mlia/devices/ethosu/advisor.py b/src/mlia/devices/ethosu/advisor.py index b7b8305..be58de7 100644 --- a/src/mlia/devices/ethosu/advisor.py +++ b/src/mlia/devices/ethosu/advisor.py @@ -1,14 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U MLIA module.""" +from __future__ import annotations + from pathlib import Path from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Union -from mlia.core._typing import PathOrFileLike from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor from mlia.core.advisor import InferenceAdvisor @@ -18,6 +15,7 @@ from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer from mlia.core.data_collection import DataCollector from mlia.core.events import Event +from mlia.core.typing import PathOrFileLike from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer from mlia.devices.ethosu.config import EthosUConfiguration @@ -40,13 +38,13 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): """Return name of the advisor.""" return "ethos_u_inference_advisor" - def get_collectors(self, context: Context) -> List[DataCollector]: + def get_collectors(self, context: Context) -> list[DataCollector]: """Return list of the data collectors.""" model = self.get_model(context) device = self._get_device(context) backends = self._get_backends(context) - collectors: List[DataCollector] = [] + collectors: list[DataCollector] = [] if AdviceCategory.OPERATORS in context.advice_category: collectors.append(EthosUOperatorCompatibility(model, device)) @@ -75,20 +73,20 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): return collectors - def get_analyzers(self, context: Context) -> List[DataAnalyzer]: + def get_analyzers(self, context: Context) -> list[DataAnalyzer]: """Return list of the data analyzers.""" return [ EthosUDataAnalyzer(), ] - def get_producers(self, context: Context) -> List[AdviceProducer]: + def get_producers(self, context: Context) -> list[AdviceProducer]: """Return list of the advice producers.""" return [ EthosUAdviceProducer(), EthosUStaticAdviceProducer(), ] - def get_events(self, context: Context) -> List[Event]: + def get_events(self, context: Context) -> list[Event]: """Return list of the startup events.""" model = self.get_model(context) device = self._get_device(context) @@ -103,7 +101,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): return get_target(target_profile) - def _get_optimization_settings(self, context: Context) -> List[List[dict]]: + def _get_optimization_settings(self, context: Context) -> list[list[dict]]: """Get optimization settings.""" return self.get_parameter( # type: ignore EthosUOptimizationPerformance.name(), @@ -113,7 +111,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): context=context, ) - def _get_backends(self, context: Context) -> Optional[List[str]]: + def _get_backends(self, context: Context) -> list[str] | None: """Get list of backends.""" return self.get_parameter( # type: ignore self.name(), @@ -127,8 +125,8 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): def configure_and_get_ethosu_advisor( context: ExecutionContext, target_profile: str, - model: Union[Path, str], - output: Optional[PathOrFileLike] = None, + model: str | Path, + output: PathOrFileLike | None = None, **extra_args: Any, ) -> InferenceAdvisor: """Create and configure Ethos-U advisor.""" @@ -158,12 +156,12 @@ _DEFAULT_OPTIMIZATION_TARGETS = [ def _get_config_parameters( - model: Union[Path, str], + model: str | Path, target_profile: str, **extra_args: Any, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Get configuration parameters for the advisor.""" - advisor_parameters: Dict[str, Any] = { + advisor_parameters: dict[str, Any] = { "ethos_u_inference_advisor": { "model": model, "target_profile": target_profile, diff --git a/src/mlia/devices/ethosu/config.py b/src/mlia/devices/ethosu/config.py index cecbb27..e44dcdc 100644 --- a/src/mlia/devices/ethosu/config.py +++ b/src/mlia/devices/ethosu/config.py @@ -1,9 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U configuration.""" +from __future__ import annotations + import logging from typing import Any -from typing import Dict from mlia.devices.config import IPConfiguration from mlia.tools.vela_wrapper import resolve_compiler_config @@ -38,7 +39,7 @@ class EthosUConfiguration(IPConfiguration): ) @property - def resolved_compiler_config(self) -> Dict[str, Any]: + def resolved_compiler_config(self) -> dict[str, Any]: """Resolve compiler configuration.""" return resolve_compiler_config(self.compiler_options) @@ -63,7 +64,7 @@ def get_target(target_profile: str) -> EthosUConfiguration: return EthosUConfiguration(target_profile) -def _check_target_data_complete(target_data: Dict[str, Any]) -> None: +def _check_target_data_complete(target_data: dict[str, Any]) -> None: """Check if profile contains all needed data.""" mandatory_keys = {"target", "mac", "system_config", "memory_mode"} missing_keys = sorted(mandatory_keys - target_data.keys()) diff --git a/src/mlia/devices/ethosu/data_analysis.py b/src/mlia/devices/ethosu/data_analysis.py index 9ed32ff..8d88cf7 100644 --- a/src/mlia/devices/ethosu/data_analysis.py +++ b/src/mlia/devices/ethosu/data_analysis.py @@ -1,11 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U data analysis module.""" +from __future__ import annotations + from dataclasses import dataclass from functools import singledispatchmethod -from typing import Dict -from typing import List -from typing import Union from mlia.core.common import DataItem from mlia.core.data_analysis import Fact @@ -19,7 +18,7 @@ from mlia.tools.vela_wrapper import Operators class HasCPUOnlyOperators(Fact): """Model has CPU only operators.""" - cpu_only_ops: List[str] + cpu_only_ops: list[str] @dataclass @@ -38,8 +37,8 @@ class AllOperatorsSupportedOnNPU(Fact): class PerfMetricDiff: """Performance metric difference.""" - original_value: Union[int, float] - optimized_value: Union[int, float] + original_value: int | float + optimized_value: int | float @property def diff(self) -> float: @@ -69,15 +68,15 @@ class PerfMetricDiff: class OptimizationDiff: """Optimization performance impact.""" - opt_type: List[OptimizationSettings] - opt_diffs: Dict[str, PerfMetricDiff] + opt_type: list[OptimizationSettings] + opt_diffs: dict[str, PerfMetricDiff] @dataclass class OptimizationResults(Fact): """Optimization results.""" - diffs: List[OptimizationDiff] + diffs: list[OptimizationDiff] class EthosUDataAnalyzer(FactExtractor): @@ -113,13 +112,13 @@ class EthosUDataAnalyzer(FactExtractor): orig_memory = orig.memory_usage orig_cycles = orig.npu_cycles - diffs: List[OptimizationDiff] = [] + diffs: list[OptimizationDiff] = [] for opt_type, opt_perf_metrics in optimizations: opt = opt_perf_metrics.in_kilobytes() opt_memory = opt.memory_usage opt_cycles = opt.npu_cycles - opt_diffs: Dict[str, PerfMetricDiff] = {} + opt_diffs: dict[str, PerfMetricDiff] = {} if orig_memory and opt_memory: opt_diffs.update( diff --git a/src/mlia/devices/ethosu/data_collection.py b/src/mlia/devices/ethosu/data_collection.py index 291f1b8..6ddebac 100644 --- a/src/mlia/devices/ethosu/data_collection.py +++ b/src/mlia/devices/ethosu/data_collection.py @@ -1,10 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Data collection module for Ethos-U.""" +from __future__ import annotations + import logging from pathlib import Path -from typing import List -from typing import Optional from mlia.core.context import Context from mlia.core.data_collection import ContextAwareDataCollector @@ -59,7 +59,7 @@ class EthosUPerformance(ContextAwareDataCollector): self, model: Path, device: EthosUConfiguration, - backends: Optional[List[str]] = None, + backends: list[str] | None = None, ) -> None: """Init performance data collector.""" self.model = model @@ -87,7 +87,7 @@ class OptimizeModel: """Helper class for model optimization.""" def __init__( - self, context: Context, opt_settings: List[OptimizationSettings] + self, context: Context, opt_settings: list[OptimizationSettings] ) -> None: """Init helper.""" self.context = context @@ -115,8 +115,8 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector): self, model: Path, device: EthosUConfiguration, - optimizations: List[List[dict]], - backends: Optional[List[str]] = None, + optimizations: list[list[dict]], + backends: list[str] | None = None, ) -> None: """Init performance optimizations data collector.""" self.model = model @@ -124,7 +124,7 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector): self.optimizations = optimizations self.backends = backends - def collect_data(self) -> Optional[OptimizationPerformanceMetrics]: + def collect_data(self) -> OptimizationPerformanceMetrics | None: """Collect performance metrics for the optimizations.""" logger.info("Estimate performance ...") @@ -164,8 +164,8 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector): @staticmethod def _parse_optimization_params( - optimizations: List[List[dict]], - ) -> List[List[OptimizationSettings]]: + optimizations: list[list[dict]], + ) -> list[list[OptimizationSettings]]: """Parse optimization parameters.""" if not is_list_of(optimizations, list): raise Exception("Optimization parameters expected to be a list") diff --git a/src/mlia/devices/ethosu/handlers.py b/src/mlia/devices/ethosu/handlers.py index ee0b809..48f9a2e 100644 --- a/src/mlia/devices/ethosu/handlers.py +++ b/src/mlia/devices/ethosu/handlers.py @@ -1,12 +1,13 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Event handler.""" +from __future__ import annotations + import logging -from typing import Optional -from mlia.core._typing import PathOrFileLike from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler +from mlia.core.typing import PathOrFileLike from mlia.devices.ethosu.events import EthosUAdvisorEventHandler from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics @@ -20,7 +21,7 @@ logger = logging.getLogger(__name__) class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): """CLI event handler.""" - def __init__(self, output: Optional[PathOrFileLike] = None) -> None: + def __init__(self, output: PathOrFileLike | None = None) -> None: """Init event handler.""" super().__init__(ethos_u_formatters, output) diff --git a/src/mlia/devices/ethosu/performance.py b/src/mlia/devices/ethosu/performance.py index a73045a..e89a65a 100644 --- a/src/mlia/devices/ethosu/performance.py +++ b/src/mlia/devices/ethosu/performance.py @@ -1,13 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Performance estimation.""" +from __future__ import annotations + import logging from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import List -from typing import Optional -from typing import Tuple from typing import Union import mlia.backend.manager as backend_manager @@ -49,11 +48,11 @@ class MemorySizeType(Enum): class MemoryUsage: """Memory usage metrics.""" - sram_memory_area_size: Union[int, float] - dram_memory_area_size: Union[int, float] - unknown_memory_area_size: Union[int, float] - on_chip_flash_memory_area_size: Union[int, float] - off_chip_flash_memory_area_size: Union[int, float] + sram_memory_area_size: int | float + dram_memory_area_size: int | float + unknown_memory_area_size: int | float + on_chip_flash_memory_area_size: int | float + off_chip_flash_memory_area_size: int | float memory_size_type: MemorySizeType = MemorySizeType.BYTES _default_columns = [ @@ -64,7 +63,7 @@ class MemoryUsage: "Off chip flash used", ] - def in_kilobytes(self) -> "MemoryUsage": + def in_kilobytes(self) -> MemoryUsage: """Return memory usage with values in kilobytes.""" if self.memory_size_type == MemorySizeType.KILOBYTES: return self @@ -91,10 +90,10 @@ class PerformanceMetrics: """Performance metrics.""" device: EthosUConfiguration - npu_cycles: Optional[NPUCycles] - memory_usage: Optional[MemoryUsage] + npu_cycles: NPUCycles | None + memory_usage: MemoryUsage | None - def in_kilobytes(self) -> "PerformanceMetrics": + def in_kilobytes(self) -> PerformanceMetrics: """Return metrics with memory usage in KiB.""" if self.memory_usage is None: return PerformanceMetrics(self.device, self.npu_cycles, self.memory_usage) @@ -109,8 +108,8 @@ class OptimizationPerformanceMetrics: """Optimization performance metrics.""" original_perf_metrics: PerformanceMetrics - optimizations_perf_metrics: List[ - Tuple[List[OptimizationSettings], PerformanceMetrics] + optimizations_perf_metrics: list[ + tuple[list[OptimizationSettings], PerformanceMetrics] ] @@ -124,7 +123,7 @@ class VelaPerformanceEstimator( self.context = context self.device = device - def estimate(self, model: Union[Path, ModelConfiguration]) -> MemoryUsage: + def estimate(self, model: Path | ModelConfiguration) -> MemoryUsage: """Estimate performance.""" logger.info("Getting the memory usage metrics ...") @@ -160,7 +159,7 @@ class CorstonePerformanceEstimator( self.device = device self.backend = backend - def estimate(self, model: Union[Path, ModelConfiguration]) -> NPUCycles: + def estimate(self, model: Path | ModelConfiguration) -> NPUCycles: """Estimate performance.""" logger.info("Getting the performance metrics for '%s' ...", self.backend) logger.info( @@ -212,7 +211,7 @@ class EthosUPerformanceEstimator( self, context: Context, device: EthosUConfiguration, - backends: Optional[List[str]] = None, + backends: list[str] | None = None, ) -> None: """Init performance estimator.""" self.context = context @@ -228,7 +227,7 @@ class EthosUPerformanceEstimator( ) self.backends = set(backends) - def estimate(self, model: Union[Path, ModelConfiguration]) -> PerformanceMetrics: + def estimate(self, model: Path | ModelConfiguration) -> PerformanceMetrics: """Estimate performance.""" model_path = ( Path(model.model_path) if isinstance(model, ModelConfiguration) else model diff --git a/src/mlia/devices/ethosu/reporters.py b/src/mlia/devices/ethosu/reporters.py index b3aea24..f11430c 100644 --- a/src/mlia/devices/ethosu/reporters.py +++ b/src/mlia/devices/ethosu/reporters.py @@ -1,12 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Reports module.""" +from __future__ import annotations + from collections import defaultdict from typing import Any from typing import Callable -from typing import List -from typing import Tuple -from typing import Union from mlia.core.advice_generation import Advice from mlia.core.reporting import BytesCell @@ -52,7 +51,7 @@ def report_operators_stat(operators: Operators) -> Report: ) -def report_operators(ops: List[Operator]) -> Report: +def report_operators(ops: list[Operator]) -> Report: """Return table representation for the list of operators.""" columns = [ Column("#", only_for=["plain_text"]), @@ -235,11 +234,11 @@ def report_device_details(device: EthosUConfiguration) -> Report: ) -def metrics_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]: +def metrics_as_records(perf_metrics: list[PerformanceMetrics]) -> list[tuple]: """Convert perf metrics object into list of records.""" perf_metrics = [item.in_kilobytes() for item in perf_metrics] - def _cycles_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]: + def _cycles_as_records(perf_metrics: list[PerformanceMetrics]) -> list[tuple]: metric_map = defaultdict(list) for metrics in perf_metrics: if not metrics.npu_cycles: @@ -253,7 +252,7 @@ def metrics_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]: for name, values in metric_map.items() ] - def _memory_usage_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]: + def _memory_usage_as_records(perf_metrics: list[PerformanceMetrics]) -> list[tuple]: metric_map = defaultdict(list) for metrics in perf_metrics: if not metrics.memory_usage: @@ -276,7 +275,7 @@ def metrics_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]: if all(val > 0 for val in values) ] - def _data_beats_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]: + def _data_beats_as_records(perf_metrics: list[PerformanceMetrics]) -> list[tuple]: metric_map = defaultdict(list) for metrics in perf_metrics: if not metrics.npu_cycles: @@ -308,7 +307,7 @@ def metrics_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]: def report_perf_metrics( - perf_metrics: Union[PerformanceMetrics, List[PerformanceMetrics]] + perf_metrics: PerformanceMetrics | list[PerformanceMetrics], ) -> Report: """Return comparison table for the performance metrics.""" if isinstance(perf_metrics, PerformanceMetrics): @@ -361,7 +360,7 @@ def report_perf_metrics( ) -def report_advice(advice: List[Advice]) -> Report: +def report_advice(advice: list[Advice]) -> Report: """Generate report for the advice.""" return Table( columns=[ diff --git a/src/mlia/devices/tosa/advisor.py b/src/mlia/devices/tosa/advisor.py index 6a32b94..53dfa87 100644 --- a/src/mlia/devices/tosa/advisor.py +++ b/src/mlia/devices/tosa/advisor.py @@ -1,14 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA advisor.""" +from __future__ import annotations + from pathlib import Path from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Union -from mlia.core._typing import PathOrFileLike from mlia.core.advice_generation import AdviceCategory from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor @@ -18,6 +15,7 @@ from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer from mlia.core.data_collection import DataCollector from mlia.core.events import Event +from mlia.core.typing import PathOrFileLike from mlia.devices.tosa.advice_generation import TOSAAdviceProducer from mlia.devices.tosa.config import TOSAConfiguration from mlia.devices.tosa.data_analysis import TOSADataAnalyzer @@ -34,30 +32,30 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor): """Return name of the advisor.""" return "tosa_inference_advisor" - def get_collectors(self, context: Context) -> List[DataCollector]: + def get_collectors(self, context: Context) -> list[DataCollector]: """Return list of the data collectors.""" model = self.get_model(context) - collectors: List[DataCollector] = [] + collectors: list[DataCollector] = [] if AdviceCategory.OPERATORS in context.advice_category: collectors.append(TOSAOperatorCompatibility(model)) return collectors - def get_analyzers(self, context: Context) -> List[DataAnalyzer]: + def get_analyzers(self, context: Context) -> list[DataAnalyzer]: """Return list of the data analyzers.""" return [ TOSADataAnalyzer(), ] - def get_producers(self, context: Context) -> List[AdviceProducer]: + def get_producers(self, context: Context) -> list[AdviceProducer]: """Return list of the advice producers.""" return [ TOSAAdviceProducer(), ] - def get_events(self, context: Context) -> List[Event]: + def get_events(self, context: Context) -> list[Event]: """Return list of the startup events.""" model = self.get_model(context) target_profile = self.get_target_profile(context) @@ -70,9 +68,9 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor): def configure_and_get_tosa_advisor( context: ExecutionContext, target_profile: str, - model: Union[Path, str], - output: Optional[PathOrFileLike] = None, - **_extra_args: Any + model: str | Path, + output: PathOrFileLike | None = None, + **_extra_args: Any, ) -> InferenceAdvisor: """Create and configure TOSA advisor.""" if context.event_handlers is None: @@ -84,11 +82,9 @@ def configure_and_get_tosa_advisor( return TOSAInferenceAdvisor() -def _get_config_parameters( - model: Union[Path, str], target_profile: str -) -> Dict[str, Any]: +def _get_config_parameters(model: str | Path, target_profile: str) -> dict[str, Any]: """Get configuration parameters for the advisor.""" - advisor_parameters: Dict[str, Any] = { + advisor_parameters: dict[str, Any] = { "tosa_inference_advisor": { "model": str(model), "target_profile": target_profile, diff --git a/src/mlia/devices/tosa/handlers.py b/src/mlia/devices/tosa/handlers.py index 00c18c5..5f015c4 100644 --- a/src/mlia/devices/tosa/handlers.py +++ b/src/mlia/devices/tosa/handlers.py @@ -2,12 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 """TOSA Advisor event handlers.""" # pylint: disable=R0801 +from __future__ import annotations + import logging -from typing import Optional -from mlia.core._typing import PathOrFileLike from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler +from mlia.core.typing import PathOrFileLike from mlia.devices.tosa.events import TOSAAdvisorEventHandler from mlia.devices.tosa.events import TOSAAdvisorStartedEvent from mlia.devices.tosa.operators import TOSACompatibilityInfo @@ -19,7 +20,7 @@ logger = logging.getLogger(__name__) class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler): """Event handler for TOSA advisor.""" - def __init__(self, output: Optional[PathOrFileLike] = None) -> None: + def __init__(self, output: PathOrFileLike | None = None) -> None: """Init event handler.""" super().__init__(tosa_formatters, output) diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py index 4f3df10..6cfb87f 100644 --- a/src/mlia/devices/tosa/operators.py +++ b/src/mlia/devices/tosa/operators.py @@ -1,14 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Operators module.""" +from __future__ import annotations + from dataclasses import dataclass from typing import Any from typing import cast -from typing import List -from typing import Optional from typing import Protocol -from mlia.core._typing import PathOrFileLike +from mlia.core.typing import PathOrFileLike class TOSAChecker(Protocol): @@ -17,7 +17,7 @@ class TOSAChecker(Protocol): def is_tosa_compatible(self) -> bool: """Return true if model is TOSA compatible.""" - def _get_tosa_compatibility_for_ops(self) -> List[Any]: + def _get_tosa_compatibility_for_ops(self) -> list[Any]: """Return list of operators.""" @@ -35,7 +35,7 @@ class TOSACompatibilityInfo: """Models' TOSA compatibility information.""" tosa_compatible: bool - operators: List[Operator] + operators: list[Operator] def get_tosa_compatibility_info( @@ -59,7 +59,7 @@ def get_tosa_compatibility_info( return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) -def get_tosa_checker(tflite_model_path: PathOrFileLike) -> Optional[TOSAChecker]: +def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: """Return instance of the TOSA checker.""" try: import tosa_checker as tc # pylint: disable=import-outside-toplevel diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py index 8fba95c..4363793 100644 --- a/src/mlia/devices/tosa/reporters.py +++ b/src/mlia/devices/tosa/reporters.py @@ -1,9 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Reports module.""" +from __future__ import annotations + from typing import Any from typing import Callable -from typing import List from mlia.core.advice_generation import Advice from mlia.core.reporting import Cell @@ -30,7 +31,7 @@ def report_device(device: TOSAConfiguration) -> Report: ) -def report_advice(advice: List[Advice]) -> Report: +def report_advice(advice: list[Advice]) -> Report: """Generate report for the advice.""" return Table( columns=[ @@ -43,7 +44,7 @@ def report_advice(advice: List[Advice]) -> Report: ) -def report_tosa_operators(ops: List[Operator]) -> Report: +def report_tosa_operators(ops: list[Operator]) -> Report: """Generate report for the operators.""" return Table( [ diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index d3235d7..6ee32e7 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -1,12 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Model configuration.""" +from __future__ import annotations + import logging from pathlib import Path from typing import cast -from typing import Dict from typing import List -from typing import Union import tensorflow as tf @@ -24,17 +24,17 @@ logger = logging.getLogger(__name__) class ModelConfiguration: """Base class for model configuration.""" - def __init__(self, model_path: Union[str, Path]) -> None: + def __init__(self, model_path: str | Path) -> None: """Init model configuration instance.""" self.model_path = str(model_path) def convert_to_tflite( - self, tflite_model_path: Union[str, Path], quantized: bool = False - ) -> "TFLiteModel": + self, tflite_model_path: str | Path, quantized: bool = False + ) -> TFLiteModel: """Convert model to TFLite format.""" raise NotImplementedError() - def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel": + def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel: """Convert model to Keras format.""" raise NotImplementedError() @@ -50,8 +50,8 @@ class KerasModel(ModelConfiguration): return tf.keras.models.load_model(self.model_path) def convert_to_tflite( - self, tflite_model_path: Union[str, Path], quantized: bool = False - ) -> "TFLiteModel": + self, tflite_model_path: str | Path, quantized: bool = False + ) -> TFLiteModel: """Convert model to TFLite format.""" logger.info("Converting Keras to TFLite ...") @@ -65,7 +65,7 @@ class KerasModel(ModelConfiguration): return TFLiteModel(tflite_model_path) - def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel": + def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel: """Convert model to Keras format.""" return self @@ -73,14 +73,14 @@ class KerasModel(ModelConfiguration): class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method """TFLite model configuration.""" - def input_details(self) -> List[Dict]: + def input_details(self) -> list[dict]: """Get model's input details.""" interpreter = tf.lite.Interpreter(model_path=self.model_path) - return cast(List[Dict], interpreter.get_input_details()) + return cast(List[dict], interpreter.get_input_details()) def convert_to_tflite( - self, tflite_model_path: Union[str, Path], quantized: bool = False - ) -> "TFLiteModel": + self, tflite_model_path: str | Path, quantized: bool = False + ) -> TFLiteModel: """Convert model to TFLite format.""" return self @@ -92,8 +92,8 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method """ def convert_to_tflite( - self, tflite_model_path: Union[str, Path], quantized: bool = False - ) -> "TFLiteModel": + self, tflite_model_path: str | Path, quantized: bool = False + ) -> TFLiteModel: """Convert model to TFLite format.""" converted_model = convert_tf_to_tflite(self.model_path, quantized) save_tflite_model(converted_model, tflite_model_path) @@ -101,7 +101,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method return TFLiteModel(tflite_model_path) -def get_model(model: Union[Path, str]) -> "ModelConfiguration": +def get_model(model: str | Path) -> ModelConfiguration: """Return the model object.""" if is_tflite_model(model): return TFLiteModel(model) @@ -118,7 +118,7 @@ def get_model(model: Union[Path, str]) -> "ModelConfiguration": ) -def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel": +def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel: """Convert input model to TFLite and returns TFLiteModel object.""" tflite_model_path = ctx.get_model_path("converted_model.tflite") converted_model = get_model(model) @@ -126,7 +126,7 @@ def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel": return converted_model.convert_to_tflite(tflite_model_path, True) -def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel": +def get_keras_model(model: str | Path, ctx: Context) -> KerasModel: """Convert input model to Keras and returns KerasModel object.""" keras_model_path = ctx.get_model_path("converted_model.h5") converted_model = get_model(model) diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py index 16d9e4b..4aaa33e 100644 --- a/src/mlia/nn/tensorflow/optimizations/clustering.py +++ b/src/mlia/nn/tensorflow/optimizations/clustering.py @@ -7,11 +7,10 @@ In order to do this, we need to have a base model and corresponding training dat We also have to specify a subset of layers we want to cluster. For more details, please refer to the documentation for TensorFlow Model Optimization Toolkit. """ +from __future__ import annotations + from dataclasses import dataclass from typing import Any -from typing import Dict -from typing import List -from typing import Optional import tensorflow as tf import tensorflow_model_optimization as tfmot @@ -28,7 +27,7 @@ class ClusteringConfiguration(OptimizerConfiguration): """Clustering configuration.""" optimization_target: int - layers_to_optimize: Optional[List[str]] = None + layers_to_optimize: list[str] | None = None def __str__(self) -> str: """Return string representation of the configuration.""" @@ -61,7 +60,7 @@ class Clusterer(Optimizer): """Return string representation of the optimization config.""" return str(self.optimizer_configuration) - def _setup_clustering_params(self) -> Dict[str, Any]: + def _setup_clustering_params(self) -> dict[str, Any]: CentroidInitialization = tfmot.clustering.keras.CentroidInitialization return { "number_of_clusters": self.optimizer_configuration.optimization_target, diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py index 0a3fda5..41954b9 100644 --- a/src/mlia/nn/tensorflow/optimizations/pruning.py +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -7,11 +7,10 @@ In order to do this, we need to have a base model and corresponding training dat We also have to specify a subset of layers we want to prune. For more details, please refer to the documentation for TensorFlow Model Optimization Toolkit. """ +from __future__ import annotations + import typing from dataclasses import dataclass -from typing import List -from typing import Optional -from typing import Tuple import numpy as np import tensorflow as tf @@ -29,9 +28,9 @@ class PruningConfiguration(OptimizerConfiguration): """Pruning configuration.""" optimization_target: float - layers_to_optimize: Optional[List[str]] = None - x_train: Optional[np.ndarray] = None - y_train: Optional[np.ndarray] = None + layers_to_optimize: list[str] | None = None + x_train: np.ndarray | None = None + y_train: np.ndarray | None = None batch_size: int = 1 num_epochs: int = 1 @@ -74,7 +73,7 @@ class Pruner(Optimizer): """Return string representation of the optimization config.""" return str(self.optimizer_configuration) - def _mock_train_data(self) -> Tuple[np.ndarray, np.ndarray]: + def _mock_train_data(self) -> tuple[np.ndarray, np.ndarray]: # get rid of the batch_size dimension in input and output shape input_shape = tuple(x for x in self.model.input_shape if x is not None) output_shape = tuple(x for x in self.model.output_shape if x is not None) diff --git a/src/mlia/nn/tensorflow/optimizations/select.py b/src/mlia/nn/tensorflow/optimizations/select.py index 1b0c755..d4a8ea4 100644 --- a/src/mlia/nn/tensorflow/optimizations/select.py +++ b/src/mlia/nn/tensorflow/optimizations/select.py @@ -1,12 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for optimization selection.""" +from __future__ import annotations + import math -from typing import List from typing import NamedTuple -from typing import Optional -from typing import Tuple -from typing import Union import tensorflow as tf @@ -25,14 +23,14 @@ class OptimizationSettings(NamedTuple): """Optimization settings.""" optimization_type: str - optimization_target: Union[int, float] - layers_to_optimize: Optional[List[str]] + optimization_target: int | float + layers_to_optimize: list[str] | None @staticmethod def create_from( - optimizer_params: List[Tuple[str, float]], - layers_to_optimize: Optional[List[str]] = None, - ) -> List["OptimizationSettings"]: + optimizer_params: list[tuple[str, float]], + layers_to_optimize: list[str] | None = None, + ) -> list[OptimizationSettings]: """Create optimization settings from the provided parameters.""" return [ OptimizationSettings( @@ -47,7 +45,7 @@ class OptimizationSettings(NamedTuple): """Return string representation.""" return f"{self.optimization_type}: {self.optimization_target}" - def next_target(self) -> "OptimizationSettings": + def next_target(self) -> OptimizationSettings: """Return next optimization target.""" if self.optimization_type == "pruning": next_target = round(min(self.optimization_target + 0.1, 0.9), 2) @@ -75,7 +73,7 @@ class MultiStageOptimizer(Optimizer): def __init__( self, model: tf.keras.Model, - optimizations: List[OptimizerConfiguration], + optimizations: list[OptimizerConfiguration], ) -> None: """Init MultiStageOptimizer instance.""" self.model = model @@ -98,10 +96,8 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( - model: Union[tf.keras.Model, KerasModel], - config: Union[ - OptimizerConfiguration, OptimizationSettings, List[OptimizationSettings] - ], + model: tf.keras.Model | KerasModel, + config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -123,7 +119,7 @@ def get_optimizer( def _get_optimizer( model: tf.keras.Model, - optimization_settings: Union[OptimizationSettings, List[OptimizationSettings]], + optimization_settings: OptimizationSettings | list[OptimizationSettings], ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] @@ -145,8 +141,8 @@ def _get_optimizer( def _get_optimizer_configuration( optimization_type: str, - optimization_target: Union[int, float], - layers_to_optimize: Optional[List[str]] = None, + optimization_target: int | float, + layers_to_optimize: list[str] | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -169,7 +165,7 @@ def _get_optimizer_configuration( def _check_optimizer_params( - optimization_type: str, optimization_target: Union[int, float] + optimization_type: str, optimization_target: int | float ) -> None: """Check optimizer params.""" if not optimization_target: diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py index 3f41487..0af7500 100644 --- a/src/mlia/nn/tensorflow/tflite_metrics.py +++ b/src/mlia/nn/tensorflow/tflite_metrics.py @@ -8,13 +8,13 @@ These metrics include: * Unique weights (clusters) (per layer) * gzip compression ratio """ +from __future__ import annotations + import os import typing from enum import Enum from pprint import pprint from typing import Any -from typing import List -from typing import Optional import numpy as np import tensorflow as tf @@ -42,7 +42,7 @@ def calculate_num_unique_weights(weights: np.ndarray) -> int: return num_unique_weights -def calculate_num_unique_weights_per_axis(weights: np.ndarray, axis: int) -> List[int]: +def calculate_num_unique_weights_per_axis(weights: np.ndarray, axis: int) -> list[int]: """Calculate unique weights per quantization axis.""" # Make quantized dimension the first dimension weights_trans = np.swapaxes(weights, 0, axis) @@ -74,7 +74,7 @@ class SparsityAccumulator: def calculate_sparsity( - weights: np.ndarray, accumulator: Optional[SparsityAccumulator] = None + weights: np.ndarray, accumulator: SparsityAccumulator | None = None ) -> float: """ Calculate the sparsity for the given weights. @@ -110,9 +110,7 @@ class TFLiteMetrics: * File compression via gzip """ - def __init__( - self, tflite_file: str, ignore_list: Optional[List[str]] = None - ) -> None: + def __init__(self, tflite_file: str, ignore_list: list[str] | None = None) -> None: """Load the TFLite file and filter layers.""" self.tflite_file = tflite_file if ignore_list is None: @@ -159,7 +157,7 @@ class TFLiteMetrics: acc(self.get_tensor(details)) return acc.sparsity() - def calc_num_clusters_per_axis(self, details: dict) -> List[int]: + def calc_num_clusters_per_axis(self, details: dict) -> list[int]: """Calculate number of clusters per axis.""" quant_params = details["quantization_parameters"] per_axis = len(quant_params["zero_points"]) > 1 @@ -178,14 +176,14 @@ class TFLiteMetrics: aggregation_func = self.calc_num_clusters_per_axis elif mode == ReportClusterMode.NUM_CLUSTERS_MIN_MAX: - def cluster_min_max(details: dict) -> List[int]: + def cluster_min_max(details: dict) -> list[int]: num_clusters = self.calc_num_clusters_per_axis(details) return [min(num_clusters), max(num_clusters)] aggregation_func = cluster_min_max elif mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM: - def cluster_hist(details: dict) -> List[int]: + def cluster_hist(details: dict) -> list[int]: num_clusters = self.calc_num_clusters_per_axis(details) max_num = max(num_clusters) hist = [0] * (max_num) @@ -289,7 +287,7 @@ class TFLiteMetrics: print(f"- {self._prettify_name(name)}: {nums}") @staticmethod - def _print_in_outs(ios: List[dict], verbose: bool = False) -> None: + def _print_in_outs(ios: list[dict], verbose: bool = False) -> None: for item in ios: if verbose: pprint(item) diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py index b1034d9..6250f56 100644 --- a/src/mlia/nn/tensorflow/utils.py +++ b/src/mlia/nn/tensorflow/utils.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 """Collection of useful functions for optimizations.""" +from __future__ import annotations + import logging from pathlib import Path from typing import Callable from typing import Iterable -from typing import Union import numpy as np import tensorflow as tf @@ -101,21 +102,19 @@ def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter: return tflite_model -def save_keras_model(model: tf.keras.Model, save_path: Union[str, Path]) -> None: +def save_keras_model(model: tf.keras.Model, save_path: str | Path) -> None: """Save Keras model at provided path.""" # Checkpoint: saving the optimizer is necessary. model.save(save_path, include_optimizer=True) -def save_tflite_model( - model: tf.lite.TFLiteConverter, save_path: Union[str, Path] -) -> None: +def save_tflite_model(model: tf.lite.TFLiteConverter, save_path: str | Path) -> None: """Save TFLite model at provided path.""" with open(save_path, "wb") as file: file.write(model) -def is_tflite_model(model: Union[Path, str]) -> bool: +def is_tflite_model(model: str | Path) -> bool: """Check if model type is supported by TFLite API. TFLite model is indicated by the model file extension .tflite @@ -124,7 +123,7 @@ def is_tflite_model(model: Union[Path, str]) -> bool: return model_path.suffix == ".tflite" -def is_keras_model(model: Union[Path, str]) -> bool: +def is_keras_model(model: str | Path) -> bool: """Check if model type is supported by Keras API. Keras model is indicated by: @@ -139,7 +138,7 @@ def is_keras_model(model: Union[Path, str]) -> bool: return model_path.suffix in (".h5", ".hdf5") -def is_tf_model(model: Union[Path, str]) -> bool: +def is_tf_model(model: str | Path) -> bool: """Check if model type is supported by TensorFlow API. TensorFlow model is indicated if its directory (meaning saved model) diff --git a/src/mlia/tools/metadata/common.py b/src/mlia/tools/metadata/common.py index 924e870..32da4a4 100644 --- a/src/mlia/tools/metadata/common.py +++ b/src/mlia/tools/metadata/common.py @@ -1,14 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for installation process.""" +from __future__ import annotations + import logging from abc import ABC from abc import abstractmethod from dataclasses import dataclass from pathlib import Path from typing import Callable -from typing import List -from typing import Optional from typing import Union from mlia.utils.misc import yes @@ -100,7 +100,7 @@ class SupportsInstallTypeFilter: class SearchByNameFilter: """Filter installation by name.""" - def __init__(self, backend_name: Optional[str]) -> None: + def __init__(self, backend_name: str | None) -> None: """Init filter.""" self.backend_name = backend_name @@ -113,12 +113,12 @@ class InstallationManager(ABC): """Helper class for managing installations.""" @abstractmethod - def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None: + def install_from(self, backend_path: Path, backend_name: str | None) -> None: """Install backend from the local directory.""" @abstractmethod def download_and_install( - self, backend_name: Optional[str], eula_agreement: bool + self, backend_name: str | None, eula_agreement: bool ) -> None: """Download and install backends.""" @@ -134,9 +134,9 @@ class InstallationManager(ABC): class InstallationFiltersMixin: """Mixin for filtering installation based on different conditions.""" - installations: List[Installation] + installations: list[Installation] - def filter_by(self, *filters: InstallationFilter) -> List[Installation]: + def filter_by(self, *filters: InstallationFilter) -> list[Installation]: """Filter installations.""" return [ installation @@ -145,8 +145,8 @@ class InstallationFiltersMixin: ] def could_be_installed_from( - self, backend_path: Path, backend_name: Optional[str] - ) -> List[Installation]: + self, backend_path: Path, backend_name: str | None + ) -> list[Installation]: """Return installations that could be installed from provided directory.""" return self.filter_by( SupportsInstallTypeFilter(InstallFromPath(backend_path)), @@ -154,8 +154,8 @@ class InstallationFiltersMixin: ) def could_be_downloaded_and_installed( - self, backend_name: Optional[str] = None - ) -> List[Installation]: + self, backend_name: str | None = None + ) -> list[Installation]: """Return installations that could be downloaded and installed.""" return self.filter_by( SupportsInstallTypeFilter(DownloadAndInstall()), @@ -163,15 +163,13 @@ class InstallationFiltersMixin: ReadyForInstallationFilter(), ) - def already_installed( - self, backend_name: Optional[str] = None - ) -> List[Installation]: + def already_installed(self, backend_name: str | None = None) -> list[Installation]: """Return list of backends that are already installed.""" return self.filter_by( AlreadyInstalledFilter(), SearchByNameFilter(backend_name) ) - def ready_for_installation(self) -> List[Installation]: + def ready_for_installation(self) -> list[Installation]: """Return list of the backends that could be installed.""" return self.filter_by(ReadyForInstallationFilter()) @@ -180,15 +178,15 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin): """Interactive installation manager.""" def __init__( - self, installations: List[Installation], noninteractive: bool = False + self, installations: list[Installation], noninteractive: bool = False ) -> None: """Init the manager.""" self.installations = installations self.noninteractive = noninteractive def choose_installation_for_path( - self, backend_path: Path, backend_name: Optional[str] - ) -> Optional[Installation]: + self, backend_path: Path, backend_name: str | None + ) -> Installation | None: """Check available installation and select one if possible.""" installs = self.could_be_installed_from(backend_path, backend_name) @@ -220,7 +218,7 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin): return installation - def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None: + def install_from(self, backend_path: Path, backend_name: str | None) -> None: """Install from the provided directory.""" installation = self.choose_installation_for_path(backend_path, backend_name) @@ -234,7 +232,7 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin): self._install(installation, InstallFromPath(backend_path), prompt) def download_and_install( - self, backend_name: Optional[str] = None, eula_agreement: bool = True + self, backend_name: str | None = None, eula_agreement: bool = True ) -> None: """Download and install available backends.""" installations = self.could_be_downloaded_and_installed(backend_name) @@ -269,7 +267,7 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin): @staticmethod def _print_installation_list( - header: str, installations: List[Installation], new_section: bool = False + header: str, installations: list[Installation], new_section: bool = False ) -> None: """Print list of the installations.""" logger.info("%s%s\n", "\n" if new_section else "", header) diff --git a/src/mlia/tools/metadata/corstone.py b/src/mlia/tools/metadata/corstone.py index 023369c..feef7ad 100644 --- a/src/mlia/tools/metadata/corstone.py +++ b/src/mlia/tools/metadata/corstone.py @@ -6,6 +6,8 @@ The import of subprocess module raises a B404 bandit error. MLIA usage of subprocess is needed and can be considered safe hence disabling the security check. """ +from __future__ import annotations + import logging import platform import subprocess # nosec @@ -14,7 +16,6 @@ from dataclasses import dataclass from pathlib import Path from typing import Callable from typing import Iterable -from typing import List from typing import Optional import mlia.backend.manager as backend_manager @@ -40,7 +41,7 @@ class BackendInfo: backend_path: Path copy_source: bool = True - system_config: Optional[str] = None + system_config: str | None = None PathChecker = Callable[[Path], Optional[BackendInfo]] @@ -55,10 +56,10 @@ class BackendMetadata: name: str, description: str, system_config: str, - apps_resources: List[str], + apps_resources: list[str], fvp_dir_name: str, - download_artifact: Optional[DownloadArtifact], - supported_platforms: Optional[List[str]] = None, + download_artifact: DownloadArtifact | None, + supported_platforms: list[str] | None = None, ) -> None: """ Initialize BackendMetadata. @@ -100,7 +101,7 @@ class BackendInstallation(Installation): backend_runner: backend_manager.BackendRunner, metadata: BackendMetadata, path_checker: PathChecker, - backend_installer: Optional[BackendInstaller], + backend_installer: BackendInstaller | None, ) -> None: """Init the backend installation.""" self.backend_runner = backend_runner @@ -209,13 +210,13 @@ class PackagePathChecker: """Package path checker.""" def __init__( - self, expected_files: List[str], backend_subfolder: Optional[str] = None + self, expected_files: list[str], backend_subfolder: str | None = None ) -> None: """Init the path checker.""" self.expected_files = expected_files self.backend_subfolder = backend_subfolder - def __call__(self, backend_path: Path) -> Optional[BackendInfo]: + def __call__(self, backend_path: Path) -> BackendInfo | None: """Check if directory contains all expected files.""" resolved_paths = (backend_path / file for file in self.expected_files) if not all_files_exist(resolved_paths): @@ -238,9 +239,9 @@ class StaticPathChecker: def __init__( self, static_backend_path: Path, - expected_files: List[str], + expected_files: list[str], copy_source: bool = False, - system_config: Optional[str] = None, + system_config: str | None = None, ) -> None: """Init static path checker.""" self.static_backend_path = static_backend_path @@ -248,7 +249,7 @@ class StaticPathChecker: self.copy_source = copy_source self.system_config = system_config - def __call__(self, backend_path: Path) -> Optional[BackendInfo]: + def __call__(self, backend_path: Path) -> BackendInfo | None: """Check if directory equals static backend path with all expected files.""" if backend_path != self.static_backend_path: return None @@ -271,7 +272,7 @@ class CompoundPathChecker: """Init compound path checker.""" self.path_checkers = path_checkers - def __call__(self, backend_path: Path) -> Optional[BackendInfo]: + def __call__(self, backend_path: Path) -> BackendInfo | None: """Iterate over checkers and return first non empty backend info.""" first_resolved_backend_info = ( backend_info @@ -401,7 +402,7 @@ def get_corstone_310_installation() -> Installation: return corstone_310 -def get_corstone_installations() -> List[Installation]: +def get_corstone_installations() -> list[Installation]: """Get Corstone installations.""" return [ get_corstone_300_installation(), diff --git a/src/mlia/tools/vela_wrapper.py b/src/mlia/tools/vela_wrapper.py index 7225797..47c15e9 100644 --- a/src/mlia/tools/vela_wrapper.py +++ b/src/mlia/tools/vela_wrapper.py @@ -1,18 +1,15 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Vela wrapper module.""" +from __future__ import annotations + import itertools import logging import sys from dataclasses import dataclass from pathlib import Path from typing import Any -from typing import Dict -from typing import List from typing import Literal -from typing import Optional -from typing import Tuple -from typing import Union import numpy as np from ethosu.vela.architecture_features import ArchitectureFeatures @@ -70,7 +67,7 @@ class NpuSupported: """Operator's npu supported attribute.""" supported: bool - reasons: List[Tuple[str, str]] + reasons: list[tuple[str, str]] @dataclass @@ -95,7 +92,7 @@ class Operator: class Operators: """Model's operators.""" - ops: List[Operator] + ops: list[Operator] @property def npu_supported_ratio(self) -> float: @@ -150,7 +147,7 @@ class OptimizedModel: compiler_options: CompilerOptions scheduler_options: SchedulerOptions - def save(self, output_filename: Union[str, Path]) -> None: + def save(self, output_filename: str | Path) -> None: """Save instance of the optimized model to the file.""" write_tflite(self.nng, output_filename) @@ -173,16 +170,16 @@ OptimizationStrategyType = Literal["Performance", "Size"] class VelaCompilerOptions: # pylint: disable=too-many-instance-attributes """Vela compiler options.""" - config_files: Optional[Union[str, List[str]]] = None + config_files: str | list[str] | None = None system_config: str = ArchitectureFeatures.DEFAULT_CONFIG memory_mode: str = ArchitectureFeatures.DEFAULT_CONFIG - accelerator_config: Optional[AcceleratorConfigType] = None + accelerator_config: AcceleratorConfigType | None = None max_block_dependency: int = ArchitectureFeatures.MAX_BLOCKDEP - arena_cache_size: Optional[int] = None + arena_cache_size: int | None = None tensor_allocator: TensorAllocatorType = "HillClimb" cpu_tensor_alignment: int = Tensor.AllocationQuantum optimization_strategy: OptimizationStrategyType = "Performance" - output_dir: Optional[str] = None + output_dir: str | None = None recursion_limit: int = 1000 @@ -207,14 +204,14 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes sys.setrecursionlimit(self.recursion_limit) - def read_model(self, model: Union[str, Path]) -> Model: + def read_model(self, model: str | Path) -> Model: """Read model.""" logger.debug("Read model %s", model) nng, network_type = self._read_model(model) return Model(nng, network_type) - def compile_model(self, model: Union[str, Path, Model]) -> OptimizedModel: + def compile_model(self, model: str | Path | Model) -> OptimizedModel: """Compile the model.""" if isinstance(model, (str, Path)): nng, network_type = self._read_model(model) @@ -240,7 +237,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes except (SystemExit, Exception) as err: raise Exception("Model could not be optimized with Vela compiler") from err - def get_config(self) -> Dict[str, Any]: + def get_config(self) -> dict[str, Any]: """Get compiler configuration.""" arch = self._architecture_features() @@ -277,7 +274,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes } @staticmethod - def _read_model(model: Union[str, Path]) -> Tuple[Graph, NetworkType]: + def _read_model(model: str | Path) -> tuple[Graph, NetworkType]: """Read TFLite model.""" try: model_path = str(model) if isinstance(model, Path) else model @@ -334,7 +331,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes def resolve_compiler_config( vela_compiler_options: VelaCompilerOptions, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Resolve passed compiler options. Vela has number of configuration parameters that being @@ -397,7 +394,7 @@ def _performance_metrics(optimized_model: OptimizedModel) -> PerformanceMetrics: def memory_usage(mem_area: MemArea) -> int: """Get memory usage for the proviced memory area type.""" - memory_used: Dict[MemArea, int] = optimized_model.nng.memory_used + memory_used: dict[MemArea, int] = optimized_model.nng.memory_used bandwidths = optimized_model.nng.bandwidths return memory_used.get(mem_area, 0) if np.sum(bandwidths[mem_area]) > 0 else 0 diff --git a/src/mlia/utils/console.py b/src/mlia/utils/console.py index 7cb3d83..1f428a7 100644 --- a/src/mlia/utils/console.py +++ b/src/mlia/utils/console.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Console output utility functions.""" +from __future__ import annotations + from typing import Iterable -from typing import List -from typing import Optional from rich.console import Console from rich.console import RenderableType @@ -13,7 +13,7 @@ from rich.text import Text def create_section_header( - section_name: Optional[str] = None, length: int = 80, sep: str = "-" + section_name: str | None = None, length: int = 80, sep: str = "-" ) -> str: """Return section header.""" if not section_name: @@ -41,7 +41,7 @@ def style_improvement(result: bool) -> str: def produce_table( rows: Iterable, - headers: Optional[List[str]] = None, + headers: list[str] | None = None, table_style: str = "default", ) -> str: """Represent data in tabular form.""" diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py index 4658738..9ef2d9e 100644 --- a/src/mlia/utils/download.py +++ b/src/mlia/utils/download.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Utils for files downloading.""" +from __future__ import annotations + from dataclasses import dataclass from pathlib import Path from typing import Iterable -from typing import List -from typing import Optional import requests from rich.progress import BarColumn @@ -20,10 +20,10 @@ from mlia.utils.types import parse_int def download_progress( - content_chunks: Iterable[bytes], content_length: Optional[int], label: Optional[str] + content_chunks: Iterable[bytes], content_length: int | None, label: str | None ) -> Iterable[bytes]: """Show progress info while reading content.""" - columns: List[ProgressColumn] = [TextColumn("{task.description}")] + columns: list[ProgressColumn] = [TextColumn("{task.description}")] if content_length is None: total = float("inf") @@ -44,7 +44,7 @@ def download( url: str, dest: Path, show_progress: bool = False, - label: Optional[str] = None, + label: str | None = None, chunk_size: int = 8192, ) -> None: """Download the file.""" diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py index 0c28d35..25619c5 100644 --- a/src/mlia/utils/filesystem.py +++ b/src/mlia/utils/filesystem.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Utils related to file management.""" +from __future__ import annotations + import hashlib import importlib.resources as pkg_resources import json @@ -12,12 +14,8 @@ from tempfile import mkstemp from tempfile import TemporaryDirectory from typing import Any from typing import cast -from typing import Dict from typing import Generator from typing import Iterable -from typing import List -from typing import Optional -from typing import Union def get_mlia_resources() -> Path: @@ -37,7 +35,7 @@ def get_profiles_file() -> Path: return get_mlia_resources() / "profiles.json" -def get_profiles_data() -> Dict[str, Dict[str, Any]]: +def get_profiles_data() -> dict[str, dict[str, Any]]: """Get the profile values as a dictionary.""" with open(get_profiles_file(), encoding="utf-8") as json_file: profiles = json.load(json_file) @@ -48,7 +46,7 @@ def get_profiles_data() -> Dict[str, Dict[str, Any]]: return profiles -def get_profile(target_profile: str) -> Dict[str, Any]: +def get_profile(target_profile: str) -> dict[str, Any]: """Get settings for the provided target profile.""" if not target_profile: raise Exception("Target profile is not provided") @@ -61,7 +59,7 @@ def get_profile(target_profile: str) -> Dict[str, Any]: raise Exception(f"Unable to find target profile {target_profile}") from err -def get_supported_profile_names() -> List[str]: +def get_supported_profile_names() -> list[str]: """Get the supported Ethos-U profile names.""" return list(get_profiles_data().keys()) @@ -73,7 +71,7 @@ def get_target(target_profile: str) -> str: @contextmanager -def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]: +def temp_file(suffix: str | None = None) -> Generator[Path, None, None]: """Create temp file and remove it after.""" _, tmp_file = mkstemp(suffix=suffix) @@ -84,14 +82,14 @@ def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]: @contextmanager -def temp_directory(suffix: Optional[str] = None) -> Generator[Path, None, None]: +def temp_directory(suffix: str | None = None) -> Generator[Path, None, None]: """Create temp directory and remove it after.""" with TemporaryDirectory(suffix=suffix) as tmpdir: yield Path(tmpdir) def file_chunks( - filepath: Union[Path, str], chunk_size: int = 4096 + filepath: str | Path, chunk_size: int = 4096 ) -> Generator[bytes, None, None]: """Return sequence of the file chunks.""" with open(filepath, "rb") as file: @@ -99,7 +97,7 @@ def file_chunks( yield data -def hexdigest(filepath: Union[Path, str], hash_obj: "hashlib._Hash") -> str: +def hexdigest(filepath: str | Path, hash_obj: "hashlib._Hash") -> str: """Return hex digest of the file.""" for chunk in file_chunks(filepath): hash_obj.update(chunk) diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py index 86d7567..793500a 100644 --- a/src/mlia/utils/logging.py +++ b/src/mlia/utils/logging.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Logging utility functions.""" +from __future__ import annotations + import logging from contextlib import contextmanager from contextlib import ExitStack @@ -10,8 +12,6 @@ from pathlib import Path from typing import Any from typing import Callable from typing import Generator -from typing import List -from typing import Optional class LoggerWriter: @@ -61,7 +61,7 @@ class LogFilter(logging.Filter): return self.log_record_filter(record) @classmethod - def equals(cls, log_level: int) -> "LogFilter": + def equals(cls, log_level: int) -> LogFilter: """Return log filter that filters messages by log level.""" def filter_by_level(log_record: logging.LogRecord) -> bool: @@ -70,7 +70,7 @@ class LogFilter(logging.Filter): return cls(filter_by_level) @classmethod - def skip(cls, log_level: int) -> "LogFilter": + def skip(cls, log_level: int) -> LogFilter: """Return log filter that skips messages with particular level.""" def skip_by_level(log_record: logging.LogRecord) -> bool: @@ -81,15 +81,15 @@ class LogFilter(logging.Filter): def create_log_handler( *, - file_path: Optional[Path] = None, - stream: Optional[Any] = None, - log_level: Optional[int] = None, - log_format: Optional[str] = None, - log_filter: Optional[logging.Filter] = None, + file_path: Path | None = None, + stream: Any | None = None, + log_level: int | None = None, + log_format: str | None = None, + log_filter: logging.Filter | None = None, delay: bool = True, ) -> logging.Handler: """Create logger handler.""" - handler: Optional[logging.Handler] = None + handler: logging.Handler | None = None if file_path is not None: handler = logging.FileHandler(file_path, delay=delay) @@ -112,7 +112,7 @@ def create_log_handler( def attach_handlers( - handlers: List[logging.Handler], loggers: List[logging.Logger] + handlers: list[logging.Handler], loggers: list[logging.Logger] ) -> None: """Attach handlers to the loggers.""" for handler in handlers: diff --git a/src/mlia/utils/types.py b/src/mlia/utils/types.py index 9b63928..ea067b8 100644 --- a/src/mlia/utils/types.py +++ b/src/mlia/utils/types.py @@ -1,11 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Types related utility functions.""" +from __future__ import annotations + from typing import Any -from typing import Optional -def is_list_of(data: Any, cls: type, elem_num: Optional[int] = None) -> bool: +def is_list_of(data: Any, cls: type, elem_num: int | None = None) -> bool: """Check if data is a list of object of the same class.""" return ( isinstance(data, (tuple, list)) @@ -24,7 +25,7 @@ def is_number(value: str) -> bool: return True -def parse_int(value: Any, default: Optional[int] = None) -> Optional[int]: +def parse_int(value: Any, default: int | None = None) -> int | None: """Parse integer value.""" try: return int(value) -- cgit v1.2.1