From f5b293d0927506c2a979a091bf0d07ecc78fa181 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 8 Sep 2022 14:24:39 +0100 Subject: MLIA-386 Simplify typing in the source code - Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a --- src/mlia/backend/manager.py | 41 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 22 deletions(-) (limited to 'src/mlia/backend/manager.py') diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py index 8d8246d..c8fe0f7 100644 --- a/src/mlia/backend/manager.py +++ b/src/mlia/backend/manager.py @@ -1,17 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for backend integration.""" +from __future__ import annotations + import logging from abc import ABC from abc import abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Dict -from typing import List from typing import Literal -from typing import Optional -from typing import Set -from typing import Tuple from mlia.backend.application import get_available_applications from mlia.backend.application import install_application @@ -58,7 +55,7 @@ def get_system_name(backend: str, device_type: str) -> str: return _SUPPORTED_SYSTEMS[backend][device_type] -def is_supported(backend: str, device_type: Optional[str] = None) -> bool: +def is_supported(backend: str, device_type: str | None = None) -> bool: """Check if the backend (and optionally device type) is supported.""" if device_type is None: return backend in _SUPPORTED_SYSTEMS @@ -70,17 +67,17 @@ def is_supported(backend: str, device_type: Optional[str] = None) -> bool: return False -def supported_backends() -> List[str]: +def supported_backends() -> list[str]: """Get a list of all backends supported by the backend manager.""" return list(_SUPPORTED_SYSTEMS.keys()) -def get_all_system_names(backend: str) -> List[str]: +def get_all_system_names(backend: str) -> list[str]: """Get all systems supported by the backend.""" return list(_SUPPORTED_SYSTEMS.get(backend, {}).values()) -def get_all_application_names(backend: str) -> List[str]: +def get_all_application_names(backend: str) -> list[str]: """Get all applications supported by the backend.""" app_set = { app @@ -124,8 +121,8 @@ class ExecutionParams: application: str system: str - application_params: List[str] - system_params: List[str] + application_params: list[str] + system_params: list[str] class LogWriter(OutputConsumer): @@ -153,7 +150,7 @@ class GenericInferenceOutputParser(Base64OutputConsumer): } @property - def result(self) -> Dict: + def result(self) -> dict: """Merge the raw results and map the names to the right output names.""" merged_result = {} for raw_result in self.parsed_output: @@ -172,7 +169,7 @@ class GenericInferenceOutputParser(Base64OutputConsumer): """Return true if all expected data has been parsed.""" return set(self.result.keys()) == set(self._map.values()) - def missed_keys(self) -> Set[str]: + def missed_keys(self) -> set[str]: """Return a set of the keys that have not been found in the output.""" return set(self._map.values()) - set(self.result.keys()) @@ -184,12 +181,12 @@ class BackendRunner: """Init BackendRunner instance.""" @staticmethod - def get_installed_systems() -> List[str]: + def get_installed_systems() -> list[str]: """Get list of the installed systems.""" return [system.name for system in get_available_systems()] @staticmethod - def get_installed_applications(system: Optional[str] = None) -> List[str]: + def get_installed_applications(system: str | None = None) -> list[str]: """Get list of the installed application.""" return [ app.name @@ -205,7 +202,7 @@ class BackendRunner: """Return true if requested system installed.""" return system in self.get_installed_systems() - def systems_installed(self, systems: List[str]) -> bool: + def systems_installed(self, systems: list[str]) -> bool: """Check if all provided systems are installed.""" if not systems: return False @@ -213,7 +210,7 @@ class BackendRunner: installed_systems = self.get_installed_systems() return all(system in installed_systems for system in systems) - def applications_installed(self, applications: List[str]) -> bool: + def applications_installed(self, applications: list[str]) -> bool: """Check if all provided applications are installed.""" if not applications: return False @@ -221,7 +218,7 @@ class BackendRunner: installed_apps = self.get_installed_applications() return all(app in installed_apps for app in applications) - def all_installed(self, systems: List[str], apps: List[str]) -> bool: + def all_installed(self, systems: list[str], apps: list[str]) -> bool: """Check if all provided artifacts are installed.""" return self.systems_installed(systems) and self.applications_installed(apps) @@ -247,7 +244,7 @@ class BackendRunner: return ctx @staticmethod - def _params(name: str, params: List[str]) -> List[str]: + def _params(name: str, params: list[str]) -> list[str]: return [p for item in [(name, param) for param in params] for p in item] @@ -259,7 +256,7 @@ class GenericInferenceRunner(ABC): self.backend_runner = backend_runner def run( - self, model_info: ModelInfo, output_consumers: List[OutputConsumer] + self, model_info: ModelInfo, output_consumers: list[OutputConsumer] ) -> None: """Run generic inference for the provided device/model.""" execution_params = self.get_execution_params(model_info) @@ -284,7 +281,7 @@ class GenericInferenceRunner(ABC): ) @staticmethod - def consume_output(output: bytearray, consumers: List[OutputConsumer]) -> bytearray: + def consume_output(output: bytearray, consumers: list[OutputConsumer]) -> bytearray: """ Pass program's output to the consumers and filter it. @@ -320,7 +317,7 @@ class GenericInferenceRunnerEthosU(GenericInferenceRunner): @staticmethod def resolve_system_and_app( device_info: DeviceInfo, backend: str - ) -> Tuple[str, str]: + ) -> tuple[str, str]: """Find appropriate system and application for the provided device/backend.""" try: system_name = get_system_name(backend, device_info.device_type) -- cgit v1.2.1