aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/core
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-08 14:24:39 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-09 17:21:48 +0100
commitf5b293d0927506c2a979a091bf0d07ecc78fa181 (patch)
tree4de585b7cb6ed34da8237063752270189a730a41 /src/mlia/core
parentcde0c6ee140bd108849bff40467d8f18ffc332ef (diff)
downloadmlia-f5b293d0927506c2a979a091bf0d07ecc78fa181.tar.gz
MLIA-386 Simplify typing in the source code
- Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
Diffstat (limited to 'src/mlia/core')
-rw-r--r--src/mlia/core/advice_generation.py14
-rw-r--r--src/mlia/core/advisor.py11
-rw-r--r--src/mlia/core/common.py4
-rw-r--r--src/mlia/core/context.py31
-rw-r--r--src/mlia/core/data_analysis.py9
-rw-r--r--src/mlia/core/events.py24
-rw-r--r--src/mlia/core/handlers.py10
-rw-r--r--src/mlia/core/helpers.py15
-rw-r--r--src/mlia/core/mixins.py7
-rw-r--r--src/mlia/core/performance.py19
-rw-r--r--src/mlia/core/reporting.py99
-rw-r--r--src/mlia/core/typing.py (renamed from src/mlia/core/_typing.py)0
-rw-r--r--src/mlia/core/workflow.py15
13 files changed, 129 insertions, 129 deletions
diff --git a/src/mlia/core/advice_generation.py b/src/mlia/core/advice_generation.py
index 76cc1f2..86285fe 100644
--- a/src/mlia/core/advice_generation.py
+++ b/src/mlia/core/advice_generation.py
@@ -1,14 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for advice generation."""
+from __future__ import annotations
+
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from functools import wraps
from typing import Any
from typing import Callable
-from typing import List
-from typing import Union
from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
@@ -20,7 +20,7 @@ from mlia.core.mixins import ContextMixin
class Advice:
"""Base class for the advice."""
- messages: List[str]
+ messages: list[str]
@dataclass
@@ -56,7 +56,7 @@ class AdviceProducer(ABC):
"""
@abstractmethod
- def get_advice(self) -> Union[Advice, List[Advice]]:
+ def get_advice(self) -> Advice | list[Advice]:
"""Get produced advice."""
@@ -76,13 +76,13 @@ class FactBasedAdviceProducer(ContextAwareAdviceProducer):
def __init__(self) -> None:
"""Init advice producer."""
- self.advice: List[Advice] = []
+ self.advice: list[Advice] = []
- def get_advice(self) -> Union[Advice, List[Advice]]:
+ def get_advice(self) -> Advice | list[Advice]:
"""Get produced advice."""
return self.advice
- def add_advice(self, messages: List[str]) -> None:
+ def add_advice(self, messages: list[str]) -> None:
"""Add advice."""
self.advice.append(Advice(messages))
diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py
index 13689fa..d684241 100644
--- a/src/mlia/core/advisor.py
+++ b/src/mlia/core/advisor.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Inference advisor module."""
+from __future__ import annotations
+
from abc import abstractmethod
from pathlib import Path
from typing import cast
-from typing import List
from mlia.core.advice_generation import AdviceProducer
from mlia.core.common import NamedEntity
@@ -44,19 +45,19 @@ class DefaultInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
)
@abstractmethod
- def get_collectors(self, context: Context) -> List[DataCollector]:
+ def get_collectors(self, context: Context) -> list[DataCollector]:
"""Return list of the data collectors."""
@abstractmethod
- def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
+ def get_analyzers(self, context: Context) -> list[DataAnalyzer]:
"""Return list of the data analyzers."""
@abstractmethod
- def get_producers(self, context: Context) -> List[AdviceProducer]:
+ def get_producers(self, context: Context) -> list[AdviceProducer]:
"""Return list of the advice producers."""
@abstractmethod
- def get_events(self, context: Context) -> List[Event]:
+ def get_events(self, context: Context) -> list[Event]:
"""Return list of the startup events."""
def get_string_parameter(self, context: Context, param: str) -> str:
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
index a11bf9a..63fb324 100644
--- a/src/mlia/core/common.py
+++ b/src/mlia/core/common.py
@@ -5,6 +5,8 @@
This module contains common interfaces/classess shared across
core module.
"""
+from __future__ import annotations
+
from abc import ABC
from abc import abstractmethod
from enum import auto
@@ -30,7 +32,7 @@ class AdviceCategory(Flag):
ALL = OPERATORS | PERFORMANCE | OPTIMIZATION
@classmethod
- def from_string(cls, value: str) -> "AdviceCategory":
+ def from_string(cls, value: str) -> AdviceCategory:
"""Resolve enum value from string value."""
category_names = [item.name for item in AdviceCategory]
if not value or value.upper() not in category_names:
diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py
index 83d2f7c..a4737bb 100644
--- a/src/mlia/core/context.py
+++ b/src/mlia/core/context.py
@@ -7,15 +7,14 @@ Context is an object that describes advisor working environment
and requested behavior (advice categories, input configuration
parameters).
"""
+from __future__ import annotations
+
import logging
from abc import ABC
from abc import abstractmethod
from pathlib import Path
from typing import Any
-from typing import List
from typing import Mapping
-from typing import Optional
-from typing import Union
from mlia.core.common import AdviceCategory
from mlia.core.events import DefaultEventPublisher
@@ -50,7 +49,7 @@ class Context(ABC):
@property
@abstractmethod
- def event_handlers(self) -> Optional[List[EventHandler]]:
+ def event_handlers(self) -> list[EventHandler] | None:
"""Return list of the event_handlers."""
@property
@@ -60,7 +59,7 @@ class Context(ABC):
@property
@abstractmethod
- def config_parameters(self) -> Optional[Mapping[str, Any]]:
+ def config_parameters(self) -> Mapping[str, Any] | None:
"""Return configuration parameters."""
@property
@@ -73,7 +72,7 @@ class Context(ABC):
self,
*,
advice_category: AdviceCategory,
- event_handlers: List[EventHandler],
+ event_handlers: list[EventHandler],
config_parameters: Mapping[str, Any],
) -> None:
"""Update context parameters."""
@@ -98,14 +97,14 @@ class ExecutionContext(Context):
self,
*,
advice_category: AdviceCategory = AdviceCategory.ALL,
- config_parameters: Optional[Mapping[str, Any]] = None,
- working_dir: Optional[Union[str, Path]] = None,
- event_handlers: Optional[List[EventHandler]] = None,
- event_publisher: Optional[EventPublisher] = None,
+ config_parameters: Mapping[str, Any] | None = None,
+ working_dir: str | Path | None = None,
+ event_handlers: list[EventHandler] | None = None,
+ event_publisher: EventPublisher | None = None,
verbose: bool = False,
logs_dir: str = "logs",
models_dir: str = "models",
- action_resolver: Optional[ActionResolver] = None,
+ action_resolver: ActionResolver | None = None,
) -> None:
"""Init execution context.
@@ -151,22 +150,22 @@ class ExecutionContext(Context):
self._advice_category = advice_category
@property
- def config_parameters(self) -> Optional[Mapping[str, Any]]:
+ def config_parameters(self) -> Mapping[str, Any] | None:
"""Return configuration parameters."""
return self._config_parameters
@config_parameters.setter
- def config_parameters(self, config_parameters: Optional[Mapping[str, Any]]) -> None:
+ def config_parameters(self, config_parameters: Mapping[str, Any] | None) -> None:
"""Setter for the configuration parameters."""
self._config_parameters = config_parameters
@property
- def event_handlers(self) -> Optional[List[EventHandler]]:
+ def event_handlers(self) -> list[EventHandler] | None:
"""Return list of the event handlers."""
return self._event_handlers
@event_handlers.setter
- def event_handlers(self, event_handlers: List[EventHandler]) -> None:
+ def event_handlers(self, event_handlers: list[EventHandler]) -> None:
"""Setter for the event handlers."""
self._event_handlers = event_handlers
@@ -196,7 +195,7 @@ class ExecutionContext(Context):
self,
*,
advice_category: AdviceCategory,
- event_handlers: List[EventHandler],
+ event_handlers: list[EventHandler],
config_parameters: Mapping[str, Any],
) -> None:
"""Update context parameters."""
diff --git a/src/mlia/core/data_analysis.py b/src/mlia/core/data_analysis.py
index 6adb41e..0603425 100644
--- a/src/mlia/core/data_analysis.py
+++ b/src/mlia/core/data_analysis.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for data analysis."""
+from __future__ import annotations
+
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
-from typing import List
from mlia.core.common import DataItem
from mlia.core.mixins import ContextMixin
@@ -29,7 +30,7 @@ class DataAnalyzer(ABC):
"""
@abstractmethod
- def get_analyzed_data(self) -> List[DataItem]:
+ def get_analyzed_data(self) -> list[DataItem]:
"""Get analyzed data."""
@@ -59,9 +60,9 @@ class FactExtractor(ContextAwareDataAnalyzer):
def __init__(self) -> None:
"""Init fact extractor."""
- self.facts: List[Fact] = []
+ self.facts: list[Fact] = []
- def get_analyzed_data(self) -> List[DataItem]:
+ def get_analyzed_data(self) -> list[DataItem]:
"""Return list of the collected facts."""
return self.facts
diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py
index 0b8461b..71c86e2 100644
--- a/src/mlia/core/events.py
+++ b/src/mlia/core/events.py
@@ -9,6 +9,8 @@ calling application.
Each component of the workflow can generate events of specific type.
Application can subscribe and react to those events.
"""
+from __future__ import annotations
+
import traceback
import uuid
from abc import ABC
@@ -19,11 +21,7 @@ from dataclasses import dataclass
from dataclasses import field
from functools import singledispatchmethod
from typing import Any
-from typing import Dict
from typing import Generator
-from typing import List
-from typing import Optional
-from typing import Tuple
from mlia.core.common import DataItem
@@ -41,7 +39,7 @@ class Event:
"""Generate unique ID for the event."""
self.event_id = str(uuid.uuid4())
- def compare_without_id(self, other: "Event") -> bool:
+ def compare_without_id(self, other: Event) -> bool:
"""Compare two events without event_id field."""
if not isinstance(other, Event) or self.__class__ != other.__class__:
return False
@@ -73,7 +71,7 @@ class ActionStartedEvent(Event):
"""
action_type: str
- params: Optional[Dict] = None
+ params: dict | None = None
@dataclass
@@ -84,7 +82,7 @@ class SubActionEvent(ChildEvent):
"""
action_type: str
- params: Optional[Dict] = None
+ params: dict | None = None
@dataclass
@@ -271,8 +269,8 @@ class EventDispatcherMetaclass(type):
def __new__(
cls,
clsname: str,
- bases: Tuple,
- namespace: Dict[str, Any],
+ bases: tuple[type, ...],
+ namespace: dict[str, Any],
event_handler_method_prefix: str = "on_",
) -> Any:
"""Create event dispatcher and link event handlers."""
@@ -321,7 +319,7 @@ class EventPublisher(ABC):
"""
def register_event_handlers(
- self, event_handlers: Optional[List[EventHandler]]
+ self, event_handlers: list[EventHandler] | None
) -> None:
"""Register event handlers.
@@ -354,7 +352,7 @@ class DefaultEventPublisher(EventPublisher):
def __init__(self) -> None:
"""Init the event publisher."""
- self.handlers: List[EventHandler] = []
+ self.handlers: list[EventHandler] = []
def register_event_handler(self, event_handler: EventHandler) -> None:
"""Register the event handler.
@@ -374,7 +372,7 @@ class DefaultEventPublisher(EventPublisher):
@contextmanager
def stage(
- publisher: EventPublisher, events: Tuple[Event, Event]
+ publisher: EventPublisher, events: tuple[Event, Event]
) -> Generator[None, None, None]:
"""Generate events before and after stage.
@@ -390,7 +388,7 @@ def stage(
@contextmanager
def action(
- publisher: EventPublisher, action_type: str, params: Optional[Dict] = None
+ publisher: EventPublisher, action_type: str, params: dict | None = None
) -> Generator[None, None, None]:
"""Generate events before and after action."""
action_started = ActionStartedEvent(action_type, params)
diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py
index e576f74..a3255ae 100644
--- a/src/mlia/core/handlers.py
+++ b/src/mlia/core/handlers.py
@@ -1,13 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Event handlers module."""
+from __future__ import annotations
+
import logging
from typing import Any
from typing import Callable
-from typing import List
-from typing import Optional
-from mlia.core._typing import PathOrFileLike
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
from mlia.core.events import ActionFinishedEvent
@@ -28,6 +27,7 @@ from mlia.core.events import ExecutionStartedEvent
from mlia.core.reporting import Report
from mlia.core.reporting import Reporter
from mlia.core.reporting import resolve_output_format
+from mlia.core.typing import PathOrFileLike
from mlia.utils.console import create_section_header
@@ -101,14 +101,14 @@ class WorkflowEventsHandler(SystemEventsHandler):
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output: Optional[PathOrFileLike] = None,
+ output: PathOrFileLike | None = None,
) -> None:
"""Init event handler."""
output_format = resolve_output_format(output)
self.reporter = Reporter(formatter_resolver, output_format)
self.output = output
- self.advice: List[Advice] = []
+ self.advice: list[Advice] = []
def on_execution_started(self, event: ExecutionStartedEvent) -> None:
"""Handle ExecutionStarted event."""
diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py
index d10ea5d..f0c4474 100644
--- a/src/mlia/core/helpers.py
+++ b/src/mlia/core/helpers.py
@@ -2,34 +2,35 @@
# SPDX-License-Identifier: Apache-2.0
"""Module for various helper classes."""
# pylint: disable=no-self-use, unused-argument
+from __future__ import annotations
+
from typing import Any
-from typing import List
class ActionResolver:
"""Helper class for generating actions (e.g. commands with parameters)."""
- def apply_optimizations(self, **kwargs: Any) -> List[str]:
+ def apply_optimizations(self, **kwargs: Any) -> list[str]:
"""Return action details for applying optimizations."""
return []
- def supported_operators_info(self) -> List[str]:
+ def supported_operators_info(self) -> list[str]:
"""Return action details for generating supported ops report."""
return []
- def check_performance(self) -> List[str]:
+ def check_performance(self) -> list[str]:
"""Return action details for checking performance."""
return []
- def check_operator_compatibility(self) -> List[str]:
+ def check_operator_compatibility(self) -> list[str]:
"""Return action details for checking op compatibility."""
return []
- def operator_compatibility_details(self) -> List[str]:
+ def operator_compatibility_details(self) -> list[str]:
"""Return action details for getting more information about op compatibility."""
return []
- def optimization_details(self) -> List[str]:
+ def optimization_details(self) -> list[str]:
"""Return action detail for getting information about optimizations."""
return []
diff --git a/src/mlia/core/mixins.py b/src/mlia/core/mixins.py
index ee03100..5ef9d66 100644
--- a/src/mlia/core/mixins.py
+++ b/src/mlia/core/mixins.py
@@ -1,8 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Mixins module."""
+from __future__ import annotations
+
from typing import Any
-from typing import Optional
from mlia.core.context import Context
@@ -27,8 +28,8 @@ class ParameterResolverMixin:
section: str,
name: str,
expected: bool = True,
- expected_type: Optional[type] = None,
- context: Optional[Context] = None,
+ expected_type: type | None = None,
+ context: Context | None = None,
) -> Any:
"""Get parameter value."""
ctx = context or self.context
diff --git a/src/mlia/core/performance.py b/src/mlia/core/performance.py
index 5433d5c..cb12918 100644
--- a/src/mlia/core/performance.py
+++ b/src/mlia/core/performance.py
@@ -1,30 +1,31 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for performance estimation."""
+from __future__ import annotations
+
from abc import abstractmethod
from typing import Callable
from typing import Generic
-from typing import List
from typing import TypeVar
-ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
-PerfMetricsType = TypeVar("PerfMetricsType") # pylint: disable=invalid-name
+M = TypeVar("M") # model type
+P = TypeVar("P") # performance metrics
-class PerformanceEstimator(Generic[ModelType, PerfMetricsType]):
+class PerformanceEstimator(Generic[M, P]):
"""Base class for the performance estimation."""
@abstractmethod
- def estimate(self, model: ModelType) -> PerfMetricsType:
+ def estimate(self, model: M) -> P:
"""Estimate performance."""
def estimate_performance(
- original_model: ModelType,
- estimator: PerformanceEstimator[ModelType, PerfMetricsType],
- model_transformations: List[Callable[[ModelType], ModelType]],
-) -> List[PerfMetricsType]:
+ original_model: M,
+ estimator: PerformanceEstimator[M, P],
+ model_transformations: list[Callable[[M], M]],
+) -> list[P]:
"""Estimate performance impact.
This function estimates performance impact on model performance after
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index 58a41d3..0c8fabc 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Reporting module."""
+from __future__ import annotations
+
import csv
import json
import logging
@@ -19,19 +21,14 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Collection
-from typing import Dict
from typing import Generator
from typing import Iterable
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
import numpy as np
-from mlia.core._typing import FileLike
-from mlia.core._typing import OutputFormat
-from mlia.core._typing import PathOrFileLike
+from mlia.core.typing import FileLike
+from mlia.core.typing import OutputFormat
+from mlia.core.typing import PathOrFileLike
from mlia.utils.console import apply_style
from mlia.utils.console import produce_table
from mlia.utils.logging import LoggerWriter
@@ -48,7 +45,7 @@ class Report(ABC):
"""Convert to json serializible format."""
@abstractmethod
- def to_csv(self, **kwargs: Any) -> List[Any]:
+ def to_csv(self, **kwargs: Any) -> list[Any]:
"""Convert to csv serializible format."""
@abstractmethod
@@ -62,9 +59,9 @@ class ReportItem:
def __init__(
self,
name: str,
- alias: Optional[str] = None,
- value: Optional[Union[str, int, "Cell"]] = None,
- nested_items: Optional[List["ReportItem"]] = None,
+ alias: str | None = None,
+ value: str | int | Cell | None = None,
+ nested_items: list[ReportItem] | None = None,
) -> None:
"""Init the report item."""
self.name = name
@@ -98,9 +95,9 @@ class Format:
:param style: text style
"""
- wrap_width: Optional[int] = None
- str_fmt: Optional[Union[str, Callable[[Any], str]]] = None
- style: Optional[str] = None
+ wrap_width: int | None = None
+ str_fmt: str | Callable[[Any], str] | None = None
+ style: str | None = None
@dataclass
@@ -112,7 +109,7 @@ class Cell:
"""
value: Any
- fmt: Optional[Format] = None
+ fmt: Format | None = None
def _apply_style(self, value: str) -> str:
"""Apply style to the value."""
@@ -151,7 +148,7 @@ class CountAwareCell(Cell):
def __init__(
self,
- value: Optional[Union[int, float]],
+ value: int | float | None,
singular: str,
plural: str,
format_string: str = ",d",
@@ -159,7 +156,7 @@ class CountAwareCell(Cell):
"""Init cell instance."""
self.unit = singular if value == 1 else plural
- def format_value(val: Optional[Union[int, float]]) -> str:
+ def format_value(val: int | float | None) -> str:
"""Provide string representation for the value."""
if val is None:
return ""
@@ -183,7 +180,7 @@ class CountAwareCell(Cell):
class BytesCell(CountAwareCell):
"""Cell that represents memory size."""
- def __init__(self, value: Optional[int]) -> None:
+ def __init__(self, value: int | None) -> None:
"""Init cell instance."""
super().__init__(value, "byte", "bytes")
@@ -191,7 +188,7 @@ class BytesCell(CountAwareCell):
class CyclesCell(CountAwareCell):
"""Cell that represents cycles."""
- def __init__(self, value: Optional[Union[int, float]]) -> None:
+ def __init__(self, value: int | float | None) -> None:
"""Init cell instance."""
super().__init__(value, "cycle", "cycles", ",.0f")
@@ -199,7 +196,7 @@ class CyclesCell(CountAwareCell):
class ClockCell(CountAwareCell):
"""Cell that represents clock value."""
- def __init__(self, value: Optional[Union[int, float]]) -> None:
+ def __init__(self, value: int | float | None) -> None:
"""Init cell instance."""
super().__init__(value, "Hz", "Hz", ",.0f")
@@ -210,9 +207,9 @@ class Column:
def __init__(
self,
header: str,
- alias: Optional[str] = None,
- fmt: Optional[Format] = None,
- only_for: Optional[List[str]] = None,
+ alias: str | None = None,
+ fmt: Format | None = None,
+ only_for: list[str] | None = None,
) -> None:
"""Init column definition.
@@ -228,7 +225,7 @@ class Column:
self.fmt = fmt
self.only_for = only_for
- def supports_format(self, fmt: str) -> bool:
+ def supports_format(self, fmt: OutputFormat) -> bool:
"""Return true if column should be shown."""
return not self.only_for or fmt in self.only_for
@@ -236,20 +233,20 @@ class Column:
class NestedReport(Report):
"""Report with nested items."""
- def __init__(self, name: str, alias: str, items: List[ReportItem]) -> None:
+ def __init__(self, name: str, alias: str, items: list[ReportItem]) -> None:
"""Init nested report."""
self.name = name
self.alias = alias
self.items = items
- def to_csv(self, **kwargs: Any) -> List[Any]:
+ def to_csv(self, **kwargs: Any) -> list[Any]:
"""Convert to csv serializible format."""
result = {}
def collect_item_values(
item: ReportItem,
- _parent: Optional[ReportItem],
- _prev: Optional[ReportItem],
+ _parent: ReportItem | None,
+ _prev: ReportItem | None,
_level: int,
) -> None:
"""Collect item values into a dictionary.."""
@@ -279,13 +276,13 @@ class NestedReport(Report):
def to_json(self, **kwargs: Any) -> Any:
"""Convert to json serializible format."""
- per_parent: Dict[Optional[ReportItem], Dict] = defaultdict(dict)
+ per_parent: dict[ReportItem | None, dict] = defaultdict(dict)
result = per_parent[None]
def collect_as_dicts(
item: ReportItem,
- parent: Optional[ReportItem],
- _prev: Optional[ReportItem],
+ parent: ReportItem | None,
+ _prev: ReportItem | None,
_level: int,
) -> None:
"""Collect item values as nested dictionaries."""
@@ -313,8 +310,8 @@ class NestedReport(Report):
def convert_to_text(
item: ReportItem,
- _parent: Optional[ReportItem],
- prev: Optional[ReportItem],
+ _parent: ReportItem | None,
+ prev: ReportItem | None,
level: int,
) -> None:
"""Convert item to text representation."""
@@ -345,12 +342,12 @@ class NestedReport(Report):
def _traverse(
self,
- items: List[ReportItem],
+ items: list[ReportItem],
visit_item: Callable[
- [ReportItem, Optional[ReportItem], Optional[ReportItem], int], None
+ [ReportItem, ReportItem | None, ReportItem | None, int], None
],
level: int = 1,
- parent: Optional[ReportItem] = None,
+ parent: ReportItem | None = None,
) -> None:
"""Traverse through items."""
prev = None
@@ -369,11 +366,11 @@ class Table(Report):
def __init__(
self,
- columns: List[Column],
+ columns: list[Column],
rows: Collection,
name: str,
- alias: Optional[str] = None,
- notes: Optional[str] = None,
+ alias: str | None = None,
+ notes: str | None = None,
) -> None:
"""Init table definition.
@@ -477,7 +474,7 @@ class Table(Report):
return title + formatted_table + footer
- def to_csv(self, **kwargs: Any) -> List[Any]:
+ def to_csv(self, **kwargs: Any) -> list[Any]:
"""Convert table to csv format."""
headers = [[c.header for c in self.columns if c.supports_format("csv")]]
@@ -528,7 +525,7 @@ class CompoundReport(Report):
This class could be used for producing multiple reports at once.
"""
- def __init__(self, reports: List[Report]) -> None:
+ def __init__(self, reports: list[Report]) -> None:
"""Init compound report instance."""
self.reports = reports
@@ -538,13 +535,13 @@ class CompoundReport(Report):
Method attempts to create compound dictionary based on provided
parts.
"""
- result: Dict[str, Any] = {}
+ result: dict[str, Any] = {}
for item in self.reports:
result.update(item.to_json(**kwargs))
return result
- def to_csv(self, **kwargs: Any) -> List[Any]:
+ def to_csv(self, **kwargs: Any) -> list[Any]:
"""Convert to csv serializible format.
CSV format does support only one table. In order to be able to export
@@ -592,7 +589,7 @@ class CompoundReport(Report):
class CompoundFormatter:
"""Compound data formatter."""
- def __init__(self, formatters: List[Callable]) -> None:
+ def __init__(self, formatters: list[Callable]) -> None:
"""Init compound formatter."""
self.formatters = formatters
@@ -637,7 +634,7 @@ def produce_report(
data: Any,
formatter: Callable[[Any], Report],
fmt: OutputFormat = "plain_text",
- output: Optional[PathOrFileLike] = None,
+ output: PathOrFileLike | None = None,
**kwargs: Any,
) -> None:
"""Produce report based on provided data."""
@@ -679,8 +676,8 @@ class Reporter:
self.output_format = output_format
self.print_as_submitted = print_as_submitted
- self.data: List[Tuple[Any, Callable[[Any], Report]]] = []
- self.delayed: List[Tuple[Any, Callable[[Any], Report]]] = []
+ self.data: list[tuple[Any, Callable[[Any], Report]]] = []
+ self.delayed: list[tuple[Any, Callable[[Any], Report]]] = []
def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None:
"""Submit data for the report."""
@@ -713,7 +710,7 @@ class Reporter:
)
self.delayed = []
- def generate_report(self, output: Optional[PathOrFileLike]) -> None:
+ def generate_report(self, output: PathOrFileLike | None) -> None:
"""Generate report."""
already_printed = (
self.print_as_submitted
@@ -735,7 +732,7 @@ class Reporter:
@contextmanager
def get_reporter(
output_format: OutputFormat,
- output: Optional[PathOrFileLike],
+ output: PathOrFileLike | None,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
) -> Generator[Reporter, None, None]:
"""Get reporter and generate report."""
@@ -762,7 +759,7 @@ def _apply_format_parameters(
return wrapper
-def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat:
+def resolve_output_format(output: PathOrFileLike | None) -> OutputFormat:
"""Resolve output format based on the output name."""
if isinstance(output, (str, Path)):
format_from_filename = Path(output).suffix.lstrip(".")
diff --git a/src/mlia/core/_typing.py b/src/mlia/core/typing.py
index bda995c..bda995c 100644
--- a/src/mlia/core/_typing.py
+++ b/src/mlia/core/typing.py
diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py
index 03f3d1c..d862a86 100644
--- a/src/mlia/core/workflow.py
+++ b/src/mlia/core/workflow.py
@@ -5,16 +5,15 @@
This module contains implementation of the workflow
executors.
"""
+from __future__ import annotations
+
import itertools
from abc import ABC
from abc import abstractmethod
from functools import wraps
from typing import Any
from typing import Callable
-from typing import List
-from typing import Optional
from typing import Sequence
-from typing import Tuple
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
@@ -57,7 +56,7 @@ STAGE_ANALYSIS = (DataAnalysisStageStartedEvent(), DataAnalysisStageFinishedEven
STAGE_ADVICE = (AdviceStageStartedEvent(), AdviceStageFinishedEvent())
-def on_stage(stage_events: Tuple[Event, Event]) -> Callable:
+def on_stage(stage_events: tuple[Event, Event]) -> Callable:
"""Mark start/finish of the stage with appropriate events."""
def wrapper(method: Callable) -> Callable:
@@ -87,7 +86,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
collectors: Sequence[DataCollector],
analyzers: Sequence[DataAnalyzer],
producers: Sequence[AdviceProducer],
- startup_events: Optional[Sequence[Event]] = None,
+ startup_events: Sequence[Event] | None = None,
):
"""Init default workflow executor.
@@ -130,7 +129,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
self.publish(event)
@on_stage(STAGE_COLLECTION)
- def collect_data(self) -> List[DataItem]:
+ def collect_data(self) -> list[DataItem]:
"""Collect data.
Run each of data collector components and return list of
@@ -148,7 +147,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
return collected_data
@on_stage(STAGE_ANALYSIS)
- def analyze_data(self, collected_data: List[DataItem]) -> List[DataItem]:
+ def analyze_data(self, collected_data: list[DataItem]) -> list[DataItem]:
"""Analyze data.
Pass each collected data item into each data analyzer and
@@ -168,7 +167,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
return analyzed_data
@on_stage(STAGE_ADVICE)
- def produce_advice(self, analyzed_data: List[DataItem]) -> None:
+ def produce_advice(self, analyzed_data: list[DataItem]) -> None:
"""Produce advice.
Pass each analyzed data item into each advice producer and