diff options
Diffstat (limited to 'src/mlia/core')
-rw-r--r-- | src/mlia/core/__init__.py | 21 | ||||
-rw-r--r-- | src/mlia/core/_typing.py | 12 | ||||
-rw-r--r-- | src/mlia/core/advice_generation.py | 106 | ||||
-rw-r--r-- | src/mlia/core/advisor.py | 21 | ||||
-rw-r--r-- | src/mlia/core/common.py | 47 | ||||
-rw-r--r-- | src/mlia/core/context.py | 218 | ||||
-rw-r--r-- | src/mlia/core/data_analysis.py | 70 | ||||
-rw-r--r-- | src/mlia/core/data_collection.py | 37 | ||||
-rw-r--r-- | src/mlia/core/errors.py | 18 | ||||
-rw-r--r-- | src/mlia/core/events.py | 455 | ||||
-rw-r--r-- | src/mlia/core/helpers.py | 38 | ||||
-rw-r--r-- | src/mlia/core/mixins.py | 54 | ||||
-rw-r--r-- | src/mlia/core/performance.py | 47 | ||||
-rw-r--r-- | src/mlia/core/reporting.py | 762 | ||||
-rw-r--r-- | src/mlia/core/workflow.py | 216 |
15 files changed, 2122 insertions, 0 deletions
diff --git a/src/mlia/core/__init__.py b/src/mlia/core/__init__.py new file mode 100644 index 0000000..49b1830 --- /dev/null +++ b/src/mlia/core/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Core module. + +Core module contains the main components that are used in the workflow of +ML Inference Advisor: + - data collectors + - data analyzers + - advice producers + - event publishers + - event handlers + +The workflow of ML Inference Advisor consists of 3 stages: + - data collection + - data analysis + - advice generation + +Data is being passed from one stage to another via workflow executor. +Results (collected data, analyzed data, advice, etc) are being published via +publish/subscribe mechanishm. +""" diff --git a/src/mlia/core/_typing.py b/src/mlia/core/_typing.py new file mode 100644 index 0000000..bda995c --- /dev/null +++ b/src/mlia/core/_typing.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for custom type hints.""" +from pathlib import Path +from typing import Literal +from typing import TextIO +from typing import Union + + +FileLike = TextIO +PathOrFileLike = Union[str, Path, FileLike] +OutputFormat = Literal["plain_text", "csv", "json"] diff --git a/src/mlia/core/advice_generation.py b/src/mlia/core/advice_generation.py new file mode 100644 index 0000000..76cc1f2 --- /dev/null +++ b/src/mlia/core/advice_generation.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for advice generation.""" +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from functools import wraps +from typing import Any +from typing import Callable +from typing import List +from typing import Union + +from mlia.core.common import AdviceCategory +from mlia.core.common import DataItem +from mlia.core.events import SystemEvent +from mlia.core.mixins import ContextMixin + + +@dataclass +class Advice: + """Base class for the advice.""" + + messages: List[str] + + +@dataclass +class AdviceEvent(SystemEvent): + """Advice event. + + This event is published for every produced advice. + + :param advice: Advice instance + """ + + advice: Advice + + +class AdviceProducer(ABC): + """Base class for the advice producer. + + Producer has two methods for advice generation: + - produce_advice - used to generate advice based on provided + data (analyzed data item from analyze stage) + - get_advice - used for getting generated advice + + Advice producers that have predefined advice could skip + implementation of produce_advice method. + """ + + @abstractmethod + def produce_advice(self, data_item: DataItem) -> None: + """Process data item and produce advice. + + :param data_item: piece of data that could be used + for advice generation + """ + + @abstractmethod + def get_advice(self) -> Union[Advice, List[Advice]]: + """Get produced advice.""" + + +class ContextAwareAdviceProducer(AdviceProducer, ContextMixin): + """Context aware advice producer. + + This class makes easier access to the Context object. Context object could + be automatically injected during workflow configuration. + """ + + +class FactBasedAdviceProducer(ContextAwareAdviceProducer): + """Advice producer based on provided facts. + + This is an utility class that maintain list of generated Advice instances. + """ + + def __init__(self) -> None: + """Init advice producer.""" + self.advice: List[Advice] = [] + + def get_advice(self) -> Union[Advice, List[Advice]]: + """Get produced advice.""" + return self.advice + + def add_advice(self, messages: List[str]) -> None: + """Add advice.""" + self.advice.append(Advice(messages)) + + +def advice_category(*categories: AdviceCategory) -> Callable: + """Filter advice generation handler by advice category.""" + + def wrapper(handler: Callable) -> Callable: + """Wrap data handler.""" + + @wraps(handler) + def check_category(self: Any, *args: Any, **kwargs: Any) -> Any: + """Check if handler can produce advice for the requested category.""" + if not self.context.any_category_enabled(*categories): + return + + handler(self, *args, **kwargs) + + return check_category + + return wrapper diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py new file mode 100644 index 0000000..868d0c7 --- /dev/null +++ b/src/mlia/core/advisor.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Inference advisor module.""" +from abc import abstractmethod + +from mlia.core.common import NamedEntity +from mlia.core.context import Context +from mlia.core.workflow import WorkflowExecutor + + +class InferenceAdvisor(NamedEntity): + """Base class for inference advisors.""" + + @abstractmethod + def configure(self, context: Context) -> WorkflowExecutor: + """Configure advisor execution.""" + + def run(self, context: Context) -> None: + """Run inference advisor.""" + executor = self.configure(context) + executor.run() diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py new file mode 100644 index 0000000..5fbad42 --- /dev/null +++ b/src/mlia/core/common.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Common module. + +This module contains common interfaces/classess shared across +core module. +""" +from abc import ABC +from abc import abstractmethod +from enum import Enum +from typing import Any + +# This type is used as type alias for the items which are being passed around +# in advisor workflow. There are no restrictions on the type of the +# object. This alias used only to emphasize the nature of the input/output +# arguments. +DataItem = Any + + +class AdviceCategory(Enum): + """Advice category. + + Enumeration of advice categories supported by ML Inference Advisor. + """ + + OPERATORS = 1 + PERFORMANCE = 2 + OPTIMIZATION = 3 + ALL = 4 + + @classmethod + def from_string(cls, value: str) -> "AdviceCategory": + """Resolve enum value from string value.""" + category_names = [item.name for item in AdviceCategory] + if not value or value.upper() not in category_names: + raise Exception(f"Invalid advice category {value}") + + return AdviceCategory[value.upper()] + + +class NamedEntity(ABC): + """Entity with a name and description.""" + + @classmethod + @abstractmethod + def name(cls) -> str: + """Return name of the entity.""" diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py new file mode 100644 index 0000000..8b3dd2c --- /dev/null +++ b/src/mlia/core/context.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Context module. + +This module contains functionality related to the Context. +Context is an object that describes advisor working environment +and requested behavior (advice categories, input configuration +parameters). +""" +import logging +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from typing import Any +from typing import List +from typing import Mapping +from typing import Optional +from typing import Union + +from mlia.core.common import AdviceCategory +from mlia.core.events import DefaultEventPublisher +from mlia.core.events import EventHandler +from mlia.core.events import EventPublisher +from mlia.core.helpers import ActionResolver +from mlia.core.helpers import APIActionResolver + +logger = logging.getLogger(__name__) + + +class Context(ABC): + """Abstract class for the execution context.""" + + @abstractmethod + def get_model_path(self, model_filename: str) -> Path: + """Return path for the intermediate/optimized models. + + During workflow execution different parts of the advisor + require creating intermediate files for models. + + This method allows to provide paths where those models + could be saved. + + :param model_filename: filename of the model + """ + + @property + @abstractmethod + def event_publisher(self) -> EventPublisher: + """Return event publisher.""" + + @property + @abstractmethod + def event_handlers(self) -> Optional[List[EventHandler]]: + """Return list of the event_handlers.""" + + @property + @abstractmethod + def advice_category(self) -> Optional[AdviceCategory]: + """Return advice category.""" + + @property + @abstractmethod + def config_parameters(self) -> Optional[Mapping[str, Any]]: + """Return configuration parameters.""" + + @property + @abstractmethod + def action_resolver(self) -> ActionResolver: + """Return action resolver.""" + + @abstractmethod + def update( + self, + *, + advice_category: AdviceCategory, + event_handlers: List[EventHandler], + config_parameters: Mapping[str, Any], + ) -> None: + """Update context parameters.""" + + def category_enabled(self, category: AdviceCategory) -> bool: + """Check if category enabled.""" + return category == self.advice_category + + def any_category_enabled(self, *categories: AdviceCategory) -> bool: + """Return true if any category is enabled.""" + return self.advice_category in categories + + def register_event_handlers(self) -> None: + """Register event handlers.""" + self.event_publisher.register_event_handlers(self.event_handlers) + + +class ExecutionContext(Context): + """Execution context.""" + + def __init__( + self, + *, + advice_category: Optional[AdviceCategory] = None, + config_parameters: Optional[Mapping[str, Any]] = None, + working_dir: Optional[Union[str, Path]] = None, + event_handlers: Optional[List[EventHandler]] = None, + event_publisher: Optional[EventPublisher] = None, + verbose: bool = False, + logs_dir: str = "logs", + models_dir: str = "models", + action_resolver: Optional[ActionResolver] = None, + ) -> None: + """Init execution context. + + :param advice_category: requested advice category + :param config_parameters: dictionary like object with input parameters + :param working_dir: path to the directory that will be used as a place + to store temporary files, logs, models. If not provided then + current working directory will be used instead + :param event_handlers: optional list of event handlers + :param event_publisher: optional event publisher instance. If not provided + then default implementation of event publisher will be used + :param verbose: enable verbose output + :param logs_dir: name of the directory inside working directory where + log files will be stored + :param models_dir: name of the directory inside working directory where + temporary models will be stored + :param action_resolver: instance of the action resolver that could make + advice actionable + """ + self._advice_category = advice_category + self._config_parameters = config_parameters + + self._working_dir_path = Path.cwd() + if working_dir: + self._working_dir_path = Path(working_dir) + self._working_dir_path.mkdir(exist_ok=True) + + self._event_handlers = event_handlers + self._event_publisher = event_publisher or DefaultEventPublisher() + self.verbose = verbose + self.logs_dir = logs_dir + self.models_dir = models_dir + self._action_resolver = action_resolver or APIActionResolver() + + @property + def advice_category(self) -> Optional[AdviceCategory]: + """Return advice category.""" + return self._advice_category + + @advice_category.setter + def advice_category(self, advice_category: AdviceCategory) -> None: + """Setter for the advice category.""" + self._advice_category = advice_category + + @property + def config_parameters(self) -> Optional[Mapping[str, Any]]: + """Return configuration parameters.""" + return self._config_parameters + + @config_parameters.setter + def config_parameters(self, config_parameters: Optional[Mapping[str, Any]]) -> None: + """Setter for the configuration parameters.""" + self._config_parameters = config_parameters + + @property + def event_handlers(self) -> Optional[List[EventHandler]]: + """Return list of the event handlers.""" + return self._event_handlers + + @event_handlers.setter + def event_handlers(self, event_handlers: List[EventHandler]) -> None: + """Setter for the event handlers.""" + self._event_handlers = event_handlers + + @property + def event_publisher(self) -> EventPublisher: + """Return event publisher.""" + return self._event_publisher + + @property + def action_resolver(self) -> ActionResolver: + """Return action resolver.""" + return self._action_resolver + + def get_model_path(self, model_filename: str) -> Path: + """Return path for the model.""" + models_dir_path = self._working_dir_path / self.models_dir + models_dir_path.mkdir(exist_ok=True) + + return models_dir_path / model_filename + + @property + def logs_path(self) -> Path: + """Return path to the logs directory.""" + return self._working_dir_path / self.logs_dir + + def update( + self, + *, + advice_category: AdviceCategory, + event_handlers: List[EventHandler], + config_parameters: Mapping[str, Any], + ) -> None: + """Update context parameters.""" + self._advice_category = advice_category + self._event_handlers = event_handlers + self._config_parameters = config_parameters + + def __str__(self) -> str: + """Return string representation.""" + category = ( + "<not set>" if self.advice_category is None else self.advice_category.name + ) + + return ( + f"ExecutionContext: working_dir={self._working_dir_path}, " + f"advice_category={category}, " + f"config_parameters={self.config_parameters}, " + f"verbose={self.verbose}" + ) diff --git a/src/mlia/core/data_analysis.py b/src/mlia/core/data_analysis.py new file mode 100644 index 0000000..6adb41e --- /dev/null +++ b/src/mlia/core/data_analysis.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for data analysis.""" +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from typing import List + +from mlia.core.common import DataItem +from mlia.core.mixins import ContextMixin + + +class DataAnalyzer(ABC): + """Base class for the data analysis. + + Purpose of this class is to extract valuable data out of + collected data which could be used for advice generation. + + This process consists of two steps: + - analyze every item of the collected data + - get analyzed data + """ + + @abstractmethod + def analyze_data(self, data_item: DataItem) -> None: + """Analyze data. + + :param data_item: item of the collected data + """ + + @abstractmethod + def get_analyzed_data(self) -> List[DataItem]: + """Get analyzed data.""" + + +class ContextAwareDataAnalyzer(DataAnalyzer, ContextMixin): + """Context aware data analyzer. + + This class makes easier access to the Context object. Context object could + be automatically injected during workflow configuration. + """ + + +@dataclass +class Fact: + """Base class for the facts. + + Fact represents some piece of knowledge about collected + data. + """ + + +class FactExtractor(ContextAwareDataAnalyzer): + """Data analyzer based on extracting facts. + + Utility class that makes fact extraction easier. + Class maintains list of the extracted facts. + """ + + def __init__(self) -> None: + """Init fact extractor.""" + self.facts: List[Fact] = [] + + def get_analyzed_data(self) -> List[DataItem]: + """Return list of the collected facts.""" + return self.facts + + def add_fact(self, fact: Fact) -> None: + """Add fact.""" + self.facts.append(fact) diff --git a/src/mlia/core/data_collection.py b/src/mlia/core/data_collection.py new file mode 100644 index 0000000..43b6d1c --- /dev/null +++ b/src/mlia/core/data_collection.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for data collection. + +This module contains base classes for the first stage +of the ML Inference Advisor workflow - data collection. +""" +from abc import abstractmethod + +from mlia.core.common import DataItem +from mlia.core.common import NamedEntity +from mlia.core.mixins import ContextMixin +from mlia.core.mixins import ParameterResolverMixin + + +class DataCollector(NamedEntity): + """Base class for the data collection. + + Data collection is the first step in the process of the advice + generation. + + Different implementations of this class can provide various + information about model or device. This information is being used + at later stages. + """ + + @abstractmethod + def collect_data(self) -> DataItem: + """Collect data.""" + + +class ContextAwareDataCollector(DataCollector, ContextMixin, ParameterResolverMixin): + """Context aware data collector. + + This class makes easier access to the Context object. Context object could + be automatically injected during workflow configuration. + """ diff --git a/src/mlia/core/errors.py b/src/mlia/core/errors.py new file mode 100644 index 0000000..7d6beb1 --- /dev/null +++ b/src/mlia/core/errors.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""MLIA exceptions module.""" + + +class ConfigurationError(Exception): + """Configuration error.""" + + +class FunctionalityNotSupportedError(Exception): + """Functionality is not supported error.""" + + def __init__(self, reason: str, description: str) -> None: + """Init exception.""" + super().__init__(f"{reason}: {description}") + + self.reason = reason + self.description = description diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py new file mode 100644 index 0000000..10aec86 --- /dev/null +++ b/src/mlia/core/events.py @@ -0,0 +1,455 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for the events and related functionality. + +This module represents one of the main component of the workflow - +events publishing and provides a way for delivering results to the +calling application. + +Each component of the workflow can generate events of specific type. +Application can subscribe and react to those events. +""" +import traceback +import uuid +from abc import ABC +from abc import abstractmethod +from contextlib import contextmanager +from dataclasses import asdict +from dataclasses import dataclass +from dataclasses import field +from functools import singledispatchmethod +from typing import Any +from typing import Dict +from typing import Generator +from typing import List +from typing import Optional +from typing import Tuple + +from mlia.core.common import DataItem + + +@dataclass +class Event: + """Base class for the events. + + This class is used as a root node of the events class hierarchy. + """ + + event_id: str = field(init=False) + + def __post_init__(self) -> None: + """Generate unique ID for the event.""" + self.event_id = str(uuid.uuid4()) + + def compare_without_id(self, other: "Event") -> bool: + """Compare two events without event_id field.""" + if not isinstance(other, Event) or self.__class__ != other.__class__: + return False + + self_as_dict = asdict(self) + self_as_dict.pop("event_id") + + other_as_dict = asdict(other) + other_as_dict.pop("event_id") + + return self_as_dict == other_as_dict + + +@dataclass +class ChildEvent(Event): + """Child event. + + This class could be used to link event with the parent event. + """ + + parent_event_id: str + + +@dataclass +class ActionStartedEvent(Event): + """Action started event. + + This event is published when some action has been started. + """ + + action_type: str + params: Optional[Dict] = None + + +@dataclass +class SubActionEvent(ChildEvent): + """SubAction event. + + This event could be used to represent some action during parent action. + """ + + action_type: str + params: Optional[Dict] = None + + +@dataclass +class ActionFinishedEvent(ChildEvent): + """Action finished event. + + This event is published when some action has been finished. + """ + + +@dataclass +class SystemEvent(Event): + """System event. + + System event class represents events that published by components + of the core module. Most common example is an workflow executor + that publishes number of system events for starting/completion + of different stages/workflow. + + Events that published by components outside of core module should not + use this class as base class. + """ + + +@dataclass +class ExecutionStartedEvent(SystemEvent): + """Execution started event. + + This event is published when workflow execution started. + """ + + +@dataclass +class ExecutionFinishedEvent(SystemEvent): + """Execution finished event. + + This event is published when workflow execution finished. + """ + + +@dataclass +class ExecutionFailedEvent(SystemEvent): + """Execution failed event.""" + + err: Exception + + +@dataclass +class DataCollectionStageStartedEvent(SystemEvent): + """Data collection stage started. + + This event is published when data collection stage started. + """ + + +@dataclass +class DataCollectorSkippedEvent(SystemEvent): + """Data collector skipped event. + + This event is published when particular data collector can + not provide data for the provided parameters. + """ + + data_collector: str + reason: str + + +@dataclass +class DataCollectionStageFinishedEvent(SystemEvent): + """Data collection stage finished. + + This event is published when data collection stage finished. + """ + + +@dataclass +class DataAnalysisStageStartedEvent(SystemEvent): + """Data analysis stage started. + + This event is published when data analysis stage started. + """ + + +@dataclass +class DataAnalysisStageFinishedEvent(SystemEvent): + """Data analysis stage finished. + + This event is published when data analysis stage finished. + """ + + +@dataclass +class AdviceStageStartedEvent(SystemEvent): + """Advace producing stage started. + + This event is published when advice generation stage started. + """ + + +@dataclass +class AdviceStageFinishedEvent(SystemEvent): + """Advace producing stage finished. + + This event is published when advice generation stage finished. + """ + + +@dataclass +class CollectedDataEvent(SystemEvent): + """Collected data event. + + This event is published for every collected data item. + + :param data_item: collected data item + """ + + data_item: DataItem + + +@dataclass +class AnalyzedDataEvent(SystemEvent): + """Analyzed data event. + + This event is published for every analyzed data item. + + :param data_item: analyzed data item + """ + + data_item: DataItem + + +class EventHandler: + """Base class for the event handlers. + + Each event handler should derive from this base class. + """ + + def handle_event(self, event: Event) -> None: + """Handle the event. + + By default all published events are being passed to each + registered event handler. It is handler's responsibility + to filter events that it interested in. + """ + + +class DebugEventHandler(EventHandler): + """Event handler for debugging purposes. + + This handler could print every published event to the + standard output. + """ + + def __init__(self, with_stacktrace: bool = False) -> None: + """Init event handler. + + :param with_stacktrace: enable printing stacktrace of the + place where event publishing occurred. + """ + self.with_stacktrace = with_stacktrace + + def handle_event(self, event: Event) -> None: + """Handle event.""" + print(f"Got event {event}") + + if self.with_stacktrace: + traceback.print_stack() + + +class EventDispatcherMetaclass(type): + """Metaclass for event dispatching. + + It could be tedious to check type of the published event + inside event handler. Instead the following convention could be + established: if method name of the class starts with some + prefix then it is considered to be event handler of particular + type. + + This metaclass goes through the list of class methods and + links all methods with the prefix "on_" to the common dispatcher + method. + """ + + def __new__( + cls, + clsname: str, + bases: Tuple, + namespace: Dict[str, Any], + event_handler_method_prefix: str = "on_", + ) -> Any: + """Create event dispatcher and link event handlers.""" + new_class = super().__new__(cls, clsname, bases, namespace) + + @singledispatchmethod + def dispatcher(_self: Any, _event: Event) -> Any: + """Event dispatcher.""" + + # get all class methods which starts with particular prefix + event_handler_methods = ( + (item_name, item) + for item_name in dir(new_class) + if callable((item := getattr(new_class, item_name))) + and item_name.startswith(event_handler_method_prefix) + ) + + # link all collected event handlers to one dispatcher method + for method_name, method in event_handler_methods: + event_handler = dispatcher.register(method) + setattr(new_class, method_name, event_handler) + + # override default handle_event method, replace it with the + # dispatcher + setattr(new_class, "handle_event", dispatcher) + + return new_class + + +class EventDispatcher(EventHandler, metaclass=EventDispatcherMetaclass): + """Event dispatcher.""" + + +class EventPublisher(ABC): + """Base class for the event publisher. + + Event publisher is a intermidiate component between event emitter + and event consumer. + """ + + @abstractmethod + def register_event_handler(self, event_handler: EventHandler) -> None: + """Register event handler. + + :param event_handler: instance of the event handler + """ + + def register_event_handlers( + self, event_handlers: Optional[List[EventHandler]] + ) -> None: + """Register event handlers. + + Can be used for batch registration of the event handlers: + + :param event_handlers: list of the event handler instances + """ + if not event_handlers: + return + + for handler in event_handlers: + self.register_event_handler(handler) + + @abstractmethod + def publish_event(self, event: Event) -> None: + """Publish the event. + + Deliver the event to the all registered event handlers. + + :param event: event instance + """ + + +class DefaultEventPublisher(EventPublisher): + """Default event publishing implementation. + + Simple implementation that maintains list of the registered event + handlers. + """ + + def __init__(self) -> None: + """Init the event publisher.""" + self.handlers: List[EventHandler] = [] + + def register_event_handler(self, event_handler: EventHandler) -> None: + """Register the event handler. + + :param event_handler: instance of the event handler + """ + self.handlers.append(event_handler) + + def publish_event(self, event: Event) -> None: + """Publish the event. + + Publisher does not catch exceptions that could be raised by event handlers. + """ + for handler in self.handlers: + handler.handle_event(event) + + +@contextmanager +def stage( + publisher: EventPublisher, events: Tuple[Event, Event] +) -> Generator[None, None, None]: + """Generate events before and after stage. + + This context manager could be used to mark start/finish + execution of a particular logical part of the workflow. + """ + started, finished = events + + publisher.publish_event(started) + yield + publisher.publish_event(finished) + + +@contextmanager +def action( + publisher: EventPublisher, action_type: str, params: Optional[Dict] = None +) -> Generator[None, None, None]: + """Generate events before and after action.""" + action_started = ActionStartedEvent(action_type, params) + action_finished = ActionFinishedEvent(action_started.event_id) + + publisher.publish_event(action_started) + yield + publisher.publish_event(action_finished) + + +class SystemEventsHandler(EventDispatcher): + """System events handler.""" + + def on_execution_started(self, event: ExecutionStartedEvent) -> None: + """Handle ExecutionStarted event.""" + + def on_execution_finished(self, event: ExecutionFinishedEvent) -> None: + """Handle ExecutionFinished event.""" + + def on_execution_failed(self, event: ExecutionFailedEvent) -> None: + """Handle ExecutionFailed event.""" + + def on_data_collection_stage_started( + self, event: DataCollectionStageStartedEvent + ) -> None: + """Handle DataCollectionStageStarted event.""" + + def on_data_collection_stage_finished( + self, event: DataCollectionStageFinishedEvent + ) -> None: + """Handle DataCollectionStageFinished event.""" + + def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None: + """Handle DataCollectorSkipped event.""" + + def on_data_analysis_stage_started( + self, event: DataAnalysisStageStartedEvent + ) -> None: + """Handle DataAnalysisStageStartedEvent event.""" + + def on_data_analysis_stage_finished( + self, event: DataAnalysisStageFinishedEvent + ) -> None: + """Handle DataAnalysisStageFinishedEvent event.""" + + def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None: + """Handle AdviceStageStarted event.""" + + def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None: + """Handle AdviceStageFinished event.""" + + def on_collected_data(self, event: CollectedDataEvent) -> None: + """Handle CollectedData event.""" + + def on_analyzed_data(self, event: AnalyzedDataEvent) -> None: + """Handle AnalyzedData event.""" + + def on_action_started(self, event: ActionStartedEvent) -> None: + """Handle ActionStarted event.""" + + def on_action_finished(self, event: ActionFinishedEvent) -> None: + """Handle ActionFinished event.""" diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py new file mode 100644 index 0000000..d10ea5d --- /dev/null +++ b/src/mlia/core/helpers.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for various helper classes.""" +# pylint: disable=no-self-use, unused-argument +from typing import Any +from typing import List + + +class ActionResolver: + """Helper class for generating actions (e.g. commands with parameters).""" + + def apply_optimizations(self, **kwargs: Any) -> List[str]: + """Return action details for applying optimizations.""" + return [] + + def supported_operators_info(self) -> List[str]: + """Return action details for generating supported ops report.""" + return [] + + def check_performance(self) -> List[str]: + """Return action details for checking performance.""" + return [] + + def check_operator_compatibility(self) -> List[str]: + """Return action details for checking op compatibility.""" + return [] + + def operator_compatibility_details(self) -> List[str]: + """Return action details for getting more information about op compatibility.""" + return [] + + def optimization_details(self) -> List[str]: + """Return action detail for getting information about optimizations.""" + return [] + + +class APIActionResolver(ActionResolver): + """Helper class for the actions performed through API.""" diff --git a/src/mlia/core/mixins.py b/src/mlia/core/mixins.py new file mode 100644 index 0000000..ee03100 --- /dev/null +++ b/src/mlia/core/mixins.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Mixins module.""" +from typing import Any +from typing import Optional + +from mlia.core.context import Context + + +class ContextMixin: + """Mixin for injecting context object.""" + + context: Context + + def set_context(self, context: Context) -> None: + """Context setter.""" + self.context = context + + +class ParameterResolverMixin: + """Mixin for parameter resolving.""" + + context: Context + + def get_parameter( + self, + section: str, + name: str, + expected: bool = True, + expected_type: Optional[type] = None, + context: Optional[Context] = None, + ) -> Any: + """Get parameter value.""" + ctx = context or self.context + + if ctx.config_parameters is None: + raise Exception("Configuration parameters are not set") + + section_params = ctx.config_parameters.get(section) + if section_params is None or not isinstance(section_params, dict): + raise Exception( + f"Parameter section {section} has wrong format, " + "expected to be a dictionary" + ) + + value = section_params.get(name) + + if not value and expected: + raise Exception(f"Parameter {name} is not set") + + if value and expected_type is not None and not isinstance(value, expected_type): + raise Exception(f"Parameter {name} expected to have type {expected_type}") + + return value diff --git a/src/mlia/core/performance.py b/src/mlia/core/performance.py new file mode 100644 index 0000000..5433d5c --- /dev/null +++ b/src/mlia/core/performance.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for performance estimation.""" +from abc import abstractmethod +from typing import Callable +from typing import Generic +from typing import List +from typing import TypeVar + + +ModelType = TypeVar("ModelType") # pylint: disable=invalid-name +PerfMetricsType = TypeVar("PerfMetricsType") # pylint: disable=invalid-name + + +class PerformanceEstimator(Generic[ModelType, PerfMetricsType]): + """Base class for the performance estimation.""" + + @abstractmethod + def estimate(self, model: ModelType) -> PerfMetricsType: + """Estimate performance.""" + + +def estimate_performance( + original_model: ModelType, + estimator: PerformanceEstimator[ModelType, PerfMetricsType], + model_transformations: List[Callable[[ModelType], ModelType]], +) -> List[PerfMetricsType]: + """Estimate performance impact. + + This function estimates performance impact on model performance after + applying provided transformations/optimizations. + + :param original_model: object that represents a model, could be + instance of the model or path to the model. This depends on + provided performance estimator. + :param estimator: performance estimator + :param model_transformations: list of the callables each of those + returns object that represents optimized model + """ + original_metrics = estimator.estimate(original_model) + + optimized_metrics = [ + estimator.estimate(transform(original_model)) + for transform in model_transformations + ] + + return [original_metrics, *optimized_metrics] diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py new file mode 100644 index 0000000..1b75bb4 --- /dev/null +++ b/src/mlia/core/reporting.py @@ -0,0 +1,762 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Reporting module.""" +import csv +import json +import logging +from abc import ABC +from abc import abstractmethod +from collections import defaultdict +from contextlib import contextmanager +from contextlib import ExitStack +from dataclasses import dataclass +from functools import partial +from io import TextIOWrapper +from pathlib import Path +from textwrap import fill +from textwrap import indent +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Generator +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np + +from mlia.core._typing import FileLike +from mlia.core._typing import OutputFormat +from mlia.core._typing import PathOrFileLike +from mlia.utils.console import apply_style +from mlia.utils.console import produce_table +from mlia.utils.logging import LoggerWriter +from mlia.utils.types import is_list_of + +logger = logging.getLogger(__name__) + + +class Report(ABC): + """Abstract class for the report.""" + + @abstractmethod + def to_json(self, **kwargs: Any) -> Any: + """Convert to json serializible format.""" + + @abstractmethod + def to_csv(self, **kwargs: Any) -> List[Any]: + """Convert to csv serializible format.""" + + @abstractmethod + def to_plain_text(self, **kwargs: Any) -> str: + """Convert to human readable format.""" + + +class ReportItem: + """Item of the report.""" + + def __init__( + self, + name: str, + alias: Optional[str] = None, + value: Optional[Union[str, int, "Cell"]] = None, + nested_items: Optional[List["ReportItem"]] = None, + ) -> None: + """Init the report item.""" + self.name = name + self.alias = alias + self.value = value + self.nested_items = nested_items or [] + + @property + def compound(self) -> bool: + """Return true if item has nested items.""" + return self.nested_items is not None and len(self.nested_items) > 0 + + @property + def raw_value(self) -> Any: + """Get actual item value.""" + val = self.value + if isinstance(val, Cell): + return val.value + + return val + + +@dataclass +class Format: + """Column or cell format. + + Format could be applied either to a column or an individual cell. + + :param wrap_width: width of the wrapped text value + :param str_fmt: string format to be applied to the value + :param style: text style + """ + + wrap_width: Optional[int] = None + str_fmt: Optional[Union[str, Callable[[Any], str]]] = None + style: Optional[str] = None + + +@dataclass +class Cell: + """Cell definition. + + This a wrapper class for a particular value in the table. Could be used + for applying specific format to this value. + """ + + value: Any + fmt: Optional[Format] = None + + def _apply_style(self, value: str) -> str: + """Apply style to the value.""" + if self.fmt and self.fmt.style: + value = apply_style(value, self.fmt.style) + + return value + + def _get_value(self) -> str: + """Return cell value.""" + if self.fmt: + if isinstance(self.fmt.str_fmt, str): + return "{:{fmt}}".format(self.value, fmt=self.fmt.str_fmt) + + if callable(self.fmt.str_fmt): + return self.fmt.str_fmt(self.value) + + return str(self.value) + + def __str__(self) -> str: + """Return string representation.""" + val = self._get_value() + return self._apply_style(val) + + def to_csv(self) -> Any: + """Cell definition for csv.""" + return self.value + + def to_json(self) -> Any: + """Cell definition for json.""" + return self.value + + +class CountAwareCell(Cell): + """Count aware cell.""" + + def __init__( + self, + value: Optional[Union[int, float]], + singular: str, + plural: str, + format_string: str = ",d", + ): + """Init cell instance.""" + self.unit = singular if value == 1 else plural + + def format_value(val: Optional[Union[int, float]]) -> str: + """Provide string representation for the value.""" + if val is None: + return "" + + if val == 1: + return f"1 {singular}" + + return f"{val:{format_string}} {plural}" + + super().__init__(value, Format(str_fmt=format_value)) + + def to_csv(self) -> Any: + """Cell definition for csv.""" + return {"value": self.value, "unit": self.unit} + + def to_json(self) -> Any: + """Cell definition for json.""" + return {"value": self.value, "unit": self.unit} + + +class BytesCell(CountAwareCell): + """Cell that represents memory size.""" + + def __init__(self, value: Optional[int]) -> None: + """Init cell instance.""" + super().__init__(value, "byte", "bytes") + + +class CyclesCell(CountAwareCell): + """Cell that represents cycles.""" + + def __init__(self, value: Optional[Union[int, float]]) -> None: + """Init cell instance.""" + super().__init__(value, "cycle", "cycles", ",.0f") + + +class ClockCell(CountAwareCell): + """Cell that represents clock value.""" + + def __init__(self, value: Optional[Union[int, float]]) -> None: + """Init cell instance.""" + super().__init__(value, "Hz", "Hz", ",.0f") + + +class Column: + """Column definition.""" + + def __init__( + self, + header: str, + alias: Optional[str] = None, + fmt: Optional[Format] = None, + only_for: Optional[List[str]] = None, + ) -> None: + """Init column definition. + + :param header: column's header + :param alias: columns's alias, could be used as column's name + :param fmt: format that will be applied for all column's values + :param only_for: list of the formats where this column should be + represented. May be used to differentiate data representation in + different formats + """ + self.header = header + self.alias = alias + self.fmt = fmt + self.only_for = only_for + + def supports_format(self, fmt: str) -> bool: + """Return true if column should be shown.""" + return not self.only_for or fmt in self.only_for + + +class NestedReport(Report): + """Report with nested items.""" + + def __init__(self, name: str, alias: str, items: List[ReportItem]) -> None: + """Init nested report.""" + self.name = name + self.alias = alias + self.items = items + + def to_csv(self, **kwargs: Any) -> List[Any]: + """Convert to csv serializible format.""" + result = {} + + def collect_item_values( + item: ReportItem, + _parent: Optional[ReportItem], + _prev: Optional[ReportItem], + _level: int, + ) -> None: + """Collect item values into a dictionary..""" + if item.value is None: + return + + if not isinstance(item.value, Cell): + result[item.alias] = item.raw_value + return + + csv_value = item.value.to_csv() + if isinstance(csv_value, dict): + csv_value = { + f"{item.alias}_{key}": value for key, value in csv_value.items() + } + else: + csv_value = {item.alias: csv_value} + + result.update(csv_value) + + self._traverse(self.items, collect_item_values) + + # make list out of the result dictionary + # first element - keys of the dictionary as headers + # second element - list of the dictionary values + return list(zip(*result.items())) + + def to_json(self, **kwargs: Any) -> Any: + """Convert to json serializible format.""" + per_parent: Dict[Optional[ReportItem], Dict] = defaultdict(dict) + result = per_parent[None] + + def collect_as_dicts( + item: ReportItem, + parent: Optional[ReportItem], + _prev: Optional[ReportItem], + _level: int, + ) -> None: + """Collect item values as nested dictionaries.""" + parent_dict = per_parent[parent] + + if item.compound: + item_dict = per_parent[item] + parent_dict[item.alias] = item_dict + else: + out_dis = ( + item.value.to_json() + if isinstance(item.value, Cell) + else item.raw_value + ) + parent_dict[item.alias] = out_dis + + self._traverse(self.items, collect_as_dicts) + + return {self.alias: result} + + def to_plain_text(self, **kwargs: Any) -> str: + """Convert to human readable format.""" + header = f"{self.name}:\n" + processed_items = [] + + def convert_to_text( + item: ReportItem, + _parent: Optional[ReportItem], + prev: Optional[ReportItem], + level: int, + ) -> None: + """Convert item to text representation.""" + if level >= 1 and prev is not None and (item.compound or prev.compound): + processed_items.append("") + + val = self._item_value(item, level) + processed_items.append(val) + + self._traverse(self.items, convert_to_text) + body = "\n".join(processed_items) + + return header + body + + @staticmethod + def _item_value( + item: ReportItem, level: int, tab_size: int = 2, column_width: int = 35 + ) -> str: + """Get report item value.""" + shift = " " * tab_size * level + if item.value is None: + return f"{shift}{item.name}:" + + col1 = f"{shift}{item.name}".ljust(column_width) + col2 = f"{item.value}".rjust(column_width) + + return col1 + col2 + + def _traverse( + self, + items: List[ReportItem], + visit_item: Callable[ + [ReportItem, Optional[ReportItem], Optional[ReportItem], int], None + ], + level: int = 1, + parent: Optional[ReportItem] = None, + ) -> None: + """Traverse through items.""" + prev = None + for item in items: + visit_item(item, parent, prev, level) + + self._traverse(item.nested_items, visit_item, level + 1, item) + prev = item + + +class Table(Report): + """Table definition. + + This class could be used for representing tabular data. + """ + + def __init__( + self, + columns: List[Column], + rows: Collection, + name: str, + alias: Optional[str] = None, + notes: Optional[str] = None, + ) -> None: + """Init table definition. + + :param columns: list of the table's columns + :param rows: list of the table's rows + :param name: name of the table + :param alias: alias for the table + """ + self.columns = columns + self.rows = rows + self.name = name + self.alias = alias + self.notes = notes + + def to_json(self, **kwargs: Any) -> Iterable: + """Convert table to dict object.""" + + def item_to_json(item: Any) -> Any: + value = item + if isinstance(item, Cell): + value = item.value + + if isinstance(value, Table): + return value.to_json() + + return value + + json_data = [ + { + col.alias or col.header: item_to_json(item) + for (item, col) in zip(row, self.columns) + if col.supports_format("json") + } + for row in self.rows + ] + + if not self.alias: + return json_data + + return {self.alias: json_data} + + def to_plain_text(self, **kwargs: Any) -> str: + """Produce report in human readable format.""" + nested = kwargs.get("nested", False) + show_headers = kwargs.get("show_headers", True) + show_title = kwargs.get("show_title", True) + table_style = kwargs.get("table_style", "default") + space = kwargs.get("space", False) + + headers = ( + [] if (nested or not show_headers) else [c.header for c in self.columns] + ) + + def item_to_plain_text(item: Any, col: Column) -> str: + """Convert item to text.""" + if isinstance(item, Table): + return item.to_plain_text(nested=True, **kwargs) + + if is_list_of(item, str): + as_text = "\n".join(item) + else: + as_text = str(item) + + if col.fmt and col.fmt.wrap_width: + as_text = fill(as_text, col.fmt.wrap_width) + + return as_text + + title = "" + if show_title and not nested: + title = f"{self.name}:\n" + + if space in (True, "top"): + title = "\n" + title + + footer = "" + if space in (True, "bottom"): + footer = "\n" + if self.notes: + footer = "\n" + self.notes + + formatted_rows = ( + ( + item_to_plain_text(item, col) + for item, col in zip(row, self.columns) + if col.supports_format("plain_text") + ) + for row in self.rows + ) + + if space == "between": + formatted_table = "\n\n".join( + produce_table([row], table_style=table_style) for row in formatted_rows + ) + else: + formatted_table = produce_table( + formatted_rows, + headers=headers, + table_style="nested" if nested else table_style, + ) + + return title + formatted_table + footer + + def to_csv(self, **kwargs: Any) -> List[Any]: + """Convert table to csv format.""" + headers = [[c.header for c in self.columns if c.supports_format("csv")]] + + def item_data(item: Any) -> Any: + if isinstance(item, Cell): + return item.value + + if isinstance(item, Table): + return ";".join( + str(item_data(cell)) for row in item.rows for cell in row + ) + + return item + + rows = [ + [ + item_data(item) + for (item, col) in zip(row, self.columns) + if col.supports_format("csv") + ] + for row in self.rows + ] + + return headers + rows + + +class SingleRow(Table): + """Table with a single row.""" + + def to_plain_text(self, **kwargs: Any) -> str: + """Produce report in human readable format.""" + if len(self.rows) != 1: + raise Exception("Table should have only one row") + + items = "\n".join( + column.header.ljust(35) + str(item).rjust(25) + for row in self.rows + for item, column in zip(row, self.columns) + if column.supports_format("plain_text") + ) + + return "\n".join([f"{self.name}:", indent(items, " ")]) + + +class CompoundReport(Report): + """Compound report. + + This class could be used for producing multiple reports at once. + """ + + def __init__(self, reports: List[Report]) -> None: + """Init compound report instance.""" + self.reports = reports + + def to_json(self, **kwargs: Any) -> Any: + """Convert to json serializible format. + + Method attempts to create compound dictionary based on provided + parts. + """ + result: Dict[str, Any] = {} + for item in self.reports: + result.update(item.to_json(**kwargs)) + + return result + + def to_csv(self, **kwargs: Any) -> List[Any]: + """Convert to csv serializible format. + + CSV format does support only one table. In order to be able to export + multiply tables they should be merged before that. This method tries to + do next: + + - if all tables have the same length then just concatenate them + - if one table has many rows and other just one (two with headers), then + for each row in table with many rows duplicate values from other tables + """ + csv_data = [item.to_csv() for item in self.reports] + lengths = [len(csv_item_data) for csv_item_data in csv_data] + + same_length = len(set(lengths)) == 1 + if same_length: + # all lists are of the same length, merge them into one + return [[cell for item in row for cell in item] for row in zip(*csv_data)] + + main_obj_indexes = [i for i, item in enumerate(csv_data) if len(item) > 2] + one_main_obj = len(main_obj_indexes) == 1 + + reference_obj_indexes = [i for i, item in enumerate(csv_data) if len(item) == 2] + other_only_ref_objs = len(reference_obj_indexes) == len(csv_data) - 1 + + if one_main_obj and other_only_ref_objs: + main_obj = csv_data[main_obj_indexes[0]] + return [ + item + + [ + ref_item + for ref_table_index in reference_obj_indexes + for ref_item in csv_data[ref_table_index][0 if i == 0 else 1] + ] + for i, item in enumerate(main_obj) + ] + + # write tables one after another if there is no other options + return [row for item in csv_data for row in item] + + def to_plain_text(self, **kwargs: Any) -> str: + """Convert to human readable format.""" + return "\n".join(item.to_plain_text(**kwargs) for item in self.reports) + + +class CompoundFormatter: + """Compound data formatter.""" + + def __init__(self, formatters: List[Callable]) -> None: + """Init compound formatter.""" + self.formatters = formatters + + def __call__(self, data: Any) -> Report: + """Produce report.""" + reports = [formatter(item) for item, formatter in zip(data, self.formatters)] + return CompoundReport(reports) + + +class CustomJSONEncoder(json.JSONEncoder): + """Custom JSON encoder.""" + + def default(self, o: Any) -> Any: + """Support numpy types.""" + if isinstance(o, np.integer): + return int(o) + + if isinstance(o, np.floating): + return float(o) + + return json.JSONEncoder.default(self, o) + + +def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: + """Produce report in json format.""" + json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder) + print(json_str, file=output) + + +def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: + """Produce report in text format.""" + print(report.to_plain_text(**kwargs), file=output) + + +def csv_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: + """Produce report in csv format.""" + csv_writer = csv.writer(output) + csv_writer.writerows(report.to_csv(**kwargs)) + + +def produce_report( + data: Any, + formatter: Callable[[Any], Report], + fmt: OutputFormat = "plain_text", + output: Optional[PathOrFileLike] = None, + **kwargs: Any, +) -> None: + """Produce report based on provided data.""" + # check if provided format value is supported + formats = {"json": json_reporter, "plain_text": text_reporter, "csv": csv_reporter} + if fmt not in formats: + raise Exception(f"Unknown format {fmt}") + + if output is None: + output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO)) + + with ExitStack() as exit_stack: + if isinstance(output, (str, Path)): + # open file and add it to the ExitStack context manager + # in that case it will be automatically closed + stream = exit_stack.enter_context(open(output, "w", encoding="utf-8")) + else: + stream = cast(TextIOWrapper, output) + + # convert data into serializable form + formatted_data = formatter(data) + # find handler for the format + format_handler = formats[fmt] + # produce report in requested format + format_handler(formatted_data, stream, **kwargs) + + +class Reporter: + """Reporter class.""" + + def __init__( + self, + formatter_resolver: Callable[[Any], Callable[[Any], Report]], + output_format: OutputFormat = "plain_text", + print_as_submitted: bool = True, + ) -> None: + """Init reporter instance.""" + self.formatter_resolver = formatter_resolver + self.output_format = output_format + self.print_as_submitted = print_as_submitted + + self.data: List[Tuple[Any, Callable[[Any], Report]]] = [] + self.delayed: List[Tuple[Any, Callable[[Any], Report]]] = [] + + def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None: + """Submit data for the report.""" + if self.print_as_submitted and not delay_print: + produce_report( + data_item, + self.formatter_resolver(data_item), + fmt="plain_text", + **kwargs, + ) + + formatter = _apply_format_parameters( + self.formatter_resolver(data_item), self.output_format, **kwargs + ) + self.data.append((data_item, formatter)) + + if delay_print: + self.delayed.append((data_item, formatter)) + + def print_delayed(self) -> None: + """Print delayed reports.""" + if not self.delayed: + return + + data, formatters = zip(*self.delayed) + produce_report( + data, + formatter=CompoundFormatter(formatters), + fmt="plain_text", + ) + self.delayed = [] + + def generate_report(self, output: Optional[PathOrFileLike]) -> None: + """Generate report.""" + already_printed = ( + self.print_as_submitted + and self.output_format == "plain_text" + and output is None + ) + if not self.data or already_printed: + return + + data, formatters = zip(*self.data) + produce_report( + data, + formatter=CompoundFormatter(formatters), + fmt=self.output_format, + output=output, + ) + + +@contextmanager +def get_reporter( + output_format: OutputFormat, + output: Optional[PathOrFileLike], + formatter_resolver: Callable[[Any], Callable[[Any], Report]], +) -> Generator[Reporter, None, None]: + """Get reporter and generate report.""" + reporter = Reporter(formatter_resolver, output_format) + + yield reporter + + reporter.generate_report(output) + + +def _apply_format_parameters( + formatter: Callable[[Any], Report], output_format: OutputFormat, **kwargs: Any +) -> Callable[[Any], Report]: + """Wrap report method.""" + + def wrapper(data: Any) -> Report: + report = formatter(data) + method_name = f"to_{output_format}" + method = getattr(report, method_name) + setattr(report, method_name, partial(method, **kwargs)) + + return report + + return wrapper diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py new file mode 100644 index 0000000..0245087 --- /dev/null +++ b/src/mlia/core/workflow.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for executors. + +This module contains implementation of the workflow +executors. +""" +import itertools +from abc import ABC +from abc import abstractmethod +from functools import wraps +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple + +from mlia.core.advice_generation import Advice +from mlia.core.advice_generation import AdviceEvent +from mlia.core.advice_generation import AdviceProducer +from mlia.core.common import DataItem +from mlia.core.context import Context +from mlia.core.data_analysis import DataAnalyzer +from mlia.core.data_collection import DataCollector +from mlia.core.errors import FunctionalityNotSupportedError +from mlia.core.events import AdviceStageFinishedEvent +from mlia.core.events import AdviceStageStartedEvent +from mlia.core.events import AnalyzedDataEvent +from mlia.core.events import CollectedDataEvent +from mlia.core.events import DataAnalysisStageFinishedEvent +from mlia.core.events import DataAnalysisStageStartedEvent +from mlia.core.events import DataCollectionStageFinishedEvent +from mlia.core.events import DataCollectionStageStartedEvent +from mlia.core.events import DataCollectorSkippedEvent +from mlia.core.events import Event +from mlia.core.events import ExecutionFailedEvent +from mlia.core.events import ExecutionFinishedEvent +from mlia.core.events import ExecutionStartedEvent +from mlia.core.events import stage +from mlia.core.mixins import ContextMixin + + +class WorkflowExecutor(ABC): + """Base workflow executor.""" + + @abstractmethod + def run(self) -> None: + """Run the module.""" + + +STAGE_COLLECTION = ( + DataCollectionStageStartedEvent(), + DataCollectionStageFinishedEvent(), +) +STAGE_ANALYSIS = (DataAnalysisStageStartedEvent(), DataAnalysisStageFinishedEvent()) +STAGE_ADVICE = (AdviceStageStartedEvent(), AdviceStageFinishedEvent()) + + +def on_stage(stage_events: Tuple[Event, Event]) -> Callable: + """Mark start/finish of the stage with appropriate events.""" + + def wrapper(method: Callable) -> Callable: + """Wrap method.""" + + @wraps(method) + def publish_events(self: Any, *args: Any, **kwargs: Any) -> Any: + """Publish events before and after execution.""" + with stage(self.context.event_publisher, stage_events): + return method(self, *args, **kwargs) + + return publish_events + + return wrapper + + +class DefaultWorkflowExecutor(WorkflowExecutor): + """Default module executor. + + This is a default implementation of the workflow executor. + All components are launched sequentually in the same process. + """ + + def __init__( + self, + context: Context, + collectors: Sequence[DataCollector], + analyzers: Sequence[DataAnalyzer], + producers: Sequence[AdviceProducer], + before_start_events: Optional[Sequence[Event]] = None, + ): + """Init default workflow executor. + + :param context: Context instance + :param collectors: List of the data collectors + :param analyzers: List of the data analyzers + :param producers: List of the advice producers + :param before_start_events: Optional list of the custom events that + should be published before start of the worfkow execution. + """ + self.context = context + self.collectors = collectors + self.analyzers = analyzers + self.producers = producers + self.before_start_events = before_start_events + + def run(self) -> None: + """Run the workflow.""" + self.inject_context() + self.context.register_event_handlers() + + try: + self.publish(ExecutionStartedEvent()) + + self.before_start() + + collected_data = self.collect_data() + analyzed_data = self.analyze_data(collected_data) + + self.produce_advice(analyzed_data) + except Exception as err: # pylint: disable=broad-except + self.publish(ExecutionFailedEvent(err)) + else: + self.publish(ExecutionFinishedEvent()) + + def before_start(self) -> None: + """Run actions before start of the workflow execution.""" + events = self.before_start_events or [] + for event in events: + self.publish(event) + + @on_stage(STAGE_COLLECTION) + def collect_data(self) -> List[DataItem]: + """Collect data. + + Run each of data collector components and return list of + the collected data items. + """ + collected_data = [] + for collector in self.collectors: + try: + if (data_item := collector.collect_data()) is not None: + collected_data.append(data_item) + self.publish(CollectedDataEvent(data_item)) + except FunctionalityNotSupportedError as err: + self.publish(DataCollectorSkippedEvent(collector.name(), str(err))) + + return collected_data + + @on_stage(STAGE_ANALYSIS) + def analyze_data(self, collected_data: List[DataItem]) -> List[DataItem]: + """Analyze data. + + Pass each collected data item into each data analyzer and + return analyzed data. + + :param collected_data: list of collected data items + """ + analyzed_data = [] + for analyzer in self.analyzers: + for item in collected_data: + analyzer.analyze_data(item) + + for data_item in analyzer.get_analyzed_data(): + analyzed_data.append(data_item) + + self.publish(AnalyzedDataEvent(data_item)) + return analyzed_data + + @on_stage(STAGE_ADVICE) + def produce_advice(self, analyzed_data: List[DataItem]) -> None: + """Produce advice. + + Pass each analyzed data item into each advice producer and + publish generated advice. + + :param analyzed_data: list of analyzed data items + """ + for producer in self.producers: + for data_item in analyzed_data: + producer.produce_advice(data_item) + + advice = producer.get_advice() + if isinstance(advice, Advice): + advice = [advice] + + for item in advice: + self.publish(AdviceEvent(item)) + + def inject_context(self) -> None: + """Inject context object into components. + + Inject context object into components that supports context + injection. + """ + context_aware_components = ( + comp + for comp in itertools.chain( + self.collectors, + self.analyzers, + self.producers, + ) + if isinstance(comp, ContextMixin) + ) + + for component in context_aware_components: + component.set_context(self.context) + + def publish(self, event: Event) -> None: + """Publish event. + + Helper method for event publising. + + :param event: event instance + """ + self.context.event_publisher.publish_event(event) |