From a8ee1aee3e674c78a77801d1bf2256881ab6b4b9 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 21 Jul 2022 14:06:03 +0100 Subject: MLIA-549 Refactor API module to support several target profiles - Move target specific details out of API module - Move common logic for workflow event handler into a separate class Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788 --- src/mlia/core/advisor.py | 64 +++++++++++++++++ src/mlia/core/events.py | 54 --------------- src/mlia/core/handlers.py | 166 +++++++++++++++++++++++++++++++++++++++++++++ src/mlia/core/reporting.py | 11 +++ src/mlia/core/workflow.py | 8 +-- 5 files changed, 245 insertions(+), 58 deletions(-) create mode 100644 src/mlia/core/handlers.py (limited to 'src/mlia/core') diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py index 868d0c7..13689fa 100644 --- a/src/mlia/core/advisor.py +++ b/src/mlia/core/advisor.py @@ -2,9 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 """Inference advisor module.""" from abc import abstractmethod +from pathlib import Path +from typing import cast +from typing import List +from mlia.core.advice_generation import AdviceProducer from mlia.core.common import NamedEntity from mlia.core.context import Context +from mlia.core.data_analysis import DataAnalyzer +from mlia.core.data_collection import DataCollector +from mlia.core.events import Event +from mlia.core.mixins import ParameterResolverMixin +from mlia.core.workflow import DefaultWorkflowExecutor from mlia.core.workflow import WorkflowExecutor @@ -19,3 +28,58 @@ class InferenceAdvisor(NamedEntity): """Run inference advisor.""" executor = self.configure(context) executor.run() + + +class DefaultInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin): + """Default implementation for the advisor.""" + + def configure(self, context: Context) -> WorkflowExecutor: + """Configure advisor.""" + return DefaultWorkflowExecutor( + context, + self.get_collectors(context), + self.get_analyzers(context), + self.get_producers(context), + self.get_events(context), + ) + + @abstractmethod + def get_collectors(self, context: Context) -> List[DataCollector]: + """Return list of the data collectors.""" + + @abstractmethod + def get_analyzers(self, context: Context) -> List[DataAnalyzer]: + """Return list of the data analyzers.""" + + @abstractmethod + def get_producers(self, context: Context) -> List[AdviceProducer]: + """Return list of the advice producers.""" + + @abstractmethod + def get_events(self, context: Context) -> List[Event]: + """Return list of the startup events.""" + + def get_string_parameter(self, context: Context, param: str) -> str: + """Get string parameter value.""" + value = self.get_parameter( + self.name(), + param, + expected_type=str, + context=context, + ) + + return cast(str, value) + + def get_model(self, context: Context) -> Path: + """Get path to the model.""" + model_param = self.get_string_parameter(context, "model") + + model = Path(model_param) + if not model.exists(): + raise Exception(f"Path {model} does not exist") + + return model + + def get_target_profile(self, context: Context) -> str: + """Get target profile.""" + return self.get_string_parameter(context, "target_profile") diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py index 10aec86..0b8461b 100644 --- a/src/mlia/core/events.py +++ b/src/mlia/core/events.py @@ -399,57 +399,3 @@ def action( publisher.publish_event(action_started) yield publisher.publish_event(action_finished) - - -class SystemEventsHandler(EventDispatcher): - """System events handler.""" - - def on_execution_started(self, event: ExecutionStartedEvent) -> None: - """Handle ExecutionStarted event.""" - - def on_execution_finished(self, event: ExecutionFinishedEvent) -> None: - """Handle ExecutionFinished event.""" - - def on_execution_failed(self, event: ExecutionFailedEvent) -> None: - """Handle ExecutionFailed event.""" - - def on_data_collection_stage_started( - self, event: DataCollectionStageStartedEvent - ) -> None: - """Handle DataCollectionStageStarted event.""" - - def on_data_collection_stage_finished( - self, event: DataCollectionStageFinishedEvent - ) -> None: - """Handle DataCollectionStageFinished event.""" - - def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None: - """Handle DataCollectorSkipped event.""" - - def on_data_analysis_stage_started( - self, event: DataAnalysisStageStartedEvent - ) -> None: - """Handle DataAnalysisStageStartedEvent event.""" - - def on_data_analysis_stage_finished( - self, event: DataAnalysisStageFinishedEvent - ) -> None: - """Handle DataAnalysisStageFinishedEvent event.""" - - def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None: - """Handle AdviceStageStarted event.""" - - def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None: - """Handle AdviceStageFinished event.""" - - def on_collected_data(self, event: CollectedDataEvent) -> None: - """Handle CollectedData event.""" - - def on_analyzed_data(self, event: AnalyzedDataEvent) -> None: - """Handle AnalyzedData event.""" - - def on_action_started(self, event: ActionStartedEvent) -> None: - """Handle ActionStarted event.""" - - def on_action_finished(self, event: ActionFinishedEvent) -> None: - """Handle ActionFinished event.""" diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py new file mode 100644 index 0000000..e576f74 --- /dev/null +++ b/src/mlia/core/handlers.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Event handlers module.""" +import logging +from typing import Any +from typing import Callable +from typing import List +from typing import Optional + +from mlia.core._typing import PathOrFileLike +from mlia.core.advice_generation import Advice +from mlia.core.advice_generation import AdviceEvent +from mlia.core.events import ActionFinishedEvent +from mlia.core.events import ActionStartedEvent +from mlia.core.events import AdviceStageFinishedEvent +from mlia.core.events import AdviceStageStartedEvent +from mlia.core.events import AnalyzedDataEvent +from mlia.core.events import CollectedDataEvent +from mlia.core.events import DataAnalysisStageFinishedEvent +from mlia.core.events import DataAnalysisStageStartedEvent +from mlia.core.events import DataCollectionStageFinishedEvent +from mlia.core.events import DataCollectionStageStartedEvent +from mlia.core.events import DataCollectorSkippedEvent +from mlia.core.events import EventDispatcher +from mlia.core.events import ExecutionFailedEvent +from mlia.core.events import ExecutionFinishedEvent +from mlia.core.events import ExecutionStartedEvent +from mlia.core.reporting import Report +from mlia.core.reporting import Reporter +from mlia.core.reporting import resolve_output_format +from mlia.utils.console import create_section_header + + +logger = logging.getLogger(__name__) + + +class SystemEventsHandler(EventDispatcher): + """System events handler.""" + + def on_execution_started(self, event: ExecutionStartedEvent) -> None: + """Handle ExecutionStarted event.""" + + def on_execution_finished(self, event: ExecutionFinishedEvent) -> None: + """Handle ExecutionFinished event.""" + + def on_execution_failed(self, event: ExecutionFailedEvent) -> None: + """Handle ExecutionFailed event.""" + + def on_data_collection_stage_started( + self, event: DataCollectionStageStartedEvent + ) -> None: + """Handle DataCollectionStageStarted event.""" + + def on_data_collection_stage_finished( + self, event: DataCollectionStageFinishedEvent + ) -> None: + """Handle DataCollectionStageFinished event.""" + + def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None: + """Handle DataCollectorSkipped event.""" + + def on_data_analysis_stage_started( + self, event: DataAnalysisStageStartedEvent + ) -> None: + """Handle DataAnalysisStageStartedEvent event.""" + + def on_data_analysis_stage_finished( + self, event: DataAnalysisStageFinishedEvent + ) -> None: + """Handle DataAnalysisStageFinishedEvent event.""" + + def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None: + """Handle AdviceStageStarted event.""" + + def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None: + """Handle AdviceStageFinished event.""" + + def on_collected_data(self, event: CollectedDataEvent) -> None: + """Handle CollectedData event.""" + + def on_analyzed_data(self, event: AnalyzedDataEvent) -> None: + """Handle AnalyzedData event.""" + + def on_action_started(self, event: ActionStartedEvent) -> None: + """Handle ActionStarted event.""" + + def on_action_finished(self, event: ActionFinishedEvent) -> None: + """Handle ActionFinished event.""" + + +_ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started") +_MODEL_ANALYSIS_MSG = create_section_header("Model Analysis") +_MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results") +_ADV_GENERATION_MSG = create_section_header("Advice Generation") +_REPORT_GENERATION_MSG = create_section_header("Report Generation") + + +class WorkflowEventsHandler(SystemEventsHandler): + """Event handler for the system events.""" + + def __init__( + self, + formatter_resolver: Callable[[Any], Callable[[Any], Report]], + output: Optional[PathOrFileLike] = None, + ) -> None: + """Init event handler.""" + output_format = resolve_output_format(output) + self.reporter = Reporter(formatter_resolver, output_format) + self.output = output + + self.advice: List[Advice] = [] + + def on_execution_started(self, event: ExecutionStartedEvent) -> None: + """Handle ExecutionStarted event.""" + logger.info(_ADV_EXECUTION_STARTED) + + def on_execution_failed(self, event: ExecutionFailedEvent) -> None: + """Handle ExecutionFailed event.""" + raise event.err + + def on_data_collection_stage_started( + self, event: DataCollectionStageStartedEvent + ) -> None: + """Handle DataCollectionStageStarted event.""" + logger.info(_MODEL_ANALYSIS_MSG) + + def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None: + """Handle AdviceStageStarted event.""" + logger.info(_ADV_GENERATION_MSG) + + def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None: + """Handle DataCollectorSkipped event.""" + logger.info("Skipped: %s", event.reason) + + @staticmethod + def report_generated(output: PathOrFileLike) -> None: + """Log report generation.""" + logger.info(_REPORT_GENERATION_MSG) + logger.info("Report(s) and advice list saved to: %s", output) + + def on_data_analysis_stage_finished( + self, event: DataAnalysisStageFinishedEvent + ) -> None: + """Handle DataAnalysisStageFinished event.""" + logger.info(_MODEL_ANALYSIS_RESULTS_MSG) + + self.reporter.print_delayed() + + def on_advice_event(self, event: AdviceEvent) -> None: + """Handle Advice event.""" + self.advice.append(event.advice) + + def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None: + """Handle AdviceStageFinishedEvent event.""" + self.reporter.submit( + self.advice, + show_title=False, + show_headers=False, + space="between", + table_style="no_borders", + ) + + self.reporter.generate_report(self.output) + + if self.output is not None: + self.report_generated(self.output) diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py index 9006602..58a41d3 100644 --- a/src/mlia/core/reporting.py +++ b/src/mlia/core/reporting.py @@ -760,3 +760,14 @@ def _apply_format_parameters( return report return wrapper + + +def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat: + """Resolve output format based on the output name.""" + if isinstance(output, (str, Path)): + format_from_filename = Path(output).suffix.lstrip(".") + + if format_from_filename in ["json", "csv"]: + return cast(OutputFormat, format_from_filename) + + return "plain_text" diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py index 0245087..03f3d1c 100644 --- a/src/mlia/core/workflow.py +++ b/src/mlia/core/workflow.py @@ -87,7 +87,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): collectors: Sequence[DataCollector], analyzers: Sequence[DataAnalyzer], producers: Sequence[AdviceProducer], - before_start_events: Optional[Sequence[Event]] = None, + startup_events: Optional[Sequence[Event]] = None, ): """Init default workflow executor. @@ -95,14 +95,14 @@ class DefaultWorkflowExecutor(WorkflowExecutor): :param collectors: List of the data collectors :param analyzers: List of the data analyzers :param producers: List of the advice producers - :param before_start_events: Optional list of the custom events that + :param startup_events: Optional list of the custom events that should be published before start of the worfkow execution. """ self.context = context self.collectors = collectors self.analyzers = analyzers self.producers = producers - self.before_start_events = before_start_events + self.startup_events = startup_events def run(self) -> None: """Run the workflow.""" @@ -125,7 +125,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): def before_start(self) -> None: """Run actions before start of the workflow execution.""" - events = self.before_start_events or [] + events = self.startup_events or [] for event in events: self.publish(event) -- cgit v1.2.1