diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-09-08 14:24:39 +0100 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-09-09 17:21:48 +0100 |
commit | f5b293d0927506c2a979a091bf0d07ecc78fa181 (patch) | |
tree | 4de585b7cb6ed34da8237063752270189a730a41 /src/mlia/tools/vela_wrapper.py | |
parent | cde0c6ee140bd108849bff40467d8f18ffc332ef (diff) | |
download | mlia-f5b293d0927506c2a979a091bf0d07ecc78fa181.tar.gz |
MLIA-386 Simplify typing in the source code
- Enable deferred annotations evaluation
- Use builtin types for type hints whenever possible
- Use | syntax for union types
- Rename mlia.core._typing into mlia.core.typing
Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
Diffstat (limited to 'src/mlia/tools/vela_wrapper.py')
-rw-r--r-- | src/mlia/tools/vela_wrapper.py | 33 |
1 files changed, 15 insertions, 18 deletions
diff --git a/src/mlia/tools/vela_wrapper.py b/src/mlia/tools/vela_wrapper.py index 7225797..47c15e9 100644 --- a/src/mlia/tools/vela_wrapper.py +++ b/src/mlia/tools/vela_wrapper.py @@ -1,18 +1,15 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Vela wrapper module.""" +from __future__ import annotations + import itertools import logging import sys from dataclasses import dataclass from pathlib import Path from typing import Any -from typing import Dict -from typing import List from typing import Literal -from typing import Optional -from typing import Tuple -from typing import Union import numpy as np from ethosu.vela.architecture_features import ArchitectureFeatures @@ -70,7 +67,7 @@ class NpuSupported: """Operator's npu supported attribute.""" supported: bool - reasons: List[Tuple[str, str]] + reasons: list[tuple[str, str]] @dataclass @@ -95,7 +92,7 @@ class Operator: class Operators: """Model's operators.""" - ops: List[Operator] + ops: list[Operator] @property def npu_supported_ratio(self) -> float: @@ -150,7 +147,7 @@ class OptimizedModel: compiler_options: CompilerOptions scheduler_options: SchedulerOptions - def save(self, output_filename: Union[str, Path]) -> None: + def save(self, output_filename: str | Path) -> None: """Save instance of the optimized model to the file.""" write_tflite(self.nng, output_filename) @@ -173,16 +170,16 @@ OptimizationStrategyType = Literal["Performance", "Size"] class VelaCompilerOptions: # pylint: disable=too-many-instance-attributes """Vela compiler options.""" - config_files: Optional[Union[str, List[str]]] = None + config_files: str | list[str] | None = None system_config: str = ArchitectureFeatures.DEFAULT_CONFIG memory_mode: str = ArchitectureFeatures.DEFAULT_CONFIG - accelerator_config: Optional[AcceleratorConfigType] = None + accelerator_config: AcceleratorConfigType | None = None max_block_dependency: int = ArchitectureFeatures.MAX_BLOCKDEP - arena_cache_size: Optional[int] = None + arena_cache_size: int | None = None tensor_allocator: TensorAllocatorType = "HillClimb" cpu_tensor_alignment: int = Tensor.AllocationQuantum optimization_strategy: OptimizationStrategyType = "Performance" - output_dir: Optional[str] = None + output_dir: str | None = None recursion_limit: int = 1000 @@ -207,14 +204,14 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes sys.setrecursionlimit(self.recursion_limit) - def read_model(self, model: Union[str, Path]) -> Model: + def read_model(self, model: str | Path) -> Model: """Read model.""" logger.debug("Read model %s", model) nng, network_type = self._read_model(model) return Model(nng, network_type) - def compile_model(self, model: Union[str, Path, Model]) -> OptimizedModel: + def compile_model(self, model: str | Path | Model) -> OptimizedModel: """Compile the model.""" if isinstance(model, (str, Path)): nng, network_type = self._read_model(model) @@ -240,7 +237,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes except (SystemExit, Exception) as err: raise Exception("Model could not be optimized with Vela compiler") from err - def get_config(self) -> Dict[str, Any]: + def get_config(self) -> dict[str, Any]: """Get compiler configuration.""" arch = self._architecture_features() @@ -277,7 +274,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes } @staticmethod - def _read_model(model: Union[str, Path]) -> Tuple[Graph, NetworkType]: + def _read_model(model: str | Path) -> tuple[Graph, NetworkType]: """Read TFLite model.""" try: model_path = str(model) if isinstance(model, Path) else model @@ -334,7 +331,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes def resolve_compiler_config( vela_compiler_options: VelaCompilerOptions, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Resolve passed compiler options. Vela has number of configuration parameters that being @@ -397,7 +394,7 @@ def _performance_metrics(optimized_model: OptimizedModel) -> PerformanceMetrics: def memory_usage(mem_area: MemArea) -> int: """Get memory usage for the proviced memory area type.""" - memory_used: Dict[MemArea, int] = optimized_model.nng.memory_used + memory_used: dict[MemArea, int] = optimized_model.nng.memory_used bandwidths = optimized_model.nng.bandwidths return memory_used.get(mem_area, 0) if np.sum(bandwidths[mem_area]) > 0 else 0 |