aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/devices/ethosu/handlers.py
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-07-21 14:06:03 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-08-19 10:23:23 +0100
commita8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (patch)
tree8463b24ba0446a49b3e012477b0834c3b5415b86 /src/mlia/devices/ethosu/handlers.py
parent76ec769ad8f8ed53ec3ff829fdd34d53db8229fd (diff)
downloadmlia-a8ee1aee3e674c78a77801d1bf2256881ab6b4b9.tar.gz
MLIA-549 Refactor API module to support several target profiles
- Move target specific details out of API module - Move common logic for workflow event handler into a separate class Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788
Diffstat (limited to 'src/mlia/devices/ethosu/handlers.py')
-rw-r--r--src/mlia/devices/ethosu/handlers.py100
1 files changed, 4 insertions, 96 deletions
diff --git a/src/mlia/devices/ethosu/handlers.py b/src/mlia/devices/ethosu/handlers.py
index 7a0c31c..ee0b809 100644
--- a/src/mlia/devices/ethosu/handlers.py
+++ b/src/mlia/devices/ethosu/handlers.py
@@ -2,101 +2,27 @@
# SPDX-License-Identifier: Apache-2.0
"""Event handler."""
import logging
-from pathlib import Path
-from typing import Dict
-from typing import List
from typing import Optional
-from mlia.core._typing import OutputFormat
from mlia.core._typing import PathOrFileLike
-from mlia.core.advice_generation import Advice
-from mlia.core.advice_generation import AdviceEvent
-from mlia.core.events import AdviceStageFinishedEvent
-from mlia.core.events import AdviceStageStartedEvent
from mlia.core.events import CollectedDataEvent
-from mlia.core.events import DataAnalysisStageFinishedEvent
-from mlia.core.events import DataCollectionStageStartedEvent
-from mlia.core.events import DataCollectorSkippedEvent
-from mlia.core.events import ExecutionFailedEvent
-from mlia.core.events import ExecutionStartedEvent
-from mlia.core.events import SystemEventsHandler
-from mlia.core.reporting import Reporter
+from mlia.core.handlers import WorkflowEventsHandler
from mlia.devices.ethosu.events import EthosUAdvisorEventHandler
from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent
from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
from mlia.devices.ethosu.performance import PerformanceMetrics
-from mlia.devices.ethosu.reporters import find_appropriate_formatter
+from mlia.devices.ethosu.reporters import ethos_u_formatters
from mlia.tools.vela_wrapper import Operators
-from mlia.utils.console import create_section_header
logger = logging.getLogger(__name__)
-ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started")
-MODEL_ANALYSIS_MSG = create_section_header("Model Analysis")
-MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results")
-ADV_GENERATION_MSG = create_section_header("Advice Generation")
-REPORT_GENERATION_MSG = create_section_header("Report Generation")
-
-
-class WorkflowEventsHandler(SystemEventsHandler):
- """Event handler for the system events."""
-
- def on_execution_started(self, event: ExecutionStartedEvent) -> None:
- """Handle ExecutionStarted event."""
- logger.info(ADV_EXECUTION_STARTED)
-
- def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
- """Handle ExecutionFailed event."""
- raise event.err
-
- def on_data_collection_stage_started(
- self, event: DataCollectionStageStartedEvent
- ) -> None:
- """Handle DataCollectionStageStarted event."""
- logger.info(MODEL_ANALYSIS_MSG)
-
- def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
- """Handle AdviceStageStarted event."""
- logger.info(ADV_GENERATION_MSG)
-
- def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
- """Handle DataCollectorSkipped event."""
- logger.info("Skipped: %s", event.reason)
-
class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
"""CLI event handler."""
def __init__(self, output: Optional[PathOrFileLike] = None) -> None:
"""Init event handler."""
- output_format = self.resolve_output_format(output)
-
- self.reporter = Reporter(find_appropriate_formatter, output_format)
- self.output = output
- self.advice: List[Advice] = []
-
- def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
- """Handle AdviceStageFinishedEvent event."""
- self.reporter.submit(
- self.advice,
- show_title=False,
- show_headers=False,
- space="between",
- table_style="no_borders",
- )
-
- self.reporter.generate_report(self.output)
-
- if self.output is not None:
- logger.info(REPORT_GENERATION_MSG)
- logger.info("Report(s) and advice list saved to: %s", self.output)
-
- def on_data_analysis_stage_finished(
- self, event: DataAnalysisStageFinishedEvent
- ) -> None:
- """Handle DataAnalysisStageFinished event."""
- logger.info(MODEL_ANALYSIS_RESULTS_MSG)
- self.reporter.print_delayed()
+ super().__init__(ethos_u_formatters, output)
def on_collected_data(self, event: CollectedDataEvent) -> None:
"""Handle CollectedDataEvent event."""
@@ -106,7 +32,7 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
self.reporter.submit([data_item.ops, data_item], delay_print=True)
if isinstance(data_item, PerformanceMetrics):
- self.reporter.submit(data_item, delay_print=True)
+ self.reporter.submit(data_item, delay_print=True, space=True)
if isinstance(data_item, OptimizationPerformanceMetrics):
original_metrics = data_item.original_perf_metrics
@@ -123,24 +49,6 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
space=True,
)
- def on_advice_event(self, event: AdviceEvent) -> None:
- """Handle Advice event."""
- self.advice.append(event.advice)
-
def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None:
"""Handle EthosUAdvisorStarted event."""
self.reporter.submit(event.device)
-
- @staticmethod
- def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat:
- """Resolve output format based on the output name."""
- output_format: OutputFormat = "plain_text"
-
- if isinstance(output, str):
- output_path = Path(output)
- output_formats: Dict[str, OutputFormat] = {".csv": "csv", ".json": "json"}
-
- if (suffix := output_path.suffix) in output_formats:
- return output_formats[suffix]
-
- return output_format