aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/backend/manager.py
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-08 14:24:39 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-09 17:21:48 +0100
commitf5b293d0927506c2a979a091bf0d07ecc78fa181 (patch)
tree4de585b7cb6ed34da8237063752270189a730a41 /src/mlia/backend/manager.py
parentcde0c6ee140bd108849bff40467d8f18ffc332ef (diff)
downloadmlia-f5b293d0927506c2a979a091bf0d07ecc78fa181.tar.gz
MLIA-386 Simplify typing in the source code
- Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
Diffstat (limited to 'src/mlia/backend/manager.py')
-rw-r--r--src/mlia/backend/manager.py41
1 files changed, 19 insertions, 22 deletions
diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py
index 8d8246d..c8fe0f7 100644
--- a/src/mlia/backend/manager.py
+++ b/src/mlia/backend/manager.py
@@ -1,17 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for backend integration."""
+from __future__ import annotations
+
import logging
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
-from typing import Dict
-from typing import List
from typing import Literal
-from typing import Optional
-from typing import Set
-from typing import Tuple
from mlia.backend.application import get_available_applications
from mlia.backend.application import install_application
@@ -58,7 +55,7 @@ def get_system_name(backend: str, device_type: str) -> str:
return _SUPPORTED_SYSTEMS[backend][device_type]
-def is_supported(backend: str, device_type: Optional[str] = None) -> bool:
+def is_supported(backend: str, device_type: str | None = None) -> bool:
"""Check if the backend (and optionally device type) is supported."""
if device_type is None:
return backend in _SUPPORTED_SYSTEMS
@@ -70,17 +67,17 @@ def is_supported(backend: str, device_type: Optional[str] = None) -> bool:
return False
-def supported_backends() -> List[str]:
+def supported_backends() -> list[str]:
"""Get a list of all backends supported by the backend manager."""
return list(_SUPPORTED_SYSTEMS.keys())
-def get_all_system_names(backend: str) -> List[str]:
+def get_all_system_names(backend: str) -> list[str]:
"""Get all systems supported by the backend."""
return list(_SUPPORTED_SYSTEMS.get(backend, {}).values())
-def get_all_application_names(backend: str) -> List[str]:
+def get_all_application_names(backend: str) -> list[str]:
"""Get all applications supported by the backend."""
app_set = {
app
@@ -124,8 +121,8 @@ class ExecutionParams:
application: str
system: str
- application_params: List[str]
- system_params: List[str]
+ application_params: list[str]
+ system_params: list[str]
class LogWriter(OutputConsumer):
@@ -153,7 +150,7 @@ class GenericInferenceOutputParser(Base64OutputConsumer):
}
@property
- def result(self) -> Dict:
+ def result(self) -> dict:
"""Merge the raw results and map the names to the right output names."""
merged_result = {}
for raw_result in self.parsed_output:
@@ -172,7 +169,7 @@ class GenericInferenceOutputParser(Base64OutputConsumer):
"""Return true if all expected data has been parsed."""
return set(self.result.keys()) == set(self._map.values())
- def missed_keys(self) -> Set[str]:
+ def missed_keys(self) -> set[str]:
"""Return a set of the keys that have not been found in the output."""
return set(self._map.values()) - set(self.result.keys())
@@ -184,12 +181,12 @@ class BackendRunner:
"""Init BackendRunner instance."""
@staticmethod
- def get_installed_systems() -> List[str]:
+ def get_installed_systems() -> list[str]:
"""Get list of the installed systems."""
return [system.name for system in get_available_systems()]
@staticmethod
- def get_installed_applications(system: Optional[str] = None) -> List[str]:
+ def get_installed_applications(system: str | None = None) -> list[str]:
"""Get list of the installed application."""
return [
app.name
@@ -205,7 +202,7 @@ class BackendRunner:
"""Return true if requested system installed."""
return system in self.get_installed_systems()
- def systems_installed(self, systems: List[str]) -> bool:
+ def systems_installed(self, systems: list[str]) -> bool:
"""Check if all provided systems are installed."""
if not systems:
return False
@@ -213,7 +210,7 @@ class BackendRunner:
installed_systems = self.get_installed_systems()
return all(system in installed_systems for system in systems)
- def applications_installed(self, applications: List[str]) -> bool:
+ def applications_installed(self, applications: list[str]) -> bool:
"""Check if all provided applications are installed."""
if not applications:
return False
@@ -221,7 +218,7 @@ class BackendRunner:
installed_apps = self.get_installed_applications()
return all(app in installed_apps for app in applications)
- def all_installed(self, systems: List[str], apps: List[str]) -> bool:
+ def all_installed(self, systems: list[str], apps: list[str]) -> bool:
"""Check if all provided artifacts are installed."""
return self.systems_installed(systems) and self.applications_installed(apps)
@@ -247,7 +244,7 @@ class BackendRunner:
return ctx
@staticmethod
- def _params(name: str, params: List[str]) -> List[str]:
+ def _params(name: str, params: list[str]) -> list[str]:
return [p for item in [(name, param) for param in params] for p in item]
@@ -259,7 +256,7 @@ class GenericInferenceRunner(ABC):
self.backend_runner = backend_runner
def run(
- self, model_info: ModelInfo, output_consumers: List[OutputConsumer]
+ self, model_info: ModelInfo, output_consumers: list[OutputConsumer]
) -> None:
"""Run generic inference for the provided device/model."""
execution_params = self.get_execution_params(model_info)
@@ -284,7 +281,7 @@ class GenericInferenceRunner(ABC):
)
@staticmethod
- def consume_output(output: bytearray, consumers: List[OutputConsumer]) -> bytearray:
+ def consume_output(output: bytearray, consumers: list[OutputConsumer]) -> bytearray:
"""
Pass program's output to the consumers and filter it.
@@ -320,7 +317,7 @@ class GenericInferenceRunnerEthosU(GenericInferenceRunner):
@staticmethod
def resolve_system_and_app(
device_info: DeviceInfo, backend: str
- ) -> Tuple[str, str]:
+ ) -> tuple[str, str]:
"""Find appropriate system and application for the provided device/backend."""
try:
system_name = get_system_name(backend, device_info.device_type)