diff options
Diffstat (limited to 'src/mlia/core/handlers.py')
-rw-r--r-- | src/mlia/core/handlers.py | 31 |
1 files changed, 12 insertions, 19 deletions
diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py index 6e50934..24d4881 100644 --- a/src/mlia/core/handlers.py +++ b/src/mlia/core/handlers.py @@ -9,7 +9,6 @@ from typing import Callable from mlia.core.advice_generation import Advice from mlia.core.advice_generation import AdviceEvent -from mlia.core.common import FormattedFilePath from mlia.core.events import ActionFinishedEvent from mlia.core.events import ActionStartedEvent from mlia.core.events import AdviceStageFinishedEvent @@ -25,9 +24,11 @@ from mlia.core.events import EventDispatcher from mlia.core.events import ExecutionFailedEvent from mlia.core.events import ExecutionFinishedEvent from mlia.core.events import ExecutionStartedEvent +from mlia.core.mixins import ContextMixin +from mlia.core.reporting import JSONReporter from mlia.core.reporting import Report from mlia.core.reporting import Reporter -from mlia.core.typing import PathOrFileLike +from mlia.core.reporting import TextReporter from mlia.utils.console import create_section_header @@ -92,26 +93,27 @@ _ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started") _MODEL_ANALYSIS_MSG = create_section_header("Model Analysis") _MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results") _ADV_GENERATION_MSG = create_section_header("Advice Generation") -_REPORT_GENERATION_MSG = create_section_header("Report Generation") -class WorkflowEventsHandler(SystemEventsHandler): +class WorkflowEventsHandler(SystemEventsHandler, ContextMixin): """Event handler for the system events.""" + reporter: Reporter + def __init__( self, formatter_resolver: Callable[[Any], Callable[[Any], Report]], - output: FormattedFilePath | None = None, ) -> None: """Init event handler.""" - output_format = output.fmt if output else "plain_text" - self.reporter = Reporter(formatter_resolver, output_format) - self.output = output.path if output else None - + self.formatter_resolver = formatter_resolver self.advice: list[Advice] = [] def on_execution_started(self, event: ExecutionStartedEvent) -> None: """Handle ExecutionStarted event.""" + if self.context.output_format == "json": + self.reporter = JSONReporter(self.formatter_resolver) + else: + self.reporter = TextReporter(self.formatter_resolver) logger.info(_ADV_EXECUTION_STARTED) def on_execution_failed(self, event: ExecutionFailedEvent) -> None: @@ -132,12 +134,6 @@ class WorkflowEventsHandler(SystemEventsHandler): """Handle DataCollectorSkipped event.""" logger.info("Skipped: %s", event.reason) - @staticmethod - def report_generated(output: PathOrFileLike) -> None: - """Log report generation.""" - logger.info(_REPORT_GENERATION_MSG) - logger.info("Report(s) and advice list saved to: %s", output) - def on_data_analysis_stage_finished( self, event: DataAnalysisStageFinishedEvent ) -> None: @@ -160,7 +156,4 @@ class WorkflowEventsHandler(SystemEventsHandler): table_style="no_borders", ) - self.reporter.generate_report(self.output) - - if self.output is not None: - self.report_generated(self.output) + self.reporter.generate_report() |