From f5b293d0927506c2a979a091bf0d07ecc78fa181 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 8 Sep 2022 14:24:39 +0100 Subject: MLIA-386 Simplify typing in the source code - Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a --- src/mlia/core/_typing.py | 12 ----- src/mlia/core/advice_generation.py | 14 +++--- src/mlia/core/advisor.py | 11 +++-- src/mlia/core/common.py | 4 +- src/mlia/core/context.py | 31 ++++++------ src/mlia/core/data_analysis.py | 9 ++-- src/mlia/core/events.py | 24 +++++---- src/mlia/core/handlers.py | 10 ++-- src/mlia/core/helpers.py | 15 +++--- src/mlia/core/mixins.py | 7 +-- src/mlia/core/performance.py | 19 ++++---- src/mlia/core/reporting.py | 99 ++++++++++++++++++-------------------- src/mlia/core/typing.py | 12 +++++ src/mlia/core/workflow.py | 15 +++--- 14 files changed, 141 insertions(+), 141 deletions(-) delete mode 100644 src/mlia/core/_typing.py create mode 100644 src/mlia/core/typing.py (limited to 'src/mlia/core') diff --git a/src/mlia/core/_typing.py b/src/mlia/core/_typing.py deleted file mode 100644 index bda995c..0000000 --- a/src/mlia/core/_typing.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module for custom type hints.""" -from pathlib import Path -from typing import Literal -from typing import TextIO -from typing import Union - - -FileLike = TextIO -PathOrFileLike = Union[str, Path, FileLike] -OutputFormat = Literal["plain_text", "csv", "json"] diff --git a/src/mlia/core/advice_generation.py b/src/mlia/core/advice_generation.py index 76cc1f2..86285fe 100644 --- a/src/mlia/core/advice_generation.py +++ b/src/mlia/core/advice_generation.py @@ -1,14 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for advice generation.""" +from __future__ import annotations + from abc import ABC from abc import abstractmethod from dataclasses import dataclass from functools import wraps from typing import Any from typing import Callable -from typing import List -from typing import Union from mlia.core.common import AdviceCategory from mlia.core.common import DataItem @@ -20,7 +20,7 @@ from mlia.core.mixins import ContextMixin class Advice: """Base class for the advice.""" - messages: List[str] + messages: list[str] @dataclass @@ -56,7 +56,7 @@ class AdviceProducer(ABC): """ @abstractmethod - def get_advice(self) -> Union[Advice, List[Advice]]: + def get_advice(self) -> Advice | list[Advice]: """Get produced advice.""" @@ -76,13 +76,13 @@ class FactBasedAdviceProducer(ContextAwareAdviceProducer): def __init__(self) -> None: """Init advice producer.""" - self.advice: List[Advice] = [] + self.advice: list[Advice] = [] - def get_advice(self) -> Union[Advice, List[Advice]]: + def get_advice(self) -> Advice | list[Advice]: """Get produced advice.""" return self.advice - def add_advice(self, messages: List[str]) -> None: + def add_advice(self, messages: list[str]) -> None: """Add advice.""" self.advice.append(Advice(messages)) diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py index 13689fa..d684241 100644 --- a/src/mlia/core/advisor.py +++ b/src/mlia/core/advisor.py @@ -1,10 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Inference advisor module.""" +from __future__ import annotations + from abc import abstractmethod from pathlib import Path from typing import cast -from typing import List from mlia.core.advice_generation import AdviceProducer from mlia.core.common import NamedEntity @@ -44,19 +45,19 @@ class DefaultInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin): ) @abstractmethod - def get_collectors(self, context: Context) -> List[DataCollector]: + def get_collectors(self, context: Context) -> list[DataCollector]: """Return list of the data collectors.""" @abstractmethod - def get_analyzers(self, context: Context) -> List[DataAnalyzer]: + def get_analyzers(self, context: Context) -> list[DataAnalyzer]: """Return list of the data analyzers.""" @abstractmethod - def get_producers(self, context: Context) -> List[AdviceProducer]: + def get_producers(self, context: Context) -> list[AdviceProducer]: """Return list of the advice producers.""" @abstractmethod - def get_events(self, context: Context) -> List[Event]: + def get_events(self, context: Context) -> list[Event]: """Return list of the startup events.""" def get_string_parameter(self, context: Context, param: str) -> str: diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py index a11bf9a..63fb324 100644 --- a/src/mlia/core/common.py +++ b/src/mlia/core/common.py @@ -5,6 +5,8 @@ This module contains common interfaces/classess shared across core module. """ +from __future__ import annotations + from abc import ABC from abc import abstractmethod from enum import auto @@ -30,7 +32,7 @@ class AdviceCategory(Flag): ALL = OPERATORS | PERFORMANCE | OPTIMIZATION @classmethod - def from_string(cls, value: str) -> "AdviceCategory": + def from_string(cls, value: str) -> AdviceCategory: """Resolve enum value from string value.""" category_names = [item.name for item in AdviceCategory] if not value or value.upper() not in category_names: diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py index 83d2f7c..a4737bb 100644 --- a/src/mlia/core/context.py +++ b/src/mlia/core/context.py @@ -7,15 +7,14 @@ Context is an object that describes advisor working environment and requested behavior (advice categories, input configuration parameters). """ +from __future__ import annotations + import logging from abc import ABC from abc import abstractmethod from pathlib import Path from typing import Any -from typing import List from typing import Mapping -from typing import Optional -from typing import Union from mlia.core.common import AdviceCategory from mlia.core.events import DefaultEventPublisher @@ -50,7 +49,7 @@ class Context(ABC): @property @abstractmethod - def event_handlers(self) -> Optional[List[EventHandler]]: + def event_handlers(self) -> list[EventHandler] | None: """Return list of the event_handlers.""" @property @@ -60,7 +59,7 @@ class Context(ABC): @property @abstractmethod - def config_parameters(self) -> Optional[Mapping[str, Any]]: + def config_parameters(self) -> Mapping[str, Any] | None: """Return configuration parameters.""" @property @@ -73,7 +72,7 @@ class Context(ABC): self, *, advice_category: AdviceCategory, - event_handlers: List[EventHandler], + event_handlers: list[EventHandler], config_parameters: Mapping[str, Any], ) -> None: """Update context parameters.""" @@ -98,14 +97,14 @@ class ExecutionContext(Context): self, *, advice_category: AdviceCategory = AdviceCategory.ALL, - config_parameters: Optional[Mapping[str, Any]] = None, - working_dir: Optional[Union[str, Path]] = None, - event_handlers: Optional[List[EventHandler]] = None, - event_publisher: Optional[EventPublisher] = None, + config_parameters: Mapping[str, Any] | None = None, + working_dir: str | Path | None = None, + event_handlers: list[EventHandler] | None = None, + event_publisher: EventPublisher | None = None, verbose: bool = False, logs_dir: str = "logs", models_dir: str = "models", - action_resolver: Optional[ActionResolver] = None, + action_resolver: ActionResolver | None = None, ) -> None: """Init execution context. @@ -151,22 +150,22 @@ class ExecutionContext(Context): self._advice_category = advice_category @property - def config_parameters(self) -> Optional[Mapping[str, Any]]: + def config_parameters(self) -> Mapping[str, Any] | None: """Return configuration parameters.""" return self._config_parameters @config_parameters.setter - def config_parameters(self, config_parameters: Optional[Mapping[str, Any]]) -> None: + def config_parameters(self, config_parameters: Mapping[str, Any] | None) -> None: """Setter for the configuration parameters.""" self._config_parameters = config_parameters @property - def event_handlers(self) -> Optional[List[EventHandler]]: + def event_handlers(self) -> list[EventHandler] | None: """Return list of the event handlers.""" return self._event_handlers @event_handlers.setter - def event_handlers(self, event_handlers: List[EventHandler]) -> None: + def event_handlers(self, event_handlers: list[EventHandler]) -> None: """Setter for the event handlers.""" self._event_handlers = event_handlers @@ -196,7 +195,7 @@ class ExecutionContext(Context): self, *, advice_category: AdviceCategory, - event_handlers: List[EventHandler], + event_handlers: list[EventHandler], config_parameters: Mapping[str, Any], ) -> None: """Update context parameters.""" diff --git a/src/mlia/core/data_analysis.py b/src/mlia/core/data_analysis.py index 6adb41e..0603425 100644 --- a/src/mlia/core/data_analysis.py +++ b/src/mlia/core/data_analysis.py @@ -1,10 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for data analysis.""" +from __future__ import annotations + from abc import ABC from abc import abstractmethod from dataclasses import dataclass -from typing import List from mlia.core.common import DataItem from mlia.core.mixins import ContextMixin @@ -29,7 +30,7 @@ class DataAnalyzer(ABC): """ @abstractmethod - def get_analyzed_data(self) -> List[DataItem]: + def get_analyzed_data(self) -> list[DataItem]: """Get analyzed data.""" @@ -59,9 +60,9 @@ class FactExtractor(ContextAwareDataAnalyzer): def __init__(self) -> None: """Init fact extractor.""" - self.facts: List[Fact] = [] + self.facts: list[Fact] = [] - def get_analyzed_data(self) -> List[DataItem]: + def get_analyzed_data(self) -> list[DataItem]: """Return list of the collected facts.""" return self.facts diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py index 0b8461b..71c86e2 100644 --- a/src/mlia/core/events.py +++ b/src/mlia/core/events.py @@ -9,6 +9,8 @@ calling application. Each component of the workflow can generate events of specific type. Application can subscribe and react to those events. """ +from __future__ import annotations + import traceback import uuid from abc import ABC @@ -19,11 +21,7 @@ from dataclasses import dataclass from dataclasses import field from functools import singledispatchmethod from typing import Any -from typing import Dict from typing import Generator -from typing import List -from typing import Optional -from typing import Tuple from mlia.core.common import DataItem @@ -41,7 +39,7 @@ class Event: """Generate unique ID for the event.""" self.event_id = str(uuid.uuid4()) - def compare_without_id(self, other: "Event") -> bool: + def compare_without_id(self, other: Event) -> bool: """Compare two events without event_id field.""" if not isinstance(other, Event) or self.__class__ != other.__class__: return False @@ -73,7 +71,7 @@ class ActionStartedEvent(Event): """ action_type: str - params: Optional[Dict] = None + params: dict | None = None @dataclass @@ -84,7 +82,7 @@ class SubActionEvent(ChildEvent): """ action_type: str - params: Optional[Dict] = None + params: dict | None = None @dataclass @@ -271,8 +269,8 @@ class EventDispatcherMetaclass(type): def __new__( cls, clsname: str, - bases: Tuple, - namespace: Dict[str, Any], + bases: tuple[type, ...], + namespace: dict[str, Any], event_handler_method_prefix: str = "on_", ) -> Any: """Create event dispatcher and link event handlers.""" @@ -321,7 +319,7 @@ class EventPublisher(ABC): """ def register_event_handlers( - self, event_handlers: Optional[List[EventHandler]] + self, event_handlers: list[EventHandler] | None ) -> None: """Register event handlers. @@ -354,7 +352,7 @@ class DefaultEventPublisher(EventPublisher): def __init__(self) -> None: """Init the event publisher.""" - self.handlers: List[EventHandler] = [] + self.handlers: list[EventHandler] = [] def register_event_handler(self, event_handler: EventHandler) -> None: """Register the event handler. @@ -374,7 +372,7 @@ class DefaultEventPublisher(EventPublisher): @contextmanager def stage( - publisher: EventPublisher, events: Tuple[Event, Event] + publisher: EventPublisher, events: tuple[Event, Event] ) -> Generator[None, None, None]: """Generate events before and after stage. @@ -390,7 +388,7 @@ def stage( @contextmanager def action( - publisher: EventPublisher, action_type: str, params: Optional[Dict] = None + publisher: EventPublisher, action_type: str, params: dict | None = None ) -> Generator[None, None, None]: """Generate events before and after action.""" action_started = ActionStartedEvent(action_type, params) diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py index e576f74..a3255ae 100644 --- a/src/mlia/core/handlers.py +++ b/src/mlia/core/handlers.py @@ -1,13 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Event handlers module.""" +from __future__ import annotations + import logging from typing import Any from typing import Callable -from typing import List -from typing import Optional -from mlia.core._typing import PathOrFileLike from mlia.core.advice_generation import Advice from mlia.core.advice_generation import AdviceEvent from mlia.core.events import ActionFinishedEvent @@ -28,6 +27,7 @@ from mlia.core.events import ExecutionStartedEvent from mlia.core.reporting import Report from mlia.core.reporting import Reporter from mlia.core.reporting import resolve_output_format +from mlia.core.typing import PathOrFileLike from mlia.utils.console import create_section_header @@ -101,14 +101,14 @@ class WorkflowEventsHandler(SystemEventsHandler): def __init__( self, formatter_resolver: Callable[[Any], Callable[[Any], Report]], - output: Optional[PathOrFileLike] = None, + output: PathOrFileLike | None = None, ) -> None: """Init event handler.""" output_format = resolve_output_format(output) self.reporter = Reporter(formatter_resolver, output_format) self.output = output - self.advice: List[Advice] = [] + self.advice: list[Advice] = [] def on_execution_started(self, event: ExecutionStartedEvent) -> None: """Handle ExecutionStarted event.""" diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py index d10ea5d..f0c4474 100644 --- a/src/mlia/core/helpers.py +++ b/src/mlia/core/helpers.py @@ -2,34 +2,35 @@ # SPDX-License-Identifier: Apache-2.0 """Module for various helper classes.""" # pylint: disable=no-self-use, unused-argument +from __future__ import annotations + from typing import Any -from typing import List class ActionResolver: """Helper class for generating actions (e.g. commands with parameters).""" - def apply_optimizations(self, **kwargs: Any) -> List[str]: + def apply_optimizations(self, **kwargs: Any) -> list[str]: """Return action details for applying optimizations.""" return [] - def supported_operators_info(self) -> List[str]: + def supported_operators_info(self) -> list[str]: """Return action details for generating supported ops report.""" return [] - def check_performance(self) -> List[str]: + def check_performance(self) -> list[str]: """Return action details for checking performance.""" return [] - def check_operator_compatibility(self) -> List[str]: + def check_operator_compatibility(self) -> list[str]: """Return action details for checking op compatibility.""" return [] - def operator_compatibility_details(self) -> List[str]: + def operator_compatibility_details(self) -> list[str]: """Return action details for getting more information about op compatibility.""" return [] - def optimization_details(self) -> List[str]: + def optimization_details(self) -> list[str]: """Return action detail for getting information about optimizations.""" return [] diff --git a/src/mlia/core/mixins.py b/src/mlia/core/mixins.py index ee03100..5ef9d66 100644 --- a/src/mlia/core/mixins.py +++ b/src/mlia/core/mixins.py @@ -1,8 +1,9 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Mixins module.""" +from __future__ import annotations + from typing import Any -from typing import Optional from mlia.core.context import Context @@ -27,8 +28,8 @@ class ParameterResolverMixin: section: str, name: str, expected: bool = True, - expected_type: Optional[type] = None, - context: Optional[Context] = None, + expected_type: type | None = None, + context: Context | None = None, ) -> Any: """Get parameter value.""" ctx = context or self.context diff --git a/src/mlia/core/performance.py b/src/mlia/core/performance.py index 5433d5c..cb12918 100644 --- a/src/mlia/core/performance.py +++ b/src/mlia/core/performance.py @@ -1,30 +1,31 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for performance estimation.""" +from __future__ import annotations + from abc import abstractmethod from typing import Callable from typing import Generic -from typing import List from typing import TypeVar -ModelType = TypeVar("ModelType") # pylint: disable=invalid-name -PerfMetricsType = TypeVar("PerfMetricsType") # pylint: disable=invalid-name +M = TypeVar("M") # model type +P = TypeVar("P") # performance metrics -class PerformanceEstimator(Generic[ModelType, PerfMetricsType]): +class PerformanceEstimator(Generic[M, P]): """Base class for the performance estimation.""" @abstractmethod - def estimate(self, model: ModelType) -> PerfMetricsType: + def estimate(self, model: M) -> P: """Estimate performance.""" def estimate_performance( - original_model: ModelType, - estimator: PerformanceEstimator[ModelType, PerfMetricsType], - model_transformations: List[Callable[[ModelType], ModelType]], -) -> List[PerfMetricsType]: + original_model: M, + estimator: PerformanceEstimator[M, P], + model_transformations: list[Callable[[M], M]], +) -> list[P]: """Estimate performance impact. This function estimates performance impact on model performance after diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py index 58a41d3..0c8fabc 100644 --- a/src/mlia/core/reporting.py +++ b/src/mlia/core/reporting.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Reporting module.""" +from __future__ import annotations + import csv import json import logging @@ -19,19 +21,14 @@ from typing import Any from typing import Callable from typing import cast from typing import Collection -from typing import Dict from typing import Generator from typing import Iterable -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union import numpy as np -from mlia.core._typing import FileLike -from mlia.core._typing import OutputFormat -from mlia.core._typing import PathOrFileLike +from mlia.core.typing import FileLike +from mlia.core.typing import OutputFormat +from mlia.core.typing import PathOrFileLike from mlia.utils.console import apply_style from mlia.utils.console import produce_table from mlia.utils.logging import LoggerWriter @@ -48,7 +45,7 @@ class Report(ABC): """Convert to json serializible format.""" @abstractmethod - def to_csv(self, **kwargs: Any) -> List[Any]: + def to_csv(self, **kwargs: Any) -> list[Any]: """Convert to csv serializible format.""" @abstractmethod @@ -62,9 +59,9 @@ class ReportItem: def __init__( self, name: str, - alias: Optional[str] = None, - value: Optional[Union[str, int, "Cell"]] = None, - nested_items: Optional[List["ReportItem"]] = None, + alias: str | None = None, + value: str | int | Cell | None = None, + nested_items: list[ReportItem] | None = None, ) -> None: """Init the report item.""" self.name = name @@ -98,9 +95,9 @@ class Format: :param style: text style """ - wrap_width: Optional[int] = None - str_fmt: Optional[Union[str, Callable[[Any], str]]] = None - style: Optional[str] = None + wrap_width: int | None = None + str_fmt: str | Callable[[Any], str] | None = None + style: str | None = None @dataclass @@ -112,7 +109,7 @@ class Cell: """ value: Any - fmt: Optional[Format] = None + fmt: Format | None = None def _apply_style(self, value: str) -> str: """Apply style to the value.""" @@ -151,7 +148,7 @@ class CountAwareCell(Cell): def __init__( self, - value: Optional[Union[int, float]], + value: int | float | None, singular: str, plural: str, format_string: str = ",d", @@ -159,7 +156,7 @@ class CountAwareCell(Cell): """Init cell instance.""" self.unit = singular if value == 1 else plural - def format_value(val: Optional[Union[int, float]]) -> str: + def format_value(val: int | float | None) -> str: """Provide string representation for the value.""" if val is None: return "" @@ -183,7 +180,7 @@ class CountAwareCell(Cell): class BytesCell(CountAwareCell): """Cell that represents memory size.""" - def __init__(self, value: Optional[int]) -> None: + def __init__(self, value: int | None) -> None: """Init cell instance.""" super().__init__(value, "byte", "bytes") @@ -191,7 +188,7 @@ class BytesCell(CountAwareCell): class CyclesCell(CountAwareCell): """Cell that represents cycles.""" - def __init__(self, value: Optional[Union[int, float]]) -> None: + def __init__(self, value: int | float | None) -> None: """Init cell instance.""" super().__init__(value, "cycle", "cycles", ",.0f") @@ -199,7 +196,7 @@ class CyclesCell(CountAwareCell): class ClockCell(CountAwareCell): """Cell that represents clock value.""" - def __init__(self, value: Optional[Union[int, float]]) -> None: + def __init__(self, value: int | float | None) -> None: """Init cell instance.""" super().__init__(value, "Hz", "Hz", ",.0f") @@ -210,9 +207,9 @@ class Column: def __init__( self, header: str, - alias: Optional[str] = None, - fmt: Optional[Format] = None, - only_for: Optional[List[str]] = None, + alias: str | None = None, + fmt: Format | None = None, + only_for: list[str] | None = None, ) -> None: """Init column definition. @@ -228,7 +225,7 @@ class Column: self.fmt = fmt self.only_for = only_for - def supports_format(self, fmt: str) -> bool: + def supports_format(self, fmt: OutputFormat) -> bool: """Return true if column should be shown.""" return not self.only_for or fmt in self.only_for @@ -236,20 +233,20 @@ class Column: class NestedReport(Report): """Report with nested items.""" - def __init__(self, name: str, alias: str, items: List[ReportItem]) -> None: + def __init__(self, name: str, alias: str, items: list[ReportItem]) -> None: """Init nested report.""" self.name = name self.alias = alias self.items = items - def to_csv(self, **kwargs: Any) -> List[Any]: + def to_csv(self, **kwargs: Any) -> list[Any]: """Convert to csv serializible format.""" result = {} def collect_item_values( item: ReportItem, - _parent: Optional[ReportItem], - _prev: Optional[ReportItem], + _parent: ReportItem | None, + _prev: ReportItem | None, _level: int, ) -> None: """Collect item values into a dictionary..""" @@ -279,13 +276,13 @@ class NestedReport(Report): def to_json(self, **kwargs: Any) -> Any: """Convert to json serializible format.""" - per_parent: Dict[Optional[ReportItem], Dict] = defaultdict(dict) + per_parent: dict[ReportItem | None, dict] = defaultdict(dict) result = per_parent[None] def collect_as_dicts( item: ReportItem, - parent: Optional[ReportItem], - _prev: Optional[ReportItem], + parent: ReportItem | None, + _prev: ReportItem | None, _level: int, ) -> None: """Collect item values as nested dictionaries.""" @@ -313,8 +310,8 @@ class NestedReport(Report): def convert_to_text( item: ReportItem, - _parent: Optional[ReportItem], - prev: Optional[ReportItem], + _parent: ReportItem | None, + prev: ReportItem | None, level: int, ) -> None: """Convert item to text representation.""" @@ -345,12 +342,12 @@ class NestedReport(Report): def _traverse( self, - items: List[ReportItem], + items: list[ReportItem], visit_item: Callable[ - [ReportItem, Optional[ReportItem], Optional[ReportItem], int], None + [ReportItem, ReportItem | None, ReportItem | None, int], None ], level: int = 1, - parent: Optional[ReportItem] = None, + parent: ReportItem | None = None, ) -> None: """Traverse through items.""" prev = None @@ -369,11 +366,11 @@ class Table(Report): def __init__( self, - columns: List[Column], + columns: list[Column], rows: Collection, name: str, - alias: Optional[str] = None, - notes: Optional[str] = None, + alias: str | None = None, + notes: str | None = None, ) -> None: """Init table definition. @@ -477,7 +474,7 @@ class Table(Report): return title + formatted_table + footer - def to_csv(self, **kwargs: Any) -> List[Any]: + def to_csv(self, **kwargs: Any) -> list[Any]: """Convert table to csv format.""" headers = [[c.header for c in self.columns if c.supports_format("csv")]] @@ -528,7 +525,7 @@ class CompoundReport(Report): This class could be used for producing multiple reports at once. """ - def __init__(self, reports: List[Report]) -> None: + def __init__(self, reports: list[Report]) -> None: """Init compound report instance.""" self.reports = reports @@ -538,13 +535,13 @@ class CompoundReport(Report): Method attempts to create compound dictionary based on provided parts. """ - result: Dict[str, Any] = {} + result: dict[str, Any] = {} for item in self.reports: result.update(item.to_json(**kwargs)) return result - def to_csv(self, **kwargs: Any) -> List[Any]: + def to_csv(self, **kwargs: Any) -> list[Any]: """Convert to csv serializible format. CSV format does support only one table. In order to be able to export @@ -592,7 +589,7 @@ class CompoundReport(Report): class CompoundFormatter: """Compound data formatter.""" - def __init__(self, formatters: List[Callable]) -> None: + def __init__(self, formatters: list[Callable]) -> None: """Init compound formatter.""" self.formatters = formatters @@ -637,7 +634,7 @@ def produce_report( data: Any, formatter: Callable[[Any], Report], fmt: OutputFormat = "plain_text", - output: Optional[PathOrFileLike] = None, + output: PathOrFileLike | None = None, **kwargs: Any, ) -> None: """Produce report based on provided data.""" @@ -679,8 +676,8 @@ class Reporter: self.output_format = output_format self.print_as_submitted = print_as_submitted - self.data: List[Tuple[Any, Callable[[Any], Report]]] = [] - self.delayed: List[Tuple[Any, Callable[[Any], Report]]] = [] + self.data: list[tuple[Any, Callable[[Any], Report]]] = [] + self.delayed: list[tuple[Any, Callable[[Any], Report]]] = [] def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None: """Submit data for the report.""" @@ -713,7 +710,7 @@ class Reporter: ) self.delayed = [] - def generate_report(self, output: Optional[PathOrFileLike]) -> None: + def generate_report(self, output: PathOrFileLike | None) -> None: """Generate report.""" already_printed = ( self.print_as_submitted @@ -735,7 +732,7 @@ class Reporter: @contextmanager def get_reporter( output_format: OutputFormat, - output: Optional[PathOrFileLike], + output: PathOrFileLike | None, formatter_resolver: Callable[[Any], Callable[[Any], Report]], ) -> Generator[Reporter, None, None]: """Get reporter and generate report.""" @@ -762,7 +759,7 @@ def _apply_format_parameters( return wrapper -def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat: +def resolve_output_format(output: PathOrFileLike | None) -> OutputFormat: """Resolve output format based on the output name.""" if isinstance(output, (str, Path)): format_from_filename = Path(output).suffix.lstrip(".") diff --git a/src/mlia/core/typing.py b/src/mlia/core/typing.py new file mode 100644 index 0000000..bda995c --- /dev/null +++ b/src/mlia/core/typing.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for custom type hints.""" +from pathlib import Path +from typing import Literal +from typing import TextIO +from typing import Union + + +FileLike = TextIO +PathOrFileLike = Union[str, Path, FileLike] +OutputFormat = Literal["plain_text", "csv", "json"] diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py index 03f3d1c..d862a86 100644 --- a/src/mlia/core/workflow.py +++ b/src/mlia/core/workflow.py @@ -5,16 +5,15 @@ This module contains implementation of the workflow executors. """ +from __future__ import annotations + import itertools from abc import ABC from abc import abstractmethod from functools import wraps from typing import Any from typing import Callable -from typing import List -from typing import Optional from typing import Sequence -from typing import Tuple from mlia.core.advice_generation import Advice from mlia.core.advice_generation import AdviceEvent @@ -57,7 +56,7 @@ STAGE_ANALYSIS = (DataAnalysisStageStartedEvent(), DataAnalysisStageFinishedEven STAGE_ADVICE = (AdviceStageStartedEvent(), AdviceStageFinishedEvent()) -def on_stage(stage_events: Tuple[Event, Event]) -> Callable: +def on_stage(stage_events: tuple[Event, Event]) -> Callable: """Mark start/finish of the stage with appropriate events.""" def wrapper(method: Callable) -> Callable: @@ -87,7 +86,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): collectors: Sequence[DataCollector], analyzers: Sequence[DataAnalyzer], producers: Sequence[AdviceProducer], - startup_events: Optional[Sequence[Event]] = None, + startup_events: Sequence[Event] | None = None, ): """Init default workflow executor. @@ -130,7 +129,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): self.publish(event) @on_stage(STAGE_COLLECTION) - def collect_data(self) -> List[DataItem]: + def collect_data(self) -> list[DataItem]: """Collect data. Run each of data collector components and return list of @@ -148,7 +147,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): return collected_data @on_stage(STAGE_ANALYSIS) - def analyze_data(self, collected_data: List[DataItem]) -> List[DataItem]: + def analyze_data(self, collected_data: list[DataItem]) -> list[DataItem]: """Analyze data. Pass each collected data item into each data analyzer and @@ -168,7 +167,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): return analyzed_data @on_stage(STAGE_ADVICE) - def produce_advice(self, analyzed_data: List[DataItem]) -> None: + def produce_advice(self, analyzed_data: list[DataItem]) -> None: """Produce advice. Pass each analyzed data item into each advice producer and -- cgit v1.2.1