aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/tools
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-08 14:24:39 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-09 17:21:48 +0100
commitf5b293d0927506c2a979a091bf0d07ecc78fa181 (patch)
tree4de585b7cb6ed34da8237063752270189a730a41 /src/mlia/tools
parentcde0c6ee140bd108849bff40467d8f18ffc332ef (diff)
downloadmlia-f5b293d0927506c2a979a091bf0d07ecc78fa181.tar.gz
MLIA-386 Simplify typing in the source code
- Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
Diffstat (limited to 'src/mlia/tools')
-rw-r--r--src/mlia/tools/metadata/common.py40
-rw-r--r--src/mlia/tools/metadata/corstone.py27
-rw-r--r--src/mlia/tools/vela_wrapper.py33
3 files changed, 48 insertions, 52 deletions
diff --git a/src/mlia/tools/metadata/common.py b/src/mlia/tools/metadata/common.py
index 924e870..32da4a4 100644
--- a/src/mlia/tools/metadata/common.py
+++ b/src/mlia/tools/metadata/common.py
@@ -1,14 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for installation process."""
+from __future__ import annotations
+
import logging
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
-from typing import List
-from typing import Optional
from typing import Union
from mlia.utils.misc import yes
@@ -100,7 +100,7 @@ class SupportsInstallTypeFilter:
class SearchByNameFilter:
"""Filter installation by name."""
- def __init__(self, backend_name: Optional[str]) -> None:
+ def __init__(self, backend_name: str | None) -> None:
"""Init filter."""
self.backend_name = backend_name
@@ -113,12 +113,12 @@ class InstallationManager(ABC):
"""Helper class for managing installations."""
@abstractmethod
- def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None:
+ def install_from(self, backend_path: Path, backend_name: str | None) -> None:
"""Install backend from the local directory."""
@abstractmethod
def download_and_install(
- self, backend_name: Optional[str], eula_agreement: bool
+ self, backend_name: str | None, eula_agreement: bool
) -> None:
"""Download and install backends."""
@@ -134,9 +134,9 @@ class InstallationManager(ABC):
class InstallationFiltersMixin:
"""Mixin for filtering installation based on different conditions."""
- installations: List[Installation]
+ installations: list[Installation]
- def filter_by(self, *filters: InstallationFilter) -> List[Installation]:
+ def filter_by(self, *filters: InstallationFilter) -> list[Installation]:
"""Filter installations."""
return [
installation
@@ -145,8 +145,8 @@ class InstallationFiltersMixin:
]
def could_be_installed_from(
- self, backend_path: Path, backend_name: Optional[str]
- ) -> List[Installation]:
+ self, backend_path: Path, backend_name: str | None
+ ) -> list[Installation]:
"""Return installations that could be installed from provided directory."""
return self.filter_by(
SupportsInstallTypeFilter(InstallFromPath(backend_path)),
@@ -154,8 +154,8 @@ class InstallationFiltersMixin:
)
def could_be_downloaded_and_installed(
- self, backend_name: Optional[str] = None
- ) -> List[Installation]:
+ self, backend_name: str | None = None
+ ) -> list[Installation]:
"""Return installations that could be downloaded and installed."""
return self.filter_by(
SupportsInstallTypeFilter(DownloadAndInstall()),
@@ -163,15 +163,13 @@ class InstallationFiltersMixin:
ReadyForInstallationFilter(),
)
- def already_installed(
- self, backend_name: Optional[str] = None
- ) -> List[Installation]:
+ def already_installed(self, backend_name: str | None = None) -> list[Installation]:
"""Return list of backends that are already installed."""
return self.filter_by(
AlreadyInstalledFilter(), SearchByNameFilter(backend_name)
)
- def ready_for_installation(self) -> List[Installation]:
+ def ready_for_installation(self) -> list[Installation]:
"""Return list of the backends that could be installed."""
return self.filter_by(ReadyForInstallationFilter())
@@ -180,15 +178,15 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
"""Interactive installation manager."""
def __init__(
- self, installations: List[Installation], noninteractive: bool = False
+ self, installations: list[Installation], noninteractive: bool = False
) -> None:
"""Init the manager."""
self.installations = installations
self.noninteractive = noninteractive
def choose_installation_for_path(
- self, backend_path: Path, backend_name: Optional[str]
- ) -> Optional[Installation]:
+ self, backend_path: Path, backend_name: str | None
+ ) -> Installation | None:
"""Check available installation and select one if possible."""
installs = self.could_be_installed_from(backend_path, backend_name)
@@ -220,7 +218,7 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
return installation
- def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None:
+ def install_from(self, backend_path: Path, backend_name: str | None) -> None:
"""Install from the provided directory."""
installation = self.choose_installation_for_path(backend_path, backend_name)
@@ -234,7 +232,7 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
self._install(installation, InstallFromPath(backend_path), prompt)
def download_and_install(
- self, backend_name: Optional[str] = None, eula_agreement: bool = True
+ self, backend_name: str | None = None, eula_agreement: bool = True
) -> None:
"""Download and install available backends."""
installations = self.could_be_downloaded_and_installed(backend_name)
@@ -269,7 +267,7 @@ class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
@staticmethod
def _print_installation_list(
- header: str, installations: List[Installation], new_section: bool = False
+ header: str, installations: list[Installation], new_section: bool = False
) -> None:
"""Print list of the installations."""
logger.info("%s%s\n", "\n" if new_section else "", header)
diff --git a/src/mlia/tools/metadata/corstone.py b/src/mlia/tools/metadata/corstone.py
index 023369c..feef7ad 100644
--- a/src/mlia/tools/metadata/corstone.py
+++ b/src/mlia/tools/metadata/corstone.py
@@ -6,6 +6,8 @@ The import of subprocess module raises a B404 bandit error. MLIA usage of
subprocess is needed and can be considered safe hence disabling the security
check.
"""
+from __future__ import annotations
+
import logging
import platform
import subprocess # nosec
@@ -14,7 +16,6 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Callable
from typing import Iterable
-from typing import List
from typing import Optional
import mlia.backend.manager as backend_manager
@@ -40,7 +41,7 @@ class BackendInfo:
backend_path: Path
copy_source: bool = True
- system_config: Optional[str] = None
+ system_config: str | None = None
PathChecker = Callable[[Path], Optional[BackendInfo]]
@@ -55,10 +56,10 @@ class BackendMetadata:
name: str,
description: str,
system_config: str,
- apps_resources: List[str],
+ apps_resources: list[str],
fvp_dir_name: str,
- download_artifact: Optional[DownloadArtifact],
- supported_platforms: Optional[List[str]] = None,
+ download_artifact: DownloadArtifact | None,
+ supported_platforms: list[str] | None = None,
) -> None:
"""
Initialize BackendMetadata.
@@ -100,7 +101,7 @@ class BackendInstallation(Installation):
backend_runner: backend_manager.BackendRunner,
metadata: BackendMetadata,
path_checker: PathChecker,
- backend_installer: Optional[BackendInstaller],
+ backend_installer: BackendInstaller | None,
) -> None:
"""Init the backend installation."""
self.backend_runner = backend_runner
@@ -209,13 +210,13 @@ class PackagePathChecker:
"""Package path checker."""
def __init__(
- self, expected_files: List[str], backend_subfolder: Optional[str] = None
+ self, expected_files: list[str], backend_subfolder: str | None = None
) -> None:
"""Init the path checker."""
self.expected_files = expected_files
self.backend_subfolder = backend_subfolder
- def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ def __call__(self, backend_path: Path) -> BackendInfo | None:
"""Check if directory contains all expected files."""
resolved_paths = (backend_path / file for file in self.expected_files)
if not all_files_exist(resolved_paths):
@@ -238,9 +239,9 @@ class StaticPathChecker:
def __init__(
self,
static_backend_path: Path,
- expected_files: List[str],
+ expected_files: list[str],
copy_source: bool = False,
- system_config: Optional[str] = None,
+ system_config: str | None = None,
) -> None:
"""Init static path checker."""
self.static_backend_path = static_backend_path
@@ -248,7 +249,7 @@ class StaticPathChecker:
self.copy_source = copy_source
self.system_config = system_config
- def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ def __call__(self, backend_path: Path) -> BackendInfo | None:
"""Check if directory equals static backend path with all expected files."""
if backend_path != self.static_backend_path:
return None
@@ -271,7 +272,7 @@ class CompoundPathChecker:
"""Init compound path checker."""
self.path_checkers = path_checkers
- def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ def __call__(self, backend_path: Path) -> BackendInfo | None:
"""Iterate over checkers and return first non empty backend info."""
first_resolved_backend_info = (
backend_info
@@ -401,7 +402,7 @@ def get_corstone_310_installation() -> Installation:
return corstone_310
-def get_corstone_installations() -> List[Installation]:
+def get_corstone_installations() -> list[Installation]:
"""Get Corstone installations."""
return [
get_corstone_300_installation(),
diff --git a/src/mlia/tools/vela_wrapper.py b/src/mlia/tools/vela_wrapper.py
index 7225797..47c15e9 100644
--- a/src/mlia/tools/vela_wrapper.py
+++ b/src/mlia/tools/vela_wrapper.py
@@ -1,18 +1,15 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Vela wrapper module."""
+from __future__ import annotations
+
import itertools
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
-from typing import Dict
-from typing import List
from typing import Literal
-from typing import Optional
-from typing import Tuple
-from typing import Union
import numpy as np
from ethosu.vela.architecture_features import ArchitectureFeatures
@@ -70,7 +67,7 @@ class NpuSupported:
"""Operator's npu supported attribute."""
supported: bool
- reasons: List[Tuple[str, str]]
+ reasons: list[tuple[str, str]]
@dataclass
@@ -95,7 +92,7 @@ class Operator:
class Operators:
"""Model's operators."""
- ops: List[Operator]
+ ops: list[Operator]
@property
def npu_supported_ratio(self) -> float:
@@ -150,7 +147,7 @@ class OptimizedModel:
compiler_options: CompilerOptions
scheduler_options: SchedulerOptions
- def save(self, output_filename: Union[str, Path]) -> None:
+ def save(self, output_filename: str | Path) -> None:
"""Save instance of the optimized model to the file."""
write_tflite(self.nng, output_filename)
@@ -173,16 +170,16 @@ OptimizationStrategyType = Literal["Performance", "Size"]
class VelaCompilerOptions: # pylint: disable=too-many-instance-attributes
"""Vela compiler options."""
- config_files: Optional[Union[str, List[str]]] = None
+ config_files: str | list[str] | None = None
system_config: str = ArchitectureFeatures.DEFAULT_CONFIG
memory_mode: str = ArchitectureFeatures.DEFAULT_CONFIG
- accelerator_config: Optional[AcceleratorConfigType] = None
+ accelerator_config: AcceleratorConfigType | None = None
max_block_dependency: int = ArchitectureFeatures.MAX_BLOCKDEP
- arena_cache_size: Optional[int] = None
+ arena_cache_size: int | None = None
tensor_allocator: TensorAllocatorType = "HillClimb"
cpu_tensor_alignment: int = Tensor.AllocationQuantum
optimization_strategy: OptimizationStrategyType = "Performance"
- output_dir: Optional[str] = None
+ output_dir: str | None = None
recursion_limit: int = 1000
@@ -207,14 +204,14 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
sys.setrecursionlimit(self.recursion_limit)
- def read_model(self, model: Union[str, Path]) -> Model:
+ def read_model(self, model: str | Path) -> Model:
"""Read model."""
logger.debug("Read model %s", model)
nng, network_type = self._read_model(model)
return Model(nng, network_type)
- def compile_model(self, model: Union[str, Path, Model]) -> OptimizedModel:
+ def compile_model(self, model: str | Path | Model) -> OptimizedModel:
"""Compile the model."""
if isinstance(model, (str, Path)):
nng, network_type = self._read_model(model)
@@ -240,7 +237,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
except (SystemExit, Exception) as err:
raise Exception("Model could not be optimized with Vela compiler") from err
- def get_config(self) -> Dict[str, Any]:
+ def get_config(self) -> dict[str, Any]:
"""Get compiler configuration."""
arch = self._architecture_features()
@@ -277,7 +274,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
}
@staticmethod
- def _read_model(model: Union[str, Path]) -> Tuple[Graph, NetworkType]:
+ def _read_model(model: str | Path) -> tuple[Graph, NetworkType]:
"""Read TFLite model."""
try:
model_path = str(model) if isinstance(model, Path) else model
@@ -334,7 +331,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
def resolve_compiler_config(
vela_compiler_options: VelaCompilerOptions,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""Resolve passed compiler options.
Vela has number of configuration parameters that being
@@ -397,7 +394,7 @@ def _performance_metrics(optimized_model: OptimizedModel) -> PerformanceMetrics:
def memory_usage(mem_area: MemArea) -> int:
"""Get memory usage for the proviced memory area type."""
- memory_used: Dict[MemArea, int] = optimized_model.nng.memory_used
+ memory_used: dict[MemArea, int] = optimized_model.nng.memory_used
bandwidths = optimized_model.nng.bandwidths
return memory_used.get(mem_area, 0) if np.sum(bandwidths[mem_area]) > 0 else 0