aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/tools/vela_wrapper.py
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-08 14:24:39 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-09 17:21:48 +0100
commitf5b293d0927506c2a979a091bf0d07ecc78fa181 (patch)
tree4de585b7cb6ed34da8237063752270189a730a41 /src/mlia/tools/vela_wrapper.py
parentcde0c6ee140bd108849bff40467d8f18ffc332ef (diff)
downloadmlia-f5b293d0927506c2a979a091bf0d07ecc78fa181.tar.gz
MLIA-386 Simplify typing in the source code
- Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
Diffstat (limited to 'src/mlia/tools/vela_wrapper.py')
-rw-r--r--src/mlia/tools/vela_wrapper.py33
1 files changed, 15 insertions, 18 deletions
diff --git a/src/mlia/tools/vela_wrapper.py b/src/mlia/tools/vela_wrapper.py
index 7225797..47c15e9 100644
--- a/src/mlia/tools/vela_wrapper.py
+++ b/src/mlia/tools/vela_wrapper.py
@@ -1,18 +1,15 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Vela wrapper module."""
+from __future__ import annotations
+
import itertools
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
-from typing import Dict
-from typing import List
from typing import Literal
-from typing import Optional
-from typing import Tuple
-from typing import Union
import numpy as np
from ethosu.vela.architecture_features import ArchitectureFeatures
@@ -70,7 +67,7 @@ class NpuSupported:
"""Operator's npu supported attribute."""
supported: bool
- reasons: List[Tuple[str, str]]
+ reasons: list[tuple[str, str]]
@dataclass
@@ -95,7 +92,7 @@ class Operator:
class Operators:
"""Model's operators."""
- ops: List[Operator]
+ ops: list[Operator]
@property
def npu_supported_ratio(self) -> float:
@@ -150,7 +147,7 @@ class OptimizedModel:
compiler_options: CompilerOptions
scheduler_options: SchedulerOptions
- def save(self, output_filename: Union[str, Path]) -> None:
+ def save(self, output_filename: str | Path) -> None:
"""Save instance of the optimized model to the file."""
write_tflite(self.nng, output_filename)
@@ -173,16 +170,16 @@ OptimizationStrategyType = Literal["Performance", "Size"]
class VelaCompilerOptions: # pylint: disable=too-many-instance-attributes
"""Vela compiler options."""
- config_files: Optional[Union[str, List[str]]] = None
+ config_files: str | list[str] | None = None
system_config: str = ArchitectureFeatures.DEFAULT_CONFIG
memory_mode: str = ArchitectureFeatures.DEFAULT_CONFIG
- accelerator_config: Optional[AcceleratorConfigType] = None
+ accelerator_config: AcceleratorConfigType | None = None
max_block_dependency: int = ArchitectureFeatures.MAX_BLOCKDEP
- arena_cache_size: Optional[int] = None
+ arena_cache_size: int | None = None
tensor_allocator: TensorAllocatorType = "HillClimb"
cpu_tensor_alignment: int = Tensor.AllocationQuantum
optimization_strategy: OptimizationStrategyType = "Performance"
- output_dir: Optional[str] = None
+ output_dir: str | None = None
recursion_limit: int = 1000
@@ -207,14 +204,14 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
sys.setrecursionlimit(self.recursion_limit)
- def read_model(self, model: Union[str, Path]) -> Model:
+ def read_model(self, model: str | Path) -> Model:
"""Read model."""
logger.debug("Read model %s", model)
nng, network_type = self._read_model(model)
return Model(nng, network_type)
- def compile_model(self, model: Union[str, Path, Model]) -> OptimizedModel:
+ def compile_model(self, model: str | Path | Model) -> OptimizedModel:
"""Compile the model."""
if isinstance(model, (str, Path)):
nng, network_type = self._read_model(model)
@@ -240,7 +237,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
except (SystemExit, Exception) as err:
raise Exception("Model could not be optimized with Vela compiler") from err
- def get_config(self) -> Dict[str, Any]:
+ def get_config(self) -> dict[str, Any]:
"""Get compiler configuration."""
arch = self._architecture_features()
@@ -277,7 +274,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
}
@staticmethod
- def _read_model(model: Union[str, Path]) -> Tuple[Graph, NetworkType]:
+ def _read_model(model: str | Path) -> tuple[Graph, NetworkType]:
"""Read TFLite model."""
try:
model_path = str(model) if isinstance(model, Path) else model
@@ -334,7 +331,7 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
def resolve_compiler_config(
vela_compiler_options: VelaCompilerOptions,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""Resolve passed compiler options.
Vela has number of configuration parameters that being
@@ -397,7 +394,7 @@ def _performance_metrics(optimized_model: OptimizedModel) -> PerformanceMetrics:
def memory_usage(mem_area: MemArea) -> int:
"""Get memory usage for the proviced memory area type."""
- memory_used: Dict[MemArea, int] = optimized_model.nng.memory_used
+ memory_used: dict[MemArea, int] = optimized_model.nng.memory_used
bandwidths = optimized_model.nng.bandwidths
return memory_used.get(mem_area, 0) if np.sum(bandwidths[mem_area]) > 0 else 0