aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/core
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-07-21 14:06:03 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-08-19 10:23:23 +0100
commita8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (patch)
tree8463b24ba0446a49b3e012477b0834c3b5415b86 /src/mlia/core
parent76ec769ad8f8ed53ec3ff829fdd34d53db8229fd (diff)
downloadmlia-a8ee1aee3e674c78a77801d1bf2256881ab6b4b9.tar.gz
MLIA-549 Refactor API module to support several target profiles
- Move target specific details out of API module - Move common logic for workflow event handler into a separate class Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788
Diffstat (limited to 'src/mlia/core')
-rw-r--r--src/mlia/core/advisor.py64
-rw-r--r--src/mlia/core/events.py54
-rw-r--r--src/mlia/core/handlers.py166
-rw-r--r--src/mlia/core/reporting.py11
-rw-r--r--src/mlia/core/workflow.py8
5 files changed, 245 insertions, 58 deletions
diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py
index 868d0c7..13689fa 100644
--- a/src/mlia/core/advisor.py
+++ b/src/mlia/core/advisor.py
@@ -2,9 +2,18 @@
# SPDX-License-Identifier: Apache-2.0
"""Inference advisor module."""
from abc import abstractmethod
+from pathlib import Path
+from typing import cast
+from typing import List
+from mlia.core.advice_generation import AdviceProducer
from mlia.core.common import NamedEntity
from mlia.core.context import Context
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.events import Event
+from mlia.core.mixins import ParameterResolverMixin
+from mlia.core.workflow import DefaultWorkflowExecutor
from mlia.core.workflow import WorkflowExecutor
@@ -19,3 +28,58 @@ class InferenceAdvisor(NamedEntity):
"""Run inference advisor."""
executor = self.configure(context)
executor.run()
+
+
+class DefaultInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
+ """Default implementation for the advisor."""
+
+ def configure(self, context: Context) -> WorkflowExecutor:
+ """Configure advisor."""
+ return DefaultWorkflowExecutor(
+ context,
+ self.get_collectors(context),
+ self.get_analyzers(context),
+ self.get_producers(context),
+ self.get_events(context),
+ )
+
+ @abstractmethod
+ def get_collectors(self, context: Context) -> List[DataCollector]:
+ """Return list of the data collectors."""
+
+ @abstractmethod
+ def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
+ """Return list of the data analyzers."""
+
+ @abstractmethod
+ def get_producers(self, context: Context) -> List[AdviceProducer]:
+ """Return list of the advice producers."""
+
+ @abstractmethod
+ def get_events(self, context: Context) -> List[Event]:
+ """Return list of the startup events."""
+
+ def get_string_parameter(self, context: Context, param: str) -> str:
+ """Get string parameter value."""
+ value = self.get_parameter(
+ self.name(),
+ param,
+ expected_type=str,
+ context=context,
+ )
+
+ return cast(str, value)
+
+ def get_model(self, context: Context) -> Path:
+ """Get path to the model."""
+ model_param = self.get_string_parameter(context, "model")
+
+ model = Path(model_param)
+ if not model.exists():
+ raise Exception(f"Path {model} does not exist")
+
+ return model
+
+ def get_target_profile(self, context: Context) -> str:
+ """Get target profile."""
+ return self.get_string_parameter(context, "target_profile")
diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py
index 10aec86..0b8461b 100644
--- a/src/mlia/core/events.py
+++ b/src/mlia/core/events.py
@@ -399,57 +399,3 @@ def action(
publisher.publish_event(action_started)
yield
publisher.publish_event(action_finished)
-
-
-class SystemEventsHandler(EventDispatcher):
- """System events handler."""
-
- def on_execution_started(self, event: ExecutionStartedEvent) -> None:
- """Handle ExecutionStarted event."""
-
- def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
- """Handle ExecutionFinished event."""
-
- def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
- """Handle ExecutionFailed event."""
-
- def on_data_collection_stage_started(
- self, event: DataCollectionStageStartedEvent
- ) -> None:
- """Handle DataCollectionStageStarted event."""
-
- def on_data_collection_stage_finished(
- self, event: DataCollectionStageFinishedEvent
- ) -> None:
- """Handle DataCollectionStageFinished event."""
-
- def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
- """Handle DataCollectorSkipped event."""
-
- def on_data_analysis_stage_started(
- self, event: DataAnalysisStageStartedEvent
- ) -> None:
- """Handle DataAnalysisStageStartedEvent event."""
-
- def on_data_analysis_stage_finished(
- self, event: DataAnalysisStageFinishedEvent
- ) -> None:
- """Handle DataAnalysisStageFinishedEvent event."""
-
- def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
- """Handle AdviceStageStarted event."""
-
- def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
- """Handle AdviceStageFinished event."""
-
- def on_collected_data(self, event: CollectedDataEvent) -> None:
- """Handle CollectedData event."""
-
- def on_analyzed_data(self, event: AnalyzedDataEvent) -> None:
- """Handle AnalyzedData event."""
-
- def on_action_started(self, event: ActionStartedEvent) -> None:
- """Handle ActionStarted event."""
-
- def on_action_finished(self, event: ActionFinishedEvent) -> None:
- """Handle ActionFinished event."""
diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py
new file mode 100644
index 0000000..e576f74
--- /dev/null
+++ b/src/mlia/core/handlers.py
@@ -0,0 +1,166 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Event handlers module."""
+import logging
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Optional
+
+from mlia.core._typing import PathOrFileLike
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import AdviceEvent
+from mlia.core.events import ActionFinishedEvent
+from mlia.core.events import ActionStartedEvent
+from mlia.core.events import AdviceStageFinishedEvent
+from mlia.core.events import AdviceStageStartedEvent
+from mlia.core.events import AnalyzedDataEvent
+from mlia.core.events import CollectedDataEvent
+from mlia.core.events import DataAnalysisStageFinishedEvent
+from mlia.core.events import DataAnalysisStageStartedEvent
+from mlia.core.events import DataCollectionStageFinishedEvent
+from mlia.core.events import DataCollectionStageStartedEvent
+from mlia.core.events import DataCollectorSkippedEvent
+from mlia.core.events import EventDispatcher
+from mlia.core.events import ExecutionFailedEvent
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.reporting import Report
+from mlia.core.reporting import Reporter
+from mlia.core.reporting import resolve_output_format
+from mlia.utils.console import create_section_header
+
+
+logger = logging.getLogger(__name__)
+
+
+class SystemEventsHandler(EventDispatcher):
+ """System events handler."""
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+
+ def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
+ """Handle ExecutionFinished event."""
+
+ def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
+ """Handle ExecutionFailed event."""
+
+ def on_data_collection_stage_started(
+ self, event: DataCollectionStageStartedEvent
+ ) -> None:
+ """Handle DataCollectionStageStarted event."""
+
+ def on_data_collection_stage_finished(
+ self, event: DataCollectionStageFinishedEvent
+ ) -> None:
+ """Handle DataCollectionStageFinished event."""
+
+ def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
+ """Handle DataCollectorSkipped event."""
+
+ def on_data_analysis_stage_started(
+ self, event: DataAnalysisStageStartedEvent
+ ) -> None:
+ """Handle DataAnalysisStageStartedEvent event."""
+
+ def on_data_analysis_stage_finished(
+ self, event: DataAnalysisStageFinishedEvent
+ ) -> None:
+ """Handle DataAnalysisStageFinishedEvent event."""
+
+ def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
+ """Handle AdviceStageStarted event."""
+
+ def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
+ """Handle AdviceStageFinished event."""
+
+ def on_collected_data(self, event: CollectedDataEvent) -> None:
+ """Handle CollectedData event."""
+
+ def on_analyzed_data(self, event: AnalyzedDataEvent) -> None:
+ """Handle AnalyzedData event."""
+
+ def on_action_started(self, event: ActionStartedEvent) -> None:
+ """Handle ActionStarted event."""
+
+ def on_action_finished(self, event: ActionFinishedEvent) -> None:
+ """Handle ActionFinished event."""
+
+
+_ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started")
+_MODEL_ANALYSIS_MSG = create_section_header("Model Analysis")
+_MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results")
+_ADV_GENERATION_MSG = create_section_header("Advice Generation")
+_REPORT_GENERATION_MSG = create_section_header("Report Generation")
+
+
+class WorkflowEventsHandler(SystemEventsHandler):
+ """Event handler for the system events."""
+
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ output: Optional[PathOrFileLike] = None,
+ ) -> None:
+ """Init event handler."""
+ output_format = resolve_output_format(output)
+ self.reporter = Reporter(formatter_resolver, output_format)
+ self.output = output
+
+ self.advice: List[Advice] = []
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+ logger.info(_ADV_EXECUTION_STARTED)
+
+ def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
+ """Handle ExecutionFailed event."""
+ raise event.err
+
+ def on_data_collection_stage_started(
+ self, event: DataCollectionStageStartedEvent
+ ) -> None:
+ """Handle DataCollectionStageStarted event."""
+ logger.info(_MODEL_ANALYSIS_MSG)
+
+ def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
+ """Handle AdviceStageStarted event."""
+ logger.info(_ADV_GENERATION_MSG)
+
+ def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
+ """Handle DataCollectorSkipped event."""
+ logger.info("Skipped: %s", event.reason)
+
+ @staticmethod
+ def report_generated(output: PathOrFileLike) -> None:
+ """Log report generation."""
+ logger.info(_REPORT_GENERATION_MSG)
+ logger.info("Report(s) and advice list saved to: %s", output)
+
+ def on_data_analysis_stage_finished(
+ self, event: DataAnalysisStageFinishedEvent
+ ) -> None:
+ """Handle DataAnalysisStageFinished event."""
+ logger.info(_MODEL_ANALYSIS_RESULTS_MSG)
+
+ self.reporter.print_delayed()
+
+ def on_advice_event(self, event: AdviceEvent) -> None:
+ """Handle Advice event."""
+ self.advice.append(event.advice)
+
+ def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
+ """Handle AdviceStageFinishedEvent event."""
+ self.reporter.submit(
+ self.advice,
+ show_title=False,
+ show_headers=False,
+ space="between",
+ table_style="no_borders",
+ )
+
+ self.reporter.generate_report(self.output)
+
+ if self.output is not None:
+ self.report_generated(self.output)
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index 9006602..58a41d3 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -760,3 +760,14 @@ def _apply_format_parameters(
return report
return wrapper
+
+
+def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat:
+ """Resolve output format based on the output name."""
+ if isinstance(output, (str, Path)):
+ format_from_filename = Path(output).suffix.lstrip(".")
+
+ if format_from_filename in ["json", "csv"]:
+ return cast(OutputFormat, format_from_filename)
+
+ return "plain_text"
diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py
index 0245087..03f3d1c 100644
--- a/src/mlia/core/workflow.py
+++ b/src/mlia/core/workflow.py
@@ -87,7 +87,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
collectors: Sequence[DataCollector],
analyzers: Sequence[DataAnalyzer],
producers: Sequence[AdviceProducer],
- before_start_events: Optional[Sequence[Event]] = None,
+ startup_events: Optional[Sequence[Event]] = None,
):
"""Init default workflow executor.
@@ -95,14 +95,14 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
:param collectors: List of the data collectors
:param analyzers: List of the data analyzers
:param producers: List of the advice producers
- :param before_start_events: Optional list of the custom events that
+ :param startup_events: Optional list of the custom events that
should be published before start of the worfkow execution.
"""
self.context = context
self.collectors = collectors
self.analyzers = analyzers
self.producers = producers
- self.before_start_events = before_start_events
+ self.startup_events = startup_events
def run(self) -> None:
"""Run the workflow."""
@@ -125,7 +125,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
def before_start(self) -> None:
"""Run actions before start of the workflow execution."""
- events = self.before_start_events or []
+ events = self.startup_events or []
for event in events:
self.publish(event)