diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-07-21 14:06:03 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-08-19 10:23:23 +0100 |
commit | a8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (patch) | |
tree | 8463b24ba0446a49b3e012477b0834c3b5415b86 /src/mlia/devices/ethosu/handlers.py | |
parent | 76ec769ad8f8ed53ec3ff829fdd34d53db8229fd (diff) | |
download | mlia-a8ee1aee3e674c78a77801d1bf2256881ab6b4b9.tar.gz |
MLIA-549 Refactor API module to support several target profiles
- Move target specific details out of API module
- Move common logic for workflow event handler into a
separate class
Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788
Diffstat (limited to 'src/mlia/devices/ethosu/handlers.py')
-rw-r--r-- | src/mlia/devices/ethosu/handlers.py | 100 |
1 files changed, 4 insertions, 96 deletions
diff --git a/src/mlia/devices/ethosu/handlers.py b/src/mlia/devices/ethosu/handlers.py index 7a0c31c..ee0b809 100644 --- a/src/mlia/devices/ethosu/handlers.py +++ b/src/mlia/devices/ethosu/handlers.py @@ -2,101 +2,27 @@ # SPDX-License-Identifier: Apache-2.0 """Event handler.""" import logging -from pathlib import Path -from typing import Dict -from typing import List from typing import Optional -from mlia.core._typing import OutputFormat from mlia.core._typing import PathOrFileLike -from mlia.core.advice_generation import Advice -from mlia.core.advice_generation import AdviceEvent -from mlia.core.events import AdviceStageFinishedEvent -from mlia.core.events import AdviceStageStartedEvent from mlia.core.events import CollectedDataEvent -from mlia.core.events import DataAnalysisStageFinishedEvent -from mlia.core.events import DataCollectionStageStartedEvent -from mlia.core.events import DataCollectorSkippedEvent -from mlia.core.events import ExecutionFailedEvent -from mlia.core.events import ExecutionStartedEvent -from mlia.core.events import SystemEventsHandler -from mlia.core.reporting import Reporter +from mlia.core.handlers import WorkflowEventsHandler from mlia.devices.ethosu.events import EthosUAdvisorEventHandler from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics from mlia.devices.ethosu.performance import PerformanceMetrics -from mlia.devices.ethosu.reporters import find_appropriate_formatter +from mlia.devices.ethosu.reporters import ethos_u_formatters from mlia.tools.vela_wrapper import Operators -from mlia.utils.console import create_section_header logger = logging.getLogger(__name__) -ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started") -MODEL_ANALYSIS_MSG = create_section_header("Model Analysis") -MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results") -ADV_GENERATION_MSG = create_section_header("Advice Generation") -REPORT_GENERATION_MSG = create_section_header("Report Generation") - - -class WorkflowEventsHandler(SystemEventsHandler): - """Event handler for the system events.""" - - def on_execution_started(self, event: ExecutionStartedEvent) -> None: - """Handle ExecutionStarted event.""" - logger.info(ADV_EXECUTION_STARTED) - - def on_execution_failed(self, event: ExecutionFailedEvent) -> None: - """Handle ExecutionFailed event.""" - raise event.err - - def on_data_collection_stage_started( - self, event: DataCollectionStageStartedEvent - ) -> None: - """Handle DataCollectionStageStarted event.""" - logger.info(MODEL_ANALYSIS_MSG) - - def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None: - """Handle AdviceStageStarted event.""" - logger.info(ADV_GENERATION_MSG) - - def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None: - """Handle DataCollectorSkipped event.""" - logger.info("Skipped: %s", event.reason) - class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): """CLI event handler.""" def __init__(self, output: Optional[PathOrFileLike] = None) -> None: """Init event handler.""" - output_format = self.resolve_output_format(output) - - self.reporter = Reporter(find_appropriate_formatter, output_format) - self.output = output - self.advice: List[Advice] = [] - - def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None: - """Handle AdviceStageFinishedEvent event.""" - self.reporter.submit( - self.advice, - show_title=False, - show_headers=False, - space="between", - table_style="no_borders", - ) - - self.reporter.generate_report(self.output) - - if self.output is not None: - logger.info(REPORT_GENERATION_MSG) - logger.info("Report(s) and advice list saved to: %s", self.output) - - def on_data_analysis_stage_finished( - self, event: DataAnalysisStageFinishedEvent - ) -> None: - """Handle DataAnalysisStageFinished event.""" - logger.info(MODEL_ANALYSIS_RESULTS_MSG) - self.reporter.print_delayed() + super().__init__(ethos_u_formatters, output) def on_collected_data(self, event: CollectedDataEvent) -> None: """Handle CollectedDataEvent event.""" @@ -106,7 +32,7 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): self.reporter.submit([data_item.ops, data_item], delay_print=True) if isinstance(data_item, PerformanceMetrics): - self.reporter.submit(data_item, delay_print=True) + self.reporter.submit(data_item, delay_print=True, space=True) if isinstance(data_item, OptimizationPerformanceMetrics): original_metrics = data_item.original_perf_metrics @@ -123,24 +49,6 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): space=True, ) - def on_advice_event(self, event: AdviceEvent) -> None: - """Handle Advice event.""" - self.advice.append(event.advice) - def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None: """Handle EthosUAdvisorStarted event.""" self.reporter.submit(event.device) - - @staticmethod - def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat: - """Resolve output format based on the output name.""" - output_format: OutputFormat = "plain_text" - - if isinstance(output, str): - output_path = Path(output) - output_formats: Dict[str, OutputFormat] = {".csv": "csv", ".json": "json"} - - if (suffix := output_path.suffix) in output_formats: - return output_formats[suffix] - - return output_format |