aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/core
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/core')
-rw-r--r--src/mlia/core/__init__.py21
-rw-r--r--src/mlia/core/_typing.py12
-rw-r--r--src/mlia/core/advice_generation.py106
-rw-r--r--src/mlia/core/advisor.py21
-rw-r--r--src/mlia/core/common.py47
-rw-r--r--src/mlia/core/context.py218
-rw-r--r--src/mlia/core/data_analysis.py70
-rw-r--r--src/mlia/core/data_collection.py37
-rw-r--r--src/mlia/core/errors.py18
-rw-r--r--src/mlia/core/events.py455
-rw-r--r--src/mlia/core/helpers.py38
-rw-r--r--src/mlia/core/mixins.py54
-rw-r--r--src/mlia/core/performance.py47
-rw-r--r--src/mlia/core/reporting.py762
-rw-r--r--src/mlia/core/workflow.py216
15 files changed, 2122 insertions, 0 deletions
diff --git a/src/mlia/core/__init__.py b/src/mlia/core/__init__.py
new file mode 100644
index 0000000..49b1830
--- /dev/null
+++ b/src/mlia/core/__init__.py
@@ -0,0 +1,21 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Core module.
+
+Core module contains the main components that are used in the workflow of
+ML Inference Advisor:
+ - data collectors
+ - data analyzers
+ - advice producers
+ - event publishers
+ - event handlers
+
+The workflow of ML Inference Advisor consists of 3 stages:
+ - data collection
+ - data analysis
+ - advice generation
+
+Data is being passed from one stage to another via workflow executor.
+Results (collected data, analyzed data, advice, etc) are being published via
+publish/subscribe mechanishm.
+"""
diff --git a/src/mlia/core/_typing.py b/src/mlia/core/_typing.py
new file mode 100644
index 0000000..bda995c
--- /dev/null
+++ b/src/mlia/core/_typing.py
@@ -0,0 +1,12 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for custom type hints."""
+from pathlib import Path
+from typing import Literal
+from typing import TextIO
+from typing import Union
+
+
+FileLike = TextIO
+PathOrFileLike = Union[str, Path, FileLike]
+OutputFormat = Literal["plain_text", "csv", "json"]
diff --git a/src/mlia/core/advice_generation.py b/src/mlia/core/advice_generation.py
new file mode 100644
index 0000000..76cc1f2
--- /dev/null
+++ b/src/mlia/core/advice_generation.py
@@ -0,0 +1,106 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for advice generation."""
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+from functools import wraps
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Union
+
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.core.events import SystemEvent
+from mlia.core.mixins import ContextMixin
+
+
+@dataclass
+class Advice:
+ """Base class for the advice."""
+
+ messages: List[str]
+
+
+@dataclass
+class AdviceEvent(SystemEvent):
+ """Advice event.
+
+ This event is published for every produced advice.
+
+ :param advice: Advice instance
+ """
+
+ advice: Advice
+
+
+class AdviceProducer(ABC):
+ """Base class for the advice producer.
+
+ Producer has two methods for advice generation:
+ - produce_advice - used to generate advice based on provided
+ data (analyzed data item from analyze stage)
+ - get_advice - used for getting generated advice
+
+ Advice producers that have predefined advice could skip
+ implementation of produce_advice method.
+ """
+
+ @abstractmethod
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Process data item and produce advice.
+
+ :param data_item: piece of data that could be used
+ for advice generation
+ """
+
+ @abstractmethod
+ def get_advice(self) -> Union[Advice, List[Advice]]:
+ """Get produced advice."""
+
+
+class ContextAwareAdviceProducer(AdviceProducer, ContextMixin):
+ """Context aware advice producer.
+
+ This class makes easier access to the Context object. Context object could
+ be automatically injected during workflow configuration.
+ """
+
+
+class FactBasedAdviceProducer(ContextAwareAdviceProducer):
+ """Advice producer based on provided facts.
+
+ This is an utility class that maintain list of generated Advice instances.
+ """
+
+ def __init__(self) -> None:
+ """Init advice producer."""
+ self.advice: List[Advice] = []
+
+ def get_advice(self) -> Union[Advice, List[Advice]]:
+ """Get produced advice."""
+ return self.advice
+
+ def add_advice(self, messages: List[str]) -> None:
+ """Add advice."""
+ self.advice.append(Advice(messages))
+
+
+def advice_category(*categories: AdviceCategory) -> Callable:
+ """Filter advice generation handler by advice category."""
+
+ def wrapper(handler: Callable) -> Callable:
+ """Wrap data handler."""
+
+ @wraps(handler)
+ def check_category(self: Any, *args: Any, **kwargs: Any) -> Any:
+ """Check if handler can produce advice for the requested category."""
+ if not self.context.any_category_enabled(*categories):
+ return
+
+ handler(self, *args, **kwargs)
+
+ return check_category
+
+ return wrapper
diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py
new file mode 100644
index 0000000..868d0c7
--- /dev/null
+++ b/src/mlia/core/advisor.py
@@ -0,0 +1,21 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Inference advisor module."""
+from abc import abstractmethod
+
+from mlia.core.common import NamedEntity
+from mlia.core.context import Context
+from mlia.core.workflow import WorkflowExecutor
+
+
+class InferenceAdvisor(NamedEntity):
+ """Base class for inference advisors."""
+
+ @abstractmethod
+ def configure(self, context: Context) -> WorkflowExecutor:
+ """Configure advisor execution."""
+
+ def run(self, context: Context) -> None:
+ """Run inference advisor."""
+ executor = self.configure(context)
+ executor.run()
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
new file mode 100644
index 0000000..5fbad42
--- /dev/null
+++ b/src/mlia/core/common.py
@@ -0,0 +1,47 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common module.
+
+This module contains common interfaces/classess shared across
+core module.
+"""
+from abc import ABC
+from abc import abstractmethod
+from enum import Enum
+from typing import Any
+
+# This type is used as type alias for the items which are being passed around
+# in advisor workflow. There are no restrictions on the type of the
+# object. This alias used only to emphasize the nature of the input/output
+# arguments.
+DataItem = Any
+
+
+class AdviceCategory(Enum):
+ """Advice category.
+
+ Enumeration of advice categories supported by ML Inference Advisor.
+ """
+
+ OPERATORS = 1
+ PERFORMANCE = 2
+ OPTIMIZATION = 3
+ ALL = 4
+
+ @classmethod
+ def from_string(cls, value: str) -> "AdviceCategory":
+ """Resolve enum value from string value."""
+ category_names = [item.name for item in AdviceCategory]
+ if not value or value.upper() not in category_names:
+ raise Exception(f"Invalid advice category {value}")
+
+ return AdviceCategory[value.upper()]
+
+
+class NamedEntity(ABC):
+ """Entity with a name and description."""
+
+ @classmethod
+ @abstractmethod
+ def name(cls) -> str:
+ """Return name of the entity."""
diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py
new file mode 100644
index 0000000..8b3dd2c
--- /dev/null
+++ b/src/mlia/core/context.py
@@ -0,0 +1,218 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Context module.
+
+This module contains functionality related to the Context.
+Context is an object that describes advisor working environment
+and requested behavior (advice categories, input configuration
+parameters).
+"""
+import logging
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Union
+
+from mlia.core.common import AdviceCategory
+from mlia.core.events import DefaultEventPublisher
+from mlia.core.events import EventHandler
+from mlia.core.events import EventPublisher
+from mlia.core.helpers import ActionResolver
+from mlia.core.helpers import APIActionResolver
+
+logger = logging.getLogger(__name__)
+
+
+class Context(ABC):
+ """Abstract class for the execution context."""
+
+ @abstractmethod
+ def get_model_path(self, model_filename: str) -> Path:
+ """Return path for the intermediate/optimized models.
+
+ During workflow execution different parts of the advisor
+ require creating intermediate files for models.
+
+ This method allows to provide paths where those models
+ could be saved.
+
+ :param model_filename: filename of the model
+ """
+
+ @property
+ @abstractmethod
+ def event_publisher(self) -> EventPublisher:
+ """Return event publisher."""
+
+ @property
+ @abstractmethod
+ def event_handlers(self) -> Optional[List[EventHandler]]:
+ """Return list of the event_handlers."""
+
+ @property
+ @abstractmethod
+ def advice_category(self) -> Optional[AdviceCategory]:
+ """Return advice category."""
+
+ @property
+ @abstractmethod
+ def config_parameters(self) -> Optional[Mapping[str, Any]]:
+ """Return configuration parameters."""
+
+ @property
+ @abstractmethod
+ def action_resolver(self) -> ActionResolver:
+ """Return action resolver."""
+
+ @abstractmethod
+ def update(
+ self,
+ *,
+ advice_category: AdviceCategory,
+ event_handlers: List[EventHandler],
+ config_parameters: Mapping[str, Any],
+ ) -> None:
+ """Update context parameters."""
+
+ def category_enabled(self, category: AdviceCategory) -> bool:
+ """Check if category enabled."""
+ return category == self.advice_category
+
+ def any_category_enabled(self, *categories: AdviceCategory) -> bool:
+ """Return true if any category is enabled."""
+ return self.advice_category in categories
+
+ def register_event_handlers(self) -> None:
+ """Register event handlers."""
+ self.event_publisher.register_event_handlers(self.event_handlers)
+
+
+class ExecutionContext(Context):
+ """Execution context."""
+
+ def __init__(
+ self,
+ *,
+ advice_category: Optional[AdviceCategory] = None,
+ config_parameters: Optional[Mapping[str, Any]] = None,
+ working_dir: Optional[Union[str, Path]] = None,
+ event_handlers: Optional[List[EventHandler]] = None,
+ event_publisher: Optional[EventPublisher] = None,
+ verbose: bool = False,
+ logs_dir: str = "logs",
+ models_dir: str = "models",
+ action_resolver: Optional[ActionResolver] = None,
+ ) -> None:
+ """Init execution context.
+
+ :param advice_category: requested advice category
+ :param config_parameters: dictionary like object with input parameters
+ :param working_dir: path to the directory that will be used as a place
+ to store temporary files, logs, models. If not provided then
+ current working directory will be used instead
+ :param event_handlers: optional list of event handlers
+ :param event_publisher: optional event publisher instance. If not provided
+ then default implementation of event publisher will be used
+ :param verbose: enable verbose output
+ :param logs_dir: name of the directory inside working directory where
+ log files will be stored
+ :param models_dir: name of the directory inside working directory where
+ temporary models will be stored
+ :param action_resolver: instance of the action resolver that could make
+ advice actionable
+ """
+ self._advice_category = advice_category
+ self._config_parameters = config_parameters
+
+ self._working_dir_path = Path.cwd()
+ if working_dir:
+ self._working_dir_path = Path(working_dir)
+ self._working_dir_path.mkdir(exist_ok=True)
+
+ self._event_handlers = event_handlers
+ self._event_publisher = event_publisher or DefaultEventPublisher()
+ self.verbose = verbose
+ self.logs_dir = logs_dir
+ self.models_dir = models_dir
+ self._action_resolver = action_resolver or APIActionResolver()
+
+ @property
+ def advice_category(self) -> Optional[AdviceCategory]:
+ """Return advice category."""
+ return self._advice_category
+
+ @advice_category.setter
+ def advice_category(self, advice_category: AdviceCategory) -> None:
+ """Setter for the advice category."""
+ self._advice_category = advice_category
+
+ @property
+ def config_parameters(self) -> Optional[Mapping[str, Any]]:
+ """Return configuration parameters."""
+ return self._config_parameters
+
+ @config_parameters.setter
+ def config_parameters(self, config_parameters: Optional[Mapping[str, Any]]) -> None:
+ """Setter for the configuration parameters."""
+ self._config_parameters = config_parameters
+
+ @property
+ def event_handlers(self) -> Optional[List[EventHandler]]:
+ """Return list of the event handlers."""
+ return self._event_handlers
+
+ @event_handlers.setter
+ def event_handlers(self, event_handlers: List[EventHandler]) -> None:
+ """Setter for the event handlers."""
+ self._event_handlers = event_handlers
+
+ @property
+ def event_publisher(self) -> EventPublisher:
+ """Return event publisher."""
+ return self._event_publisher
+
+ @property
+ def action_resolver(self) -> ActionResolver:
+ """Return action resolver."""
+ return self._action_resolver
+
+ def get_model_path(self, model_filename: str) -> Path:
+ """Return path for the model."""
+ models_dir_path = self._working_dir_path / self.models_dir
+ models_dir_path.mkdir(exist_ok=True)
+
+ return models_dir_path / model_filename
+
+ @property
+ def logs_path(self) -> Path:
+ """Return path to the logs directory."""
+ return self._working_dir_path / self.logs_dir
+
+ def update(
+ self,
+ *,
+ advice_category: AdviceCategory,
+ event_handlers: List[EventHandler],
+ config_parameters: Mapping[str, Any],
+ ) -> None:
+ """Update context parameters."""
+ self._advice_category = advice_category
+ self._event_handlers = event_handlers
+ self._config_parameters = config_parameters
+
+ def __str__(self) -> str:
+ """Return string representation."""
+ category = (
+ "<not set>" if self.advice_category is None else self.advice_category.name
+ )
+
+ return (
+ f"ExecutionContext: working_dir={self._working_dir_path}, "
+ f"advice_category={category}, "
+ f"config_parameters={self.config_parameters}, "
+ f"verbose={self.verbose}"
+ )
diff --git a/src/mlia/core/data_analysis.py b/src/mlia/core/data_analysis.py
new file mode 100644
index 0000000..6adb41e
--- /dev/null
+++ b/src/mlia/core/data_analysis.py
@@ -0,0 +1,70 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for data analysis."""
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+from typing import List
+
+from mlia.core.common import DataItem
+from mlia.core.mixins import ContextMixin
+
+
+class DataAnalyzer(ABC):
+ """Base class for the data analysis.
+
+ Purpose of this class is to extract valuable data out of
+ collected data which could be used for advice generation.
+
+ This process consists of two steps:
+ - analyze every item of the collected data
+ - get analyzed data
+ """
+
+ @abstractmethod
+ def analyze_data(self, data_item: DataItem) -> None:
+ """Analyze data.
+
+ :param data_item: item of the collected data
+ """
+
+ @abstractmethod
+ def get_analyzed_data(self) -> List[DataItem]:
+ """Get analyzed data."""
+
+
+class ContextAwareDataAnalyzer(DataAnalyzer, ContextMixin):
+ """Context aware data analyzer.
+
+ This class makes easier access to the Context object. Context object could
+ be automatically injected during workflow configuration.
+ """
+
+
+@dataclass
+class Fact:
+ """Base class for the facts.
+
+ Fact represents some piece of knowledge about collected
+ data.
+ """
+
+
+class FactExtractor(ContextAwareDataAnalyzer):
+ """Data analyzer based on extracting facts.
+
+ Utility class that makes fact extraction easier.
+ Class maintains list of the extracted facts.
+ """
+
+ def __init__(self) -> None:
+ """Init fact extractor."""
+ self.facts: List[Fact] = []
+
+ def get_analyzed_data(self) -> List[DataItem]:
+ """Return list of the collected facts."""
+ return self.facts
+
+ def add_fact(self, fact: Fact) -> None:
+ """Add fact."""
+ self.facts.append(fact)
diff --git a/src/mlia/core/data_collection.py b/src/mlia/core/data_collection.py
new file mode 100644
index 0000000..43b6d1c
--- /dev/null
+++ b/src/mlia/core/data_collection.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for data collection.
+
+This module contains base classes for the first stage
+of the ML Inference Advisor workflow - data collection.
+"""
+from abc import abstractmethod
+
+from mlia.core.common import DataItem
+from mlia.core.common import NamedEntity
+from mlia.core.mixins import ContextMixin
+from mlia.core.mixins import ParameterResolverMixin
+
+
+class DataCollector(NamedEntity):
+ """Base class for the data collection.
+
+ Data collection is the first step in the process of the advice
+ generation.
+
+ Different implementations of this class can provide various
+ information about model or device. This information is being used
+ at later stages.
+ """
+
+ @abstractmethod
+ def collect_data(self) -> DataItem:
+ """Collect data."""
+
+
+class ContextAwareDataCollector(DataCollector, ContextMixin, ParameterResolverMixin):
+ """Context aware data collector.
+
+ This class makes easier access to the Context object. Context object could
+ be automatically injected during workflow configuration.
+ """
diff --git a/src/mlia/core/errors.py b/src/mlia/core/errors.py
new file mode 100644
index 0000000..7d6beb1
--- /dev/null
+++ b/src/mlia/core/errors.py
@@ -0,0 +1,18 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""MLIA exceptions module."""
+
+
+class ConfigurationError(Exception):
+ """Configuration error."""
+
+
+class FunctionalityNotSupportedError(Exception):
+ """Functionality is not supported error."""
+
+ def __init__(self, reason: str, description: str) -> None:
+ """Init exception."""
+ super().__init__(f"{reason}: {description}")
+
+ self.reason = reason
+ self.description = description
diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py
new file mode 100644
index 0000000..10aec86
--- /dev/null
+++ b/src/mlia/core/events.py
@@ -0,0 +1,455 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for the events and related functionality.
+
+This module represents one of the main component of the workflow -
+events publishing and provides a way for delivering results to the
+calling application.
+
+Each component of the workflow can generate events of specific type.
+Application can subscribe and react to those events.
+"""
+import traceback
+import uuid
+from abc import ABC
+from abc import abstractmethod
+from contextlib import contextmanager
+from dataclasses import asdict
+from dataclasses import dataclass
+from dataclasses import field
+from functools import singledispatchmethod
+from typing import Any
+from typing import Dict
+from typing import Generator
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from mlia.core.common import DataItem
+
+
+@dataclass
+class Event:
+ """Base class for the events.
+
+ This class is used as a root node of the events class hierarchy.
+ """
+
+ event_id: str = field(init=False)
+
+ def __post_init__(self) -> None:
+ """Generate unique ID for the event."""
+ self.event_id = str(uuid.uuid4())
+
+ def compare_without_id(self, other: "Event") -> bool:
+ """Compare two events without event_id field."""
+ if not isinstance(other, Event) or self.__class__ != other.__class__:
+ return False
+
+ self_as_dict = asdict(self)
+ self_as_dict.pop("event_id")
+
+ other_as_dict = asdict(other)
+ other_as_dict.pop("event_id")
+
+ return self_as_dict == other_as_dict
+
+
+@dataclass
+class ChildEvent(Event):
+ """Child event.
+
+ This class could be used to link event with the parent event.
+ """
+
+ parent_event_id: str
+
+
+@dataclass
+class ActionStartedEvent(Event):
+ """Action started event.
+
+ This event is published when some action has been started.
+ """
+
+ action_type: str
+ params: Optional[Dict] = None
+
+
+@dataclass
+class SubActionEvent(ChildEvent):
+ """SubAction event.
+
+ This event could be used to represent some action during parent action.
+ """
+
+ action_type: str
+ params: Optional[Dict] = None
+
+
+@dataclass
+class ActionFinishedEvent(ChildEvent):
+ """Action finished event.
+
+ This event is published when some action has been finished.
+ """
+
+
+@dataclass
+class SystemEvent(Event):
+ """System event.
+
+ System event class represents events that published by components
+ of the core module. Most common example is an workflow executor
+ that publishes number of system events for starting/completion
+ of different stages/workflow.
+
+ Events that published by components outside of core module should not
+ use this class as base class.
+ """
+
+
+@dataclass
+class ExecutionStartedEvent(SystemEvent):
+ """Execution started event.
+
+ This event is published when workflow execution started.
+ """
+
+
+@dataclass
+class ExecutionFinishedEvent(SystemEvent):
+ """Execution finished event.
+
+ This event is published when workflow execution finished.
+ """
+
+
+@dataclass
+class ExecutionFailedEvent(SystemEvent):
+ """Execution failed event."""
+
+ err: Exception
+
+
+@dataclass
+class DataCollectionStageStartedEvent(SystemEvent):
+ """Data collection stage started.
+
+ This event is published when data collection stage started.
+ """
+
+
+@dataclass
+class DataCollectorSkippedEvent(SystemEvent):
+ """Data collector skipped event.
+
+ This event is published when particular data collector can
+ not provide data for the provided parameters.
+ """
+
+ data_collector: str
+ reason: str
+
+
+@dataclass
+class DataCollectionStageFinishedEvent(SystemEvent):
+ """Data collection stage finished.
+
+ This event is published when data collection stage finished.
+ """
+
+
+@dataclass
+class DataAnalysisStageStartedEvent(SystemEvent):
+ """Data analysis stage started.
+
+ This event is published when data analysis stage started.
+ """
+
+
+@dataclass
+class DataAnalysisStageFinishedEvent(SystemEvent):
+ """Data analysis stage finished.
+
+ This event is published when data analysis stage finished.
+ """
+
+
+@dataclass
+class AdviceStageStartedEvent(SystemEvent):
+ """Advace producing stage started.
+
+ This event is published when advice generation stage started.
+ """
+
+
+@dataclass
+class AdviceStageFinishedEvent(SystemEvent):
+ """Advace producing stage finished.
+
+ This event is published when advice generation stage finished.
+ """
+
+
+@dataclass
+class CollectedDataEvent(SystemEvent):
+ """Collected data event.
+
+ This event is published for every collected data item.
+
+ :param data_item: collected data item
+ """
+
+ data_item: DataItem
+
+
+@dataclass
+class AnalyzedDataEvent(SystemEvent):
+ """Analyzed data event.
+
+ This event is published for every analyzed data item.
+
+ :param data_item: analyzed data item
+ """
+
+ data_item: DataItem
+
+
+class EventHandler:
+ """Base class for the event handlers.
+
+ Each event handler should derive from this base class.
+ """
+
+ def handle_event(self, event: Event) -> None:
+ """Handle the event.
+
+ By default all published events are being passed to each
+ registered event handler. It is handler's responsibility
+ to filter events that it interested in.
+ """
+
+
+class DebugEventHandler(EventHandler):
+ """Event handler for debugging purposes.
+
+ This handler could print every published event to the
+ standard output.
+ """
+
+ def __init__(self, with_stacktrace: bool = False) -> None:
+ """Init event handler.
+
+ :param with_stacktrace: enable printing stacktrace of the
+ place where event publishing occurred.
+ """
+ self.with_stacktrace = with_stacktrace
+
+ def handle_event(self, event: Event) -> None:
+ """Handle event."""
+ print(f"Got event {event}")
+
+ if self.with_stacktrace:
+ traceback.print_stack()
+
+
+class EventDispatcherMetaclass(type):
+ """Metaclass for event dispatching.
+
+ It could be tedious to check type of the published event
+ inside event handler. Instead the following convention could be
+ established: if method name of the class starts with some
+ prefix then it is considered to be event handler of particular
+ type.
+
+ This metaclass goes through the list of class methods and
+ links all methods with the prefix "on_" to the common dispatcher
+ method.
+ """
+
+ def __new__(
+ cls,
+ clsname: str,
+ bases: Tuple,
+ namespace: Dict[str, Any],
+ event_handler_method_prefix: str = "on_",
+ ) -> Any:
+ """Create event dispatcher and link event handlers."""
+ new_class = super().__new__(cls, clsname, bases, namespace)
+
+ @singledispatchmethod
+ def dispatcher(_self: Any, _event: Event) -> Any:
+ """Event dispatcher."""
+
+ # get all class methods which starts with particular prefix
+ event_handler_methods = (
+ (item_name, item)
+ for item_name in dir(new_class)
+ if callable((item := getattr(new_class, item_name)))
+ and item_name.startswith(event_handler_method_prefix)
+ )
+
+ # link all collected event handlers to one dispatcher method
+ for method_name, method in event_handler_methods:
+ event_handler = dispatcher.register(method)
+ setattr(new_class, method_name, event_handler)
+
+ # override default handle_event method, replace it with the
+ # dispatcher
+ setattr(new_class, "handle_event", dispatcher)
+
+ return new_class
+
+
+class EventDispatcher(EventHandler, metaclass=EventDispatcherMetaclass):
+ """Event dispatcher."""
+
+
+class EventPublisher(ABC):
+ """Base class for the event publisher.
+
+ Event publisher is a intermidiate component between event emitter
+ and event consumer.
+ """
+
+ @abstractmethod
+ def register_event_handler(self, event_handler: EventHandler) -> None:
+ """Register event handler.
+
+ :param event_handler: instance of the event handler
+ """
+
+ def register_event_handlers(
+ self, event_handlers: Optional[List[EventHandler]]
+ ) -> None:
+ """Register event handlers.
+
+ Can be used for batch registration of the event handlers:
+
+ :param event_handlers: list of the event handler instances
+ """
+ if not event_handlers:
+ return
+
+ for handler in event_handlers:
+ self.register_event_handler(handler)
+
+ @abstractmethod
+ def publish_event(self, event: Event) -> None:
+ """Publish the event.
+
+ Deliver the event to the all registered event handlers.
+
+ :param event: event instance
+ """
+
+
+class DefaultEventPublisher(EventPublisher):
+ """Default event publishing implementation.
+
+ Simple implementation that maintains list of the registered event
+ handlers.
+ """
+
+ def __init__(self) -> None:
+ """Init the event publisher."""
+ self.handlers: List[EventHandler] = []
+
+ def register_event_handler(self, event_handler: EventHandler) -> None:
+ """Register the event handler.
+
+ :param event_handler: instance of the event handler
+ """
+ self.handlers.append(event_handler)
+
+ def publish_event(self, event: Event) -> None:
+ """Publish the event.
+
+ Publisher does not catch exceptions that could be raised by event handlers.
+ """
+ for handler in self.handlers:
+ handler.handle_event(event)
+
+
+@contextmanager
+def stage(
+ publisher: EventPublisher, events: Tuple[Event, Event]
+) -> Generator[None, None, None]:
+ """Generate events before and after stage.
+
+ This context manager could be used to mark start/finish
+ execution of a particular logical part of the workflow.
+ """
+ started, finished = events
+
+ publisher.publish_event(started)
+ yield
+ publisher.publish_event(finished)
+
+
+@contextmanager
+def action(
+ publisher: EventPublisher, action_type: str, params: Optional[Dict] = None
+) -> Generator[None, None, None]:
+ """Generate events before and after action."""
+ action_started = ActionStartedEvent(action_type, params)
+ action_finished = ActionFinishedEvent(action_started.event_id)
+
+ publisher.publish_event(action_started)
+ yield
+ publisher.publish_event(action_finished)
+
+
+class SystemEventsHandler(EventDispatcher):
+ """System events handler."""
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+
+ def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
+ """Handle ExecutionFinished event."""
+
+ def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
+ """Handle ExecutionFailed event."""
+
+ def on_data_collection_stage_started(
+ self, event: DataCollectionStageStartedEvent
+ ) -> None:
+ """Handle DataCollectionStageStarted event."""
+
+ def on_data_collection_stage_finished(
+ self, event: DataCollectionStageFinishedEvent
+ ) -> None:
+ """Handle DataCollectionStageFinished event."""
+
+ def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
+ """Handle DataCollectorSkipped event."""
+
+ def on_data_analysis_stage_started(
+ self, event: DataAnalysisStageStartedEvent
+ ) -> None:
+ """Handle DataAnalysisStageStartedEvent event."""
+
+ def on_data_analysis_stage_finished(
+ self, event: DataAnalysisStageFinishedEvent
+ ) -> None:
+ """Handle DataAnalysisStageFinishedEvent event."""
+
+ def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
+ """Handle AdviceStageStarted event."""
+
+ def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
+ """Handle AdviceStageFinished event."""
+
+ def on_collected_data(self, event: CollectedDataEvent) -> None:
+ """Handle CollectedData event."""
+
+ def on_analyzed_data(self, event: AnalyzedDataEvent) -> None:
+ """Handle AnalyzedData event."""
+
+ def on_action_started(self, event: ActionStartedEvent) -> None:
+ """Handle ActionStarted event."""
+
+ def on_action_finished(self, event: ActionFinishedEvent) -> None:
+ """Handle ActionFinished event."""
diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py
new file mode 100644
index 0000000..d10ea5d
--- /dev/null
+++ b/src/mlia/core/helpers.py
@@ -0,0 +1,38 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for various helper classes."""
+# pylint: disable=no-self-use, unused-argument
+from typing import Any
+from typing import List
+
+
+class ActionResolver:
+ """Helper class for generating actions (e.g. commands with parameters)."""
+
+ def apply_optimizations(self, **kwargs: Any) -> List[str]:
+ """Return action details for applying optimizations."""
+ return []
+
+ def supported_operators_info(self) -> List[str]:
+ """Return action details for generating supported ops report."""
+ return []
+
+ def check_performance(self) -> List[str]:
+ """Return action details for checking performance."""
+ return []
+
+ def check_operator_compatibility(self) -> List[str]:
+ """Return action details for checking op compatibility."""
+ return []
+
+ def operator_compatibility_details(self) -> List[str]:
+ """Return action details for getting more information about op compatibility."""
+ return []
+
+ def optimization_details(self) -> List[str]:
+ """Return action detail for getting information about optimizations."""
+ return []
+
+
+class APIActionResolver(ActionResolver):
+ """Helper class for the actions performed through API."""
diff --git a/src/mlia/core/mixins.py b/src/mlia/core/mixins.py
new file mode 100644
index 0000000..ee03100
--- /dev/null
+++ b/src/mlia/core/mixins.py
@@ -0,0 +1,54 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Mixins module."""
+from typing import Any
+from typing import Optional
+
+from mlia.core.context import Context
+
+
+class ContextMixin:
+ """Mixin for injecting context object."""
+
+ context: Context
+
+ def set_context(self, context: Context) -> None:
+ """Context setter."""
+ self.context = context
+
+
+class ParameterResolverMixin:
+ """Mixin for parameter resolving."""
+
+ context: Context
+
+ def get_parameter(
+ self,
+ section: str,
+ name: str,
+ expected: bool = True,
+ expected_type: Optional[type] = None,
+ context: Optional[Context] = None,
+ ) -> Any:
+ """Get parameter value."""
+ ctx = context or self.context
+
+ if ctx.config_parameters is None:
+ raise Exception("Configuration parameters are not set")
+
+ section_params = ctx.config_parameters.get(section)
+ if section_params is None or not isinstance(section_params, dict):
+ raise Exception(
+ f"Parameter section {section} has wrong format, "
+ "expected to be a dictionary"
+ )
+
+ value = section_params.get(name)
+
+ if not value and expected:
+ raise Exception(f"Parameter {name} is not set")
+
+ if value and expected_type is not None and not isinstance(value, expected_type):
+ raise Exception(f"Parameter {name} expected to have type {expected_type}")
+
+ return value
diff --git a/src/mlia/core/performance.py b/src/mlia/core/performance.py
new file mode 100644
index 0000000..5433d5c
--- /dev/null
+++ b/src/mlia/core/performance.py
@@ -0,0 +1,47 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for performance estimation."""
+from abc import abstractmethod
+from typing import Callable
+from typing import Generic
+from typing import List
+from typing import TypeVar
+
+
+ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
+PerfMetricsType = TypeVar("PerfMetricsType") # pylint: disable=invalid-name
+
+
+class PerformanceEstimator(Generic[ModelType, PerfMetricsType]):
+ """Base class for the performance estimation."""
+
+ @abstractmethod
+ def estimate(self, model: ModelType) -> PerfMetricsType:
+ """Estimate performance."""
+
+
+def estimate_performance(
+ original_model: ModelType,
+ estimator: PerformanceEstimator[ModelType, PerfMetricsType],
+ model_transformations: List[Callable[[ModelType], ModelType]],
+) -> List[PerfMetricsType]:
+ """Estimate performance impact.
+
+ This function estimates performance impact on model performance after
+ applying provided transformations/optimizations.
+
+ :param original_model: object that represents a model, could be
+ instance of the model or path to the model. This depends on
+ provided performance estimator.
+ :param estimator: performance estimator
+ :param model_transformations: list of the callables each of those
+ returns object that represents optimized model
+ """
+ original_metrics = estimator.estimate(original_model)
+
+ optimized_metrics = [
+ estimator.estimate(transform(original_model))
+ for transform in model_transformations
+ ]
+
+ return [original_metrics, *optimized_metrics]
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
new file mode 100644
index 0000000..1b75bb4
--- /dev/null
+++ b/src/mlia/core/reporting.py
@@ -0,0 +1,762 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Reporting module."""
+import csv
+import json
+import logging
+from abc import ABC
+from abc import abstractmethod
+from collections import defaultdict
+from contextlib import contextmanager
+from contextlib import ExitStack
+from dataclasses import dataclass
+from functools import partial
+from io import TextIOWrapper
+from pathlib import Path
+from textwrap import fill
+from textwrap import indent
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+
+from mlia.core._typing import FileLike
+from mlia.core._typing import OutputFormat
+from mlia.core._typing import PathOrFileLike
+from mlia.utils.console import apply_style
+from mlia.utils.console import produce_table
+from mlia.utils.logging import LoggerWriter
+from mlia.utils.types import is_list_of
+
+logger = logging.getLogger(__name__)
+
+
+class Report(ABC):
+ """Abstract class for the report."""
+
+ @abstractmethod
+ def to_json(self, **kwargs: Any) -> Any:
+ """Convert to json serializible format."""
+
+ @abstractmethod
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert to csv serializible format."""
+
+ @abstractmethod
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Convert to human readable format."""
+
+
+class ReportItem:
+ """Item of the report."""
+
+ def __init__(
+ self,
+ name: str,
+ alias: Optional[str] = None,
+ value: Optional[Union[str, int, "Cell"]] = None,
+ nested_items: Optional[List["ReportItem"]] = None,
+ ) -> None:
+ """Init the report item."""
+ self.name = name
+ self.alias = alias
+ self.value = value
+ self.nested_items = nested_items or []
+
+ @property
+ def compound(self) -> bool:
+ """Return true if item has nested items."""
+ return self.nested_items is not None and len(self.nested_items) > 0
+
+ @property
+ def raw_value(self) -> Any:
+ """Get actual item value."""
+ val = self.value
+ if isinstance(val, Cell):
+ return val.value
+
+ return val
+
+
+@dataclass
+class Format:
+ """Column or cell format.
+
+ Format could be applied either to a column or an individual cell.
+
+ :param wrap_width: width of the wrapped text value
+ :param str_fmt: string format to be applied to the value
+ :param style: text style
+ """
+
+ wrap_width: Optional[int] = None
+ str_fmt: Optional[Union[str, Callable[[Any], str]]] = None
+ style: Optional[str] = None
+
+
+@dataclass
+class Cell:
+ """Cell definition.
+
+ This a wrapper class for a particular value in the table. Could be used
+ for applying specific format to this value.
+ """
+
+ value: Any
+ fmt: Optional[Format] = None
+
+ def _apply_style(self, value: str) -> str:
+ """Apply style to the value."""
+ if self.fmt and self.fmt.style:
+ value = apply_style(value, self.fmt.style)
+
+ return value
+
+ def _get_value(self) -> str:
+ """Return cell value."""
+ if self.fmt:
+ if isinstance(self.fmt.str_fmt, str):
+ return "{:{fmt}}".format(self.value, fmt=self.fmt.str_fmt)
+
+ if callable(self.fmt.str_fmt):
+ return self.fmt.str_fmt(self.value)
+
+ return str(self.value)
+
+ def __str__(self) -> str:
+ """Return string representation."""
+ val = self._get_value()
+ return self._apply_style(val)
+
+ def to_csv(self) -> Any:
+ """Cell definition for csv."""
+ return self.value
+
+ def to_json(self) -> Any:
+ """Cell definition for json."""
+ return self.value
+
+
+class CountAwareCell(Cell):
+ """Count aware cell."""
+
+ def __init__(
+ self,
+ value: Optional[Union[int, float]],
+ singular: str,
+ plural: str,
+ format_string: str = ",d",
+ ):
+ """Init cell instance."""
+ self.unit = singular if value == 1 else plural
+
+ def format_value(val: Optional[Union[int, float]]) -> str:
+ """Provide string representation for the value."""
+ if val is None:
+ return ""
+
+ if val == 1:
+ return f"1 {singular}"
+
+ return f"{val:{format_string}} {plural}"
+
+ super().__init__(value, Format(str_fmt=format_value))
+
+ def to_csv(self) -> Any:
+ """Cell definition for csv."""
+ return {"value": self.value, "unit": self.unit}
+
+ def to_json(self) -> Any:
+ """Cell definition for json."""
+ return {"value": self.value, "unit": self.unit}
+
+
+class BytesCell(CountAwareCell):
+ """Cell that represents memory size."""
+
+ def __init__(self, value: Optional[int]) -> None:
+ """Init cell instance."""
+ super().__init__(value, "byte", "bytes")
+
+
+class CyclesCell(CountAwareCell):
+ """Cell that represents cycles."""
+
+ def __init__(self, value: Optional[Union[int, float]]) -> None:
+ """Init cell instance."""
+ super().__init__(value, "cycle", "cycles", ",.0f")
+
+
+class ClockCell(CountAwareCell):
+ """Cell that represents clock value."""
+
+ def __init__(self, value: Optional[Union[int, float]]) -> None:
+ """Init cell instance."""
+ super().__init__(value, "Hz", "Hz", ",.0f")
+
+
+class Column:
+ """Column definition."""
+
+ def __init__(
+ self,
+ header: str,
+ alias: Optional[str] = None,
+ fmt: Optional[Format] = None,
+ only_for: Optional[List[str]] = None,
+ ) -> None:
+ """Init column definition.
+
+ :param header: column's header
+ :param alias: columns's alias, could be used as column's name
+ :param fmt: format that will be applied for all column's values
+ :param only_for: list of the formats where this column should be
+ represented. May be used to differentiate data representation in
+ different formats
+ """
+ self.header = header
+ self.alias = alias
+ self.fmt = fmt
+ self.only_for = only_for
+
+ def supports_format(self, fmt: str) -> bool:
+ """Return true if column should be shown."""
+ return not self.only_for or fmt in self.only_for
+
+
+class NestedReport(Report):
+ """Report with nested items."""
+
+ def __init__(self, name: str, alias: str, items: List[ReportItem]) -> None:
+ """Init nested report."""
+ self.name = name
+ self.alias = alias
+ self.items = items
+
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert to csv serializible format."""
+ result = {}
+
+ def collect_item_values(
+ item: ReportItem,
+ _parent: Optional[ReportItem],
+ _prev: Optional[ReportItem],
+ _level: int,
+ ) -> None:
+ """Collect item values into a dictionary.."""
+ if item.value is None:
+ return
+
+ if not isinstance(item.value, Cell):
+ result[item.alias] = item.raw_value
+ return
+
+ csv_value = item.value.to_csv()
+ if isinstance(csv_value, dict):
+ csv_value = {
+ f"{item.alias}_{key}": value for key, value in csv_value.items()
+ }
+ else:
+ csv_value = {item.alias: csv_value}
+
+ result.update(csv_value)
+
+ self._traverse(self.items, collect_item_values)
+
+ # make list out of the result dictionary
+ # first element - keys of the dictionary as headers
+ # second element - list of the dictionary values
+ return list(zip(*result.items()))
+
+ def to_json(self, **kwargs: Any) -> Any:
+ """Convert to json serializible format."""
+ per_parent: Dict[Optional[ReportItem], Dict] = defaultdict(dict)
+ result = per_parent[None]
+
+ def collect_as_dicts(
+ item: ReportItem,
+ parent: Optional[ReportItem],
+ _prev: Optional[ReportItem],
+ _level: int,
+ ) -> None:
+ """Collect item values as nested dictionaries."""
+ parent_dict = per_parent[parent]
+
+ if item.compound:
+ item_dict = per_parent[item]
+ parent_dict[item.alias] = item_dict
+ else:
+ out_dis = (
+ item.value.to_json()
+ if isinstance(item.value, Cell)
+ else item.raw_value
+ )
+ parent_dict[item.alias] = out_dis
+
+ self._traverse(self.items, collect_as_dicts)
+
+ return {self.alias: result}
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Convert to human readable format."""
+ header = f"{self.name}:\n"
+ processed_items = []
+
+ def convert_to_text(
+ item: ReportItem,
+ _parent: Optional[ReportItem],
+ prev: Optional[ReportItem],
+ level: int,
+ ) -> None:
+ """Convert item to text representation."""
+ if level >= 1 and prev is not None and (item.compound or prev.compound):
+ processed_items.append("")
+
+ val = self._item_value(item, level)
+ processed_items.append(val)
+
+ self._traverse(self.items, convert_to_text)
+ body = "\n".join(processed_items)
+
+ return header + body
+
+ @staticmethod
+ def _item_value(
+ item: ReportItem, level: int, tab_size: int = 2, column_width: int = 35
+ ) -> str:
+ """Get report item value."""
+ shift = " " * tab_size * level
+ if item.value is None:
+ return f"{shift}{item.name}:"
+
+ col1 = f"{shift}{item.name}".ljust(column_width)
+ col2 = f"{item.value}".rjust(column_width)
+
+ return col1 + col2
+
+ def _traverse(
+ self,
+ items: List[ReportItem],
+ visit_item: Callable[
+ [ReportItem, Optional[ReportItem], Optional[ReportItem], int], None
+ ],
+ level: int = 1,
+ parent: Optional[ReportItem] = None,
+ ) -> None:
+ """Traverse through items."""
+ prev = None
+ for item in items:
+ visit_item(item, parent, prev, level)
+
+ self._traverse(item.nested_items, visit_item, level + 1, item)
+ prev = item
+
+
+class Table(Report):
+ """Table definition.
+
+ This class could be used for representing tabular data.
+ """
+
+ def __init__(
+ self,
+ columns: List[Column],
+ rows: Collection,
+ name: str,
+ alias: Optional[str] = None,
+ notes: Optional[str] = None,
+ ) -> None:
+ """Init table definition.
+
+ :param columns: list of the table's columns
+ :param rows: list of the table's rows
+ :param name: name of the table
+ :param alias: alias for the table
+ """
+ self.columns = columns
+ self.rows = rows
+ self.name = name
+ self.alias = alias
+ self.notes = notes
+
+ def to_json(self, **kwargs: Any) -> Iterable:
+ """Convert table to dict object."""
+
+ def item_to_json(item: Any) -> Any:
+ value = item
+ if isinstance(item, Cell):
+ value = item.value
+
+ if isinstance(value, Table):
+ return value.to_json()
+
+ return value
+
+ json_data = [
+ {
+ col.alias or col.header: item_to_json(item)
+ for (item, col) in zip(row, self.columns)
+ if col.supports_format("json")
+ }
+ for row in self.rows
+ ]
+
+ if not self.alias:
+ return json_data
+
+ return {self.alias: json_data}
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Produce report in human readable format."""
+ nested = kwargs.get("nested", False)
+ show_headers = kwargs.get("show_headers", True)
+ show_title = kwargs.get("show_title", True)
+ table_style = kwargs.get("table_style", "default")
+ space = kwargs.get("space", False)
+
+ headers = (
+ [] if (nested or not show_headers) else [c.header for c in self.columns]
+ )
+
+ def item_to_plain_text(item: Any, col: Column) -> str:
+ """Convert item to text."""
+ if isinstance(item, Table):
+ return item.to_plain_text(nested=True, **kwargs)
+
+ if is_list_of(item, str):
+ as_text = "\n".join(item)
+ else:
+ as_text = str(item)
+
+ if col.fmt and col.fmt.wrap_width:
+ as_text = fill(as_text, col.fmt.wrap_width)
+
+ return as_text
+
+ title = ""
+ if show_title and not nested:
+ title = f"{self.name}:\n"
+
+ if space in (True, "top"):
+ title = "\n" + title
+
+ footer = ""
+ if space in (True, "bottom"):
+ footer = "\n"
+ if self.notes:
+ footer = "\n" + self.notes
+
+ formatted_rows = (
+ (
+ item_to_plain_text(item, col)
+ for item, col in zip(row, self.columns)
+ if col.supports_format("plain_text")
+ )
+ for row in self.rows
+ )
+
+ if space == "between":
+ formatted_table = "\n\n".join(
+ produce_table([row], table_style=table_style) for row in formatted_rows
+ )
+ else:
+ formatted_table = produce_table(
+ formatted_rows,
+ headers=headers,
+ table_style="nested" if nested else table_style,
+ )
+
+ return title + formatted_table + footer
+
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert table to csv format."""
+ headers = [[c.header for c in self.columns if c.supports_format("csv")]]
+
+ def item_data(item: Any) -> Any:
+ if isinstance(item, Cell):
+ return item.value
+
+ if isinstance(item, Table):
+ return ";".join(
+ str(item_data(cell)) for row in item.rows for cell in row
+ )
+
+ return item
+
+ rows = [
+ [
+ item_data(item)
+ for (item, col) in zip(row, self.columns)
+ if col.supports_format("csv")
+ ]
+ for row in self.rows
+ ]
+
+ return headers + rows
+
+
+class SingleRow(Table):
+ """Table with a single row."""
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Produce report in human readable format."""
+ if len(self.rows) != 1:
+ raise Exception("Table should have only one row")
+
+ items = "\n".join(
+ column.header.ljust(35) + str(item).rjust(25)
+ for row in self.rows
+ for item, column in zip(row, self.columns)
+ if column.supports_format("plain_text")
+ )
+
+ return "\n".join([f"{self.name}:", indent(items, " ")])
+
+
+class CompoundReport(Report):
+ """Compound report.
+
+ This class could be used for producing multiple reports at once.
+ """
+
+ def __init__(self, reports: List[Report]) -> None:
+ """Init compound report instance."""
+ self.reports = reports
+
+ def to_json(self, **kwargs: Any) -> Any:
+ """Convert to json serializible format.
+
+ Method attempts to create compound dictionary based on provided
+ parts.
+ """
+ result: Dict[str, Any] = {}
+ for item in self.reports:
+ result.update(item.to_json(**kwargs))
+
+ return result
+
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert to csv serializible format.
+
+ CSV format does support only one table. In order to be able to export
+ multiply tables they should be merged before that. This method tries to
+ do next:
+
+ - if all tables have the same length then just concatenate them
+ - if one table has many rows and other just one (two with headers), then
+ for each row in table with many rows duplicate values from other tables
+ """
+ csv_data = [item.to_csv() for item in self.reports]
+ lengths = [len(csv_item_data) for csv_item_data in csv_data]
+
+ same_length = len(set(lengths)) == 1
+ if same_length:
+ # all lists are of the same length, merge them into one
+ return [[cell for item in row for cell in item] for row in zip(*csv_data)]
+
+ main_obj_indexes = [i for i, item in enumerate(csv_data) if len(item) > 2]
+ one_main_obj = len(main_obj_indexes) == 1
+
+ reference_obj_indexes = [i for i, item in enumerate(csv_data) if len(item) == 2]
+ other_only_ref_objs = len(reference_obj_indexes) == len(csv_data) - 1
+
+ if one_main_obj and other_only_ref_objs:
+ main_obj = csv_data[main_obj_indexes[0]]
+ return [
+ item
+ + [
+ ref_item
+ for ref_table_index in reference_obj_indexes
+ for ref_item in csv_data[ref_table_index][0 if i == 0 else 1]
+ ]
+ for i, item in enumerate(main_obj)
+ ]
+
+ # write tables one after another if there is no other options
+ return [row for item in csv_data for row in item]
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Convert to human readable format."""
+ return "\n".join(item.to_plain_text(**kwargs) for item in self.reports)
+
+
+class CompoundFormatter:
+ """Compound data formatter."""
+
+ def __init__(self, formatters: List[Callable]) -> None:
+ """Init compound formatter."""
+ self.formatters = formatters
+
+ def __call__(self, data: Any) -> Report:
+ """Produce report."""
+ reports = [formatter(item) for item, formatter in zip(data, self.formatters)]
+ return CompoundReport(reports)
+
+
+class CustomJSONEncoder(json.JSONEncoder):
+ """Custom JSON encoder."""
+
+ def default(self, o: Any) -> Any:
+ """Support numpy types."""
+ if isinstance(o, np.integer):
+ return int(o)
+
+ if isinstance(o, np.floating):
+ return float(o)
+
+ return json.JSONEncoder.default(self, o)
+
+
+def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
+ """Produce report in json format."""
+ json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder)
+ print(json_str, file=output)
+
+
+def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
+ """Produce report in text format."""
+ print(report.to_plain_text(**kwargs), file=output)
+
+
+def csv_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
+ """Produce report in csv format."""
+ csv_writer = csv.writer(output)
+ csv_writer.writerows(report.to_csv(**kwargs))
+
+
+def produce_report(
+ data: Any,
+ formatter: Callable[[Any], Report],
+ fmt: OutputFormat = "plain_text",
+ output: Optional[PathOrFileLike] = None,
+ **kwargs: Any,
+) -> None:
+ """Produce report based on provided data."""
+ # check if provided format value is supported
+ formats = {"json": json_reporter, "plain_text": text_reporter, "csv": csv_reporter}
+ if fmt not in formats:
+ raise Exception(f"Unknown format {fmt}")
+
+ if output is None:
+ output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO))
+
+ with ExitStack() as exit_stack:
+ if isinstance(output, (str, Path)):
+ # open file and add it to the ExitStack context manager
+ # in that case it will be automatically closed
+ stream = exit_stack.enter_context(open(output, "w", encoding="utf-8"))
+ else:
+ stream = cast(TextIOWrapper, output)
+
+ # convert data into serializable form
+ formatted_data = formatter(data)
+ # find handler for the format
+ format_handler = formats[fmt]
+ # produce report in requested format
+ format_handler(formatted_data, stream, **kwargs)
+
+
+class Reporter:
+ """Reporter class."""
+
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ output_format: OutputFormat = "plain_text",
+ print_as_submitted: bool = True,
+ ) -> None:
+ """Init reporter instance."""
+ self.formatter_resolver = formatter_resolver
+ self.output_format = output_format
+ self.print_as_submitted = print_as_submitted
+
+ self.data: List[Tuple[Any, Callable[[Any], Report]]] = []
+ self.delayed: List[Tuple[Any, Callable[[Any], Report]]] = []
+
+ def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None:
+ """Submit data for the report."""
+ if self.print_as_submitted and not delay_print:
+ produce_report(
+ data_item,
+ self.formatter_resolver(data_item),
+ fmt="plain_text",
+ **kwargs,
+ )
+
+ formatter = _apply_format_parameters(
+ self.formatter_resolver(data_item), self.output_format, **kwargs
+ )
+ self.data.append((data_item, formatter))
+
+ if delay_print:
+ self.delayed.append((data_item, formatter))
+
+ def print_delayed(self) -> None:
+ """Print delayed reports."""
+ if not self.delayed:
+ return
+
+ data, formatters = zip(*self.delayed)
+ produce_report(
+ data,
+ formatter=CompoundFormatter(formatters),
+ fmt="plain_text",
+ )
+ self.delayed = []
+
+ def generate_report(self, output: Optional[PathOrFileLike]) -> None:
+ """Generate report."""
+ already_printed = (
+ self.print_as_submitted
+ and self.output_format == "plain_text"
+ and output is None
+ )
+ if not self.data or already_printed:
+ return
+
+ data, formatters = zip(*self.data)
+ produce_report(
+ data,
+ formatter=CompoundFormatter(formatters),
+ fmt=self.output_format,
+ output=output,
+ )
+
+
+@contextmanager
+def get_reporter(
+ output_format: OutputFormat,
+ output: Optional[PathOrFileLike],
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+) -> Generator[Reporter, None, None]:
+ """Get reporter and generate report."""
+ reporter = Reporter(formatter_resolver, output_format)
+
+ yield reporter
+
+ reporter.generate_report(output)
+
+
+def _apply_format_parameters(
+ formatter: Callable[[Any], Report], output_format: OutputFormat, **kwargs: Any
+) -> Callable[[Any], Report]:
+ """Wrap report method."""
+
+ def wrapper(data: Any) -> Report:
+ report = formatter(data)
+ method_name = f"to_{output_format}"
+ method = getattr(report, method_name)
+ setattr(report, method_name, partial(method, **kwargs))
+
+ return report
+
+ return wrapper
diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py
new file mode 100644
index 0000000..0245087
--- /dev/null
+++ b/src/mlia/core/workflow.py
@@ -0,0 +1,216 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for executors.
+
+This module contains implementation of the workflow
+executors.
+"""
+import itertools
+from abc import ABC
+from abc import abstractmethod
+from functools import wraps
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import AdviceEvent
+from mlia.core.advice_generation import AdviceProducer
+from mlia.core.common import DataItem
+from mlia.core.context import Context
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.core.events import AdviceStageFinishedEvent
+from mlia.core.events import AdviceStageStartedEvent
+from mlia.core.events import AnalyzedDataEvent
+from mlia.core.events import CollectedDataEvent
+from mlia.core.events import DataAnalysisStageFinishedEvent
+from mlia.core.events import DataAnalysisStageStartedEvent
+from mlia.core.events import DataCollectionStageFinishedEvent
+from mlia.core.events import DataCollectionStageStartedEvent
+from mlia.core.events import DataCollectorSkippedEvent
+from mlia.core.events import Event
+from mlia.core.events import ExecutionFailedEvent
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.events import stage
+from mlia.core.mixins import ContextMixin
+
+
+class WorkflowExecutor(ABC):
+ """Base workflow executor."""
+
+ @abstractmethod
+ def run(self) -> None:
+ """Run the module."""
+
+
+STAGE_COLLECTION = (
+ DataCollectionStageStartedEvent(),
+ DataCollectionStageFinishedEvent(),
+)
+STAGE_ANALYSIS = (DataAnalysisStageStartedEvent(), DataAnalysisStageFinishedEvent())
+STAGE_ADVICE = (AdviceStageStartedEvent(), AdviceStageFinishedEvent())
+
+
+def on_stage(stage_events: Tuple[Event, Event]) -> Callable:
+ """Mark start/finish of the stage with appropriate events."""
+
+ def wrapper(method: Callable) -> Callable:
+ """Wrap method."""
+
+ @wraps(method)
+ def publish_events(self: Any, *args: Any, **kwargs: Any) -> Any:
+ """Publish events before and after execution."""
+ with stage(self.context.event_publisher, stage_events):
+ return method(self, *args, **kwargs)
+
+ return publish_events
+
+ return wrapper
+
+
+class DefaultWorkflowExecutor(WorkflowExecutor):
+ """Default module executor.
+
+ This is a default implementation of the workflow executor.
+ All components are launched sequentually in the same process.
+ """
+
+ def __init__(
+ self,
+ context: Context,
+ collectors: Sequence[DataCollector],
+ analyzers: Sequence[DataAnalyzer],
+ producers: Sequence[AdviceProducer],
+ before_start_events: Optional[Sequence[Event]] = None,
+ ):
+ """Init default workflow executor.
+
+ :param context: Context instance
+ :param collectors: List of the data collectors
+ :param analyzers: List of the data analyzers
+ :param producers: List of the advice producers
+ :param before_start_events: Optional list of the custom events that
+ should be published before start of the worfkow execution.
+ """
+ self.context = context
+ self.collectors = collectors
+ self.analyzers = analyzers
+ self.producers = producers
+ self.before_start_events = before_start_events
+
+ def run(self) -> None:
+ """Run the workflow."""
+ self.inject_context()
+ self.context.register_event_handlers()
+
+ try:
+ self.publish(ExecutionStartedEvent())
+
+ self.before_start()
+
+ collected_data = self.collect_data()
+ analyzed_data = self.analyze_data(collected_data)
+
+ self.produce_advice(analyzed_data)
+ except Exception as err: # pylint: disable=broad-except
+ self.publish(ExecutionFailedEvent(err))
+ else:
+ self.publish(ExecutionFinishedEvent())
+
+ def before_start(self) -> None:
+ """Run actions before start of the workflow execution."""
+ events = self.before_start_events or []
+ for event in events:
+ self.publish(event)
+
+ @on_stage(STAGE_COLLECTION)
+ def collect_data(self) -> List[DataItem]:
+ """Collect data.
+
+ Run each of data collector components and return list of
+ the collected data items.
+ """
+ collected_data = []
+ for collector in self.collectors:
+ try:
+ if (data_item := collector.collect_data()) is not None:
+ collected_data.append(data_item)
+ self.publish(CollectedDataEvent(data_item))
+ except FunctionalityNotSupportedError as err:
+ self.publish(DataCollectorSkippedEvent(collector.name(), str(err)))
+
+ return collected_data
+
+ @on_stage(STAGE_ANALYSIS)
+ def analyze_data(self, collected_data: List[DataItem]) -> List[DataItem]:
+ """Analyze data.
+
+ Pass each collected data item into each data analyzer and
+ return analyzed data.
+
+ :param collected_data: list of collected data items
+ """
+ analyzed_data = []
+ for analyzer in self.analyzers:
+ for item in collected_data:
+ analyzer.analyze_data(item)
+
+ for data_item in analyzer.get_analyzed_data():
+ analyzed_data.append(data_item)
+
+ self.publish(AnalyzedDataEvent(data_item))
+ return analyzed_data
+
+ @on_stage(STAGE_ADVICE)
+ def produce_advice(self, analyzed_data: List[DataItem]) -> None:
+ """Produce advice.
+
+ Pass each analyzed data item into each advice producer and
+ publish generated advice.
+
+ :param analyzed_data: list of analyzed data items
+ """
+ for producer in self.producers:
+ for data_item in analyzed_data:
+ producer.produce_advice(data_item)
+
+ advice = producer.get_advice()
+ if isinstance(advice, Advice):
+ advice = [advice]
+
+ for item in advice:
+ self.publish(AdviceEvent(item))
+
+ def inject_context(self) -> None:
+ """Inject context object into components.
+
+ Inject context object into components that supports context
+ injection.
+ """
+ context_aware_components = (
+ comp
+ for comp in itertools.chain(
+ self.collectors,
+ self.analyzers,
+ self.producers,
+ )
+ if isinstance(comp, ContextMixin)
+ )
+
+ for component in context_aware_components:
+ component.set_context(self.context)
+
+ def publish(self, event: Event) -> None:
+ """Publish event.
+
+ Helper method for event publising.
+
+ :param event: event instance
+ """
+ self.context.event_publisher.publish_event(event)