aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/devices/ethosu/handlers.py
blob: 7a0c31cb70c760d37488b219084f6de88a1efaea (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Event handler."""
import logging
from pathlib import Path
from typing import Dict
from typing import List
from typing import Optional

from mlia.core._typing import OutputFormat
from mlia.core._typing import PathOrFileLike
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
from mlia.core.events import AdviceStageFinishedEvent
from mlia.core.events import AdviceStageStartedEvent
from mlia.core.events import CollectedDataEvent
from mlia.core.events import DataAnalysisStageFinishedEvent
from mlia.core.events import DataCollectionStageStartedEvent
from mlia.core.events import DataCollectorSkippedEvent
from mlia.core.events import ExecutionFailedEvent
from mlia.core.events import ExecutionStartedEvent
from mlia.core.events import SystemEventsHandler
from mlia.core.reporting import Reporter
from mlia.devices.ethosu.events import EthosUAdvisorEventHandler
from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent
from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
from mlia.devices.ethosu.performance import PerformanceMetrics
from mlia.devices.ethosu.reporters import find_appropriate_formatter
from mlia.tools.vela_wrapper import Operators
from mlia.utils.console import create_section_header

logger = logging.getLogger(__name__)

ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started")
MODEL_ANALYSIS_MSG = create_section_header("Model Analysis")
MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results")
ADV_GENERATION_MSG = create_section_header("Advice Generation")
REPORT_GENERATION_MSG = create_section_header("Report Generation")


class WorkflowEventsHandler(SystemEventsHandler):
    """Event handler for the system events."""

    def on_execution_started(self, event: ExecutionStartedEvent) -> None:
        """Handle ExecutionStarted event."""
        logger.info(ADV_EXECUTION_STARTED)

    def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
        """Handle ExecutionFailed event."""
        raise event.err

    def on_data_collection_stage_started(
        self, event: DataCollectionStageStartedEvent
    ) -> None:
        """Handle DataCollectionStageStarted event."""
        logger.info(MODEL_ANALYSIS_MSG)

    def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
        """Handle AdviceStageStarted event."""
        logger.info(ADV_GENERATION_MSG)

    def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
        """Handle DataCollectorSkipped event."""
        logger.info("Skipped: %s", event.reason)


class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
    """CLI event handler."""

    def __init__(self, output: Optional[PathOrFileLike] = None) -> None:
        """Init event handler."""
        output_format = self.resolve_output_format(output)

        self.reporter = Reporter(find_appropriate_formatter, output_format)
        self.output = output
        self.advice: List[Advice] = []

    def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
        """Handle AdviceStageFinishedEvent event."""
        self.reporter.submit(
            self.advice,
            show_title=False,
            show_headers=False,
            space="between",
            table_style="no_borders",
        )

        self.reporter.generate_report(self.output)

        if self.output is not None:
            logger.info(REPORT_GENERATION_MSG)
            logger.info("Report(s) and advice list saved to: %s", self.output)

    def on_data_analysis_stage_finished(
        self, event: DataAnalysisStageFinishedEvent
    ) -> None:
        """Handle DataAnalysisStageFinished event."""
        logger.info(MODEL_ANALYSIS_RESULTS_MSG)
        self.reporter.print_delayed()

    def on_collected_data(self, event: CollectedDataEvent) -> None:
        """Handle CollectedDataEvent event."""
        data_item = event.data_item

        if isinstance(data_item, Operators):
            self.reporter.submit([data_item.ops, data_item], delay_print=True)

        if isinstance(data_item, PerformanceMetrics):
            self.reporter.submit(data_item, delay_print=True)

        if isinstance(data_item, OptimizationPerformanceMetrics):
            original_metrics = data_item.original_perf_metrics
            if not data_item.optimizations_perf_metrics:
                return

            _opt_settings, optimized_metrics = data_item.optimizations_perf_metrics[0]

            self.reporter.submit(
                [original_metrics, optimized_metrics],
                delay_print=True,
                columns_name="Metrics",
                title="Performance metrics",
                space=True,
            )

    def on_advice_event(self, event: AdviceEvent) -> None:
        """Handle Advice event."""
        self.advice.append(event.advice)

    def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None:
        """Handle EthosUAdvisorStarted event."""
        self.reporter.submit(event.device)

    @staticmethod
    def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat:
        """Resolve output format based on the output name."""
        output_format: OutputFormat = "plain_text"

        if isinstance(output, str):
            output_path = Path(output)
            output_formats: Dict[str, OutputFormat] = {".csv": "csv", ".json": "json"}

            if (suffix := output_path.suffix) in output_formats:
                return output_formats[suffix]

        return output_format