From a8ee1aee3e674c78a77801d1bf2256881ab6b4b9 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 21 Jul 2022 14:06:03 +0100 Subject: MLIA-549 Refactor API module to support several target profiles - Move target specific details out of API module - Move common logic for workflow event handler into a separate class Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788 --- src/mlia/api.py | 108 +++++++-------------- src/mlia/core/advisor.py | 64 +++++++++++++ src/mlia/core/events.py | 54 ----------- src/mlia/core/handlers.py | 166 +++++++++++++++++++++++++++++++++ src/mlia/core/reporting.py | 11 +++ src/mlia/core/workflow.py | 8 +- src/mlia/devices/ethosu/advisor.py | 159 +++++++++++++++++++------------ src/mlia/devices/ethosu/handlers.py | 100 +------------------- src/mlia/devices/ethosu/reporters.py | 4 +- src/mlia/utils/filesystem.py | 24 +++-- tests/test_api.py | 17 +++- tests/test_core_events.py | 2 +- tests/test_core_reporting.py | 22 +++++ tests/test_devices_ethosu_reporters.py | 4 +- 14 files changed, 439 insertions(+), 304 deletions(-) create mode 100644 src/mlia/core/handlers.py diff --git a/src/mlia/api.py b/src/mlia/api.py index 0f950db..024bc98 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -14,28 +14,13 @@ from mlia.core._typing import PathOrFileLike from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext -from mlia.core.events import EventHandler -from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor -from mlia.devices.ethosu.handlers import EthosUEventHandler +from mlia.devices.ethosu.advisor import configure_and_get_ethosu_advisor +from mlia.utils.filesystem import get_target logger = logging.getLogger(__name__) -_DEFAULT_OPTIMIZATION_TARGETS = [ - { - "optimization_type": "pruning", - "optimization_target": 0.5, - "layers_to_optimize": None, - }, - { - "optimization_type": "clustering", - "optimization_target": 32, - "layers_to_optimize": None, - }, -] - - def get_advice( target_profile: str, model: Union[Path, str], @@ -71,7 +56,6 @@ def get_advice( :param backends: A list of backends that should be used for the given target. Default settings will be used if None. - Examples: NB: Before launching MLIA, the logging functionality should be configured! @@ -87,75 +71,51 @@ def get_advice( """ advice_category = AdviceCategory.from_string(category) - config_parameters = _get_config_parameters( - model, target_profile, backends, optimization_targets - ) - event_handlers = _get_event_handlers(output) if context is not None: context.advice_category = advice_category - if context.config_parameters is None: - context.config_parameters = config_parameters - - if context.event_handlers is None: - context.event_handlers = event_handlers - if context is None: context = ExecutionContext( advice_category=advice_category, working_dir=working_dir, - config_parameters=config_parameters, - event_handlers=event_handlers, ) - advisor = _get_advisor(target_profile) - advisor.run(context) - - -def _get_advisor(target: Optional[str]) -> InferenceAdvisor: - """Find appropriate advisor for the target.""" - if not target: - raise Exception("Target is not provided") + advisor = get_advisor( + context, + target_profile, + model, + output, + optimization_targets=optimization_targets, + backends=backends, + ) - return EthosUInferenceAdvisor() + advisor.run(context) -def _get_config_parameters( - model: Union[Path, str], +def get_advisor( + context: ExecutionContext, target_profile: str, - backends: Optional[List[str]], - optimization_targets: Optional[List[Dict[str, Any]]], -) -> Dict[str, Any]: - """Get configuration parameters for the advisor.""" - advisor_parameters: Dict[str, Any] = { - "ethos_u_inference_advisor": { - "model": model, - "device": { - "target_profile": target_profile, - }, - }, + model: Union[Path, str], + output: Optional[PathOrFileLike] = None, + **extra_args: Any, +) -> InferenceAdvisor: + """Find appropriate advisor for the target.""" + target_factories = { + "ethos-u55": configure_and_get_ethosu_advisor, + "ethos-u65": configure_and_get_ethosu_advisor, } - # Specifying backends is optional (default is used) - if backends is not None: - advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends - - if not optimization_targets: - optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS - - advisor_parameters.update( - { - "ethos_u_model_optimizations": { - "optimizations": [ - optimization_targets, - ], - }, - } - ) - - return advisor_parameters - -def _get_event_handlers(output: Optional[PathOrFileLike]) -> List[EventHandler]: - """Return list of the event handlers.""" - return [EthosUEventHandler(output)] + try: + target = get_target(target_profile) + factory_function = target_factories[target] + except KeyError as err: + raise Exception(f"Unsupported profile {target_profile}") from err + + return factory_function( + context, + target_profile, + model, + output, + **extra_args, + ) diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py index 868d0c7..13689fa 100644 --- a/src/mlia/core/advisor.py +++ b/src/mlia/core/advisor.py @@ -2,9 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 """Inference advisor module.""" from abc import abstractmethod +from pathlib import Path +from typing import cast +from typing import List +from mlia.core.advice_generation import AdviceProducer from mlia.core.common import NamedEntity from mlia.core.context import Context +from mlia.core.data_analysis import DataAnalyzer +from mlia.core.data_collection import DataCollector +from mlia.core.events import Event +from mlia.core.mixins import ParameterResolverMixin +from mlia.core.workflow import DefaultWorkflowExecutor from mlia.core.workflow import WorkflowExecutor @@ -19,3 +28,58 @@ class InferenceAdvisor(NamedEntity): """Run inference advisor.""" executor = self.configure(context) executor.run() + + +class DefaultInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin): + """Default implementation for the advisor.""" + + def configure(self, context: Context) -> WorkflowExecutor: + """Configure advisor.""" + return DefaultWorkflowExecutor( + context, + self.get_collectors(context), + self.get_analyzers(context), + self.get_producers(context), + self.get_events(context), + ) + + @abstractmethod + def get_collectors(self, context: Context) -> List[DataCollector]: + """Return list of the data collectors.""" + + @abstractmethod + def get_analyzers(self, context: Context) -> List[DataAnalyzer]: + """Return list of the data analyzers.""" + + @abstractmethod + def get_producers(self, context: Context) -> List[AdviceProducer]: + """Return list of the advice producers.""" + + @abstractmethod + def get_events(self, context: Context) -> List[Event]: + """Return list of the startup events.""" + + def get_string_parameter(self, context: Context, param: str) -> str: + """Get string parameter value.""" + value = self.get_parameter( + self.name(), + param, + expected_type=str, + context=context, + ) + + return cast(str, value) + + def get_model(self, context: Context) -> Path: + """Get path to the model.""" + model_param = self.get_string_parameter(context, "model") + + model = Path(model_param) + if not model.exists(): + raise Exception(f"Path {model} does not exist") + + return model + + def get_target_profile(self, context: Context) -> str: + """Get target profile.""" + return self.get_string_parameter(context, "target_profile") diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py index 10aec86..0b8461b 100644 --- a/src/mlia/core/events.py +++ b/src/mlia/core/events.py @@ -399,57 +399,3 @@ def action( publisher.publish_event(action_started) yield publisher.publish_event(action_finished) - - -class SystemEventsHandler(EventDispatcher): - """System events handler.""" - - def on_execution_started(self, event: ExecutionStartedEvent) -> None: - """Handle ExecutionStarted event.""" - - def on_execution_finished(self, event: ExecutionFinishedEvent) -> None: - """Handle ExecutionFinished event.""" - - def on_execution_failed(self, event: ExecutionFailedEvent) -> None: - """Handle ExecutionFailed event.""" - - def on_data_collection_stage_started( - self, event: DataCollectionStageStartedEvent - ) -> None: - """Handle DataCollectionStageStarted event.""" - - def on_data_collection_stage_finished( - self, event: DataCollectionStageFinishedEvent - ) -> None: - """Handle DataCollectionStageFinished event.""" - - def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None: - """Handle DataCollectorSkipped event.""" - - def on_data_analysis_stage_started( - self, event: DataAnalysisStageStartedEvent - ) -> None: - """Handle DataAnalysisStageStartedEvent event.""" - - def on_data_analysis_stage_finished( - self, event: DataAnalysisStageFinishedEvent - ) -> None: - """Handle DataAnalysisStageFinishedEvent event.""" - - def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None: - """Handle AdviceStageStarted event.""" - - def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None: - """Handle AdviceStageFinished event.""" - - def on_collected_data(self, event: CollectedDataEvent) -> None: - """Handle CollectedData event.""" - - def on_analyzed_data(self, event: AnalyzedDataEvent) -> None: - """Handle AnalyzedData event.""" - - def on_action_started(self, event: ActionStartedEvent) -> None: - """Handle ActionStarted event.""" - - def on_action_finished(self, event: ActionFinishedEvent) -> None: - """Handle ActionFinished event.""" diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py new file mode 100644 index 0000000..e576f74 --- /dev/null +++ b/src/mlia/core/handlers.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Event handlers module.""" +import logging +from typing import Any +from typing import Callable +from typing import List +from typing import Optional + +from mlia.core._typing import PathOrFileLike +from mlia.core.advice_generation import Advice +from mlia.core.advice_generation import AdviceEvent +from mlia.core.events import ActionFinishedEvent +from mlia.core.events import ActionStartedEvent +from mlia.core.events import AdviceStageFinishedEvent +from mlia.core.events import AdviceStageStartedEvent +from mlia.core.events import AnalyzedDataEvent +from mlia.core.events import CollectedDataEvent +from mlia.core.events import DataAnalysisStageFinishedEvent +from mlia.core.events import DataAnalysisStageStartedEvent +from mlia.core.events import DataCollectionStageFinishedEvent +from mlia.core.events import DataCollectionStageStartedEvent +from mlia.core.events import DataCollectorSkippedEvent +from mlia.core.events import EventDispatcher +from mlia.core.events import ExecutionFailedEvent +from mlia.core.events import ExecutionFinishedEvent +from mlia.core.events import ExecutionStartedEvent +from mlia.core.reporting import Report +from mlia.core.reporting import Reporter +from mlia.core.reporting import resolve_output_format +from mlia.utils.console import create_section_header + + +logger = logging.getLogger(__name__) + + +class SystemEventsHandler(EventDispatcher): + """System events handler.""" + + def on_execution_started(self, event: ExecutionStartedEvent) -> None: + """Handle ExecutionStarted event.""" + + def on_execution_finished(self, event: ExecutionFinishedEvent) -> None: + """Handle ExecutionFinished event.""" + + def on_execution_failed(self, event: ExecutionFailedEvent) -> None: + """Handle ExecutionFailed event.""" + + def on_data_collection_stage_started( + self, event: DataCollectionStageStartedEvent + ) -> None: + """Handle DataCollectionStageStarted event.""" + + def on_data_collection_stage_finished( + self, event: DataCollectionStageFinishedEvent + ) -> None: + """Handle DataCollectionStageFinished event.""" + + def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None: + """Handle DataCollectorSkipped event.""" + + def on_data_analysis_stage_started( + self, event: DataAnalysisStageStartedEvent + ) -> None: + """Handle DataAnalysisStageStartedEvent event.""" + + def on_data_analysis_stage_finished( + self, event: DataAnalysisStageFinishedEvent + ) -> None: + """Handle DataAnalysisStageFinishedEvent event.""" + + def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None: + """Handle AdviceStageStarted event.""" + + def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None: + """Handle AdviceStageFinished event.""" + + def on_collected_data(self, event: CollectedDataEvent) -> None: + """Handle CollectedData event.""" + + def on_analyzed_data(self, event: AnalyzedDataEvent) -> None: + """Handle AnalyzedData event.""" + + def on_action_started(self, event: ActionStartedEvent) -> None: + """Handle ActionStarted event.""" + + def on_action_finished(self, event: ActionFinishedEvent) -> None: + """Handle ActionFinished event.""" + + +_ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started") +_MODEL_ANALYSIS_MSG = create_section_header("Model Analysis") +_MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results") +_ADV_GENERATION_MSG = create_section_header("Advice Generation") +_REPORT_GENERATION_MSG = create_section_header("Report Generation") + + +class WorkflowEventsHandler(SystemEventsHandler): + """Event handler for the system events.""" + + def __init__( + self, + formatter_resolver: Callable[[Any], Callable[[Any], Report]], + output: Optional[PathOrFileLike] = None, + ) -> None: + """Init event handler.""" + output_format = resolve_output_format(output) + self.reporter = Reporter(formatter_resolver, output_format) + self.output = output + + self.advice: List[Advice] = [] + + def on_execution_started(self, event: ExecutionStartedEvent) -> None: + """Handle ExecutionStarted event.""" + logger.info(_ADV_EXECUTION_STARTED) + + def on_execution_failed(self, event: ExecutionFailedEvent) -> None: + """Handle ExecutionFailed event.""" + raise event.err + + def on_data_collection_stage_started( + self, event: DataCollectionStageStartedEvent + ) -> None: + """Handle DataCollectionStageStarted event.""" + logger.info(_MODEL_ANALYSIS_MSG) + + def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None: + """Handle AdviceStageStarted event.""" + logger.info(_ADV_GENERATION_MSG) + + def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None: + """Handle DataCollectorSkipped event.""" + logger.info("Skipped: %s", event.reason) + + @staticmethod + def report_generated(output: PathOrFileLike) -> None: + """Log report generation.""" + logger.info(_REPORT_GENERATION_MSG) + logger.info("Report(s) and advice list saved to: %s", output) + + def on_data_analysis_stage_finished( + self, event: DataAnalysisStageFinishedEvent + ) -> None: + """Handle DataAnalysisStageFinished event.""" + logger.info(_MODEL_ANALYSIS_RESULTS_MSG) + + self.reporter.print_delayed() + + def on_advice_event(self, event: AdviceEvent) -> None: + """Handle Advice event.""" + self.advice.append(event.advice) + + def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None: + """Handle AdviceStageFinishedEvent event.""" + self.reporter.submit( + self.advice, + show_title=False, + show_headers=False, + space="between", + table_style="no_borders", + ) + + self.reporter.generate_report(self.output) + + if self.output is not None: + self.report_generated(self.output) diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py index 9006602..58a41d3 100644 --- a/src/mlia/core/reporting.py +++ b/src/mlia/core/reporting.py @@ -760,3 +760,14 @@ def _apply_format_parameters( return report return wrapper + + +def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat: + """Resolve output format based on the output name.""" + if isinstance(output, (str, Path)): + format_from_filename = Path(output).suffix.lstrip(".") + + if format_from_filename in ["json", "csv"]: + return cast(OutputFormat, format_from_filename) + + return "plain_text" diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py index 0245087..03f3d1c 100644 --- a/src/mlia/core/workflow.py +++ b/src/mlia/core/workflow.py @@ -87,7 +87,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): collectors: Sequence[DataCollector], analyzers: Sequence[DataAnalyzer], producers: Sequence[AdviceProducer], - before_start_events: Optional[Sequence[Event]] = None, + startup_events: Optional[Sequence[Event]] = None, ): """Init default workflow executor. @@ -95,14 +95,14 @@ class DefaultWorkflowExecutor(WorkflowExecutor): :param collectors: List of the data collectors :param analyzers: List of the data analyzers :param producers: List of the advice producers - :param before_start_events: Optional list of the custom events that + :param startup_events: Optional list of the custom events that should be published before start of the worfkow execution. """ self.context = context self.collectors = collectors self.analyzers = analyzers self.producers = producers - self.before_start_events = before_start_events + self.startup_events = startup_events def run(self) -> None: """Run the workflow.""" @@ -125,7 +125,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): def before_start(self) -> None: """Run actions before start of the workflow execution.""" - events = self.before_start_events or [] + events = self.startup_events or [] for event in events: self.publish(event) diff --git a/src/mlia/devices/ethosu/advisor.py b/src/mlia/devices/ethosu/advisor.py index e93858f..b7b8305 100644 --- a/src/mlia/devices/ethosu/advisor.py +++ b/src/mlia/devices/ethosu/advisor.py @@ -2,18 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 """Ethos-U MLIA module.""" from pathlib import Path +from typing import Any +from typing import Dict from typing import List from typing import Optional +from typing import Union +from mlia.core._typing import PathOrFileLike from mlia.core.advice_generation import AdviceProducer +from mlia.core.advisor import DefaultInferenceAdvisor from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import Context +from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer from mlia.core.data_collection import DataCollector -from mlia.core.mixins import ParameterResolverMixin -from mlia.core.workflow import DefaultWorkflowExecutor -from mlia.core.workflow import WorkflowExecutor +from mlia.core.events import Event from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer from mlia.devices.ethosu.config import EthosUConfiguration @@ -23,10 +27,12 @@ from mlia.devices.ethosu.data_collection import EthosUOperatorCompatibility from mlia.devices.ethosu.data_collection import EthosUOptimizationPerformance from mlia.devices.ethosu.data_collection import EthosUPerformance from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent +from mlia.devices.ethosu.handlers import EthosUEventHandler from mlia.nn.tensorflow.utils import is_tflite_model +from mlia.utils.types import is_list_of -class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin): +class EthosUInferenceAdvisor(DefaultInferenceAdvisor): """Ethos-U Inference Advisor.""" @classmethod @@ -34,34 +40,12 @@ class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin): """Return name of the advisor.""" return "ethos_u_inference_advisor" - def configure(self, context: Context) -> WorkflowExecutor: - """Configure advisor execution.""" - model = self._get_model(context) + def get_collectors(self, context: Context) -> List[DataCollector]: + """Return list of the data collectors.""" + model = self.get_model(context) device = self._get_device(context) backends = self._get_backends(context) - collectors = self._get_collectors(context, model, device, backends) - analyzers = self._get_analyzers() - producers = self._get_advice_producers() - - return DefaultWorkflowExecutor( - context, - collectors, - analyzers, - producers, - before_start_events=[ - EthosUAdvisorStartedEvent(device=device, model=model), - ], - ) - - def _get_collectors( - self, - context: Context, - model: Path, - device: EthosUConfiguration, - backends: Optional[List[str]], - ) -> List[DataCollector]: - """Get collectors.""" collectors: List[DataCollector] = [] if AdviceCategory.OPERATORS in context.advice_category: @@ -91,51 +75,34 @@ class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin): return collectors - @staticmethod - def _get_analyzers() -> List[DataAnalyzer]: - """Return data analyzers.""" + def get_analyzers(self, context: Context) -> List[DataAnalyzer]: + """Return list of the data analyzers.""" return [ EthosUDataAnalyzer(), ] - @staticmethod - def _get_advice_producers() -> List[AdviceProducer]: - """Return advice producers.""" + def get_producers(self, context: Context) -> List[AdviceProducer]: + """Return list of the advice producers.""" return [ EthosUAdviceProducer(), EthosUStaticAdviceProducer(), ] + def get_events(self, context: Context) -> List[Event]: + """Return list of the startup events.""" + model = self.get_model(context) + device = self._get_device(context) + + return [ + EthosUAdvisorStartedEvent(device=device, model=model), + ] + def _get_device(self, context: Context) -> EthosUConfiguration: """Get device.""" - device_params = self.get_parameter( - self.name(), - "device", - expected_type=dict, - context=context, - ) - - try: - target_profile = device_params["target_profile"] - except KeyError as err: - raise Exception("Unable to get device details") from err + target_profile = self.get_target_profile(context) return get_target(target_profile) - def _get_model(self, context: Context) -> Path: - """Get path to the model.""" - model_param = self.get_parameter( - self.name(), - "model", - expected_type=str, - context=context, - ) - - if not (model := Path(model_param)).exists(): - raise Exception(f"Path {model} does not exist") - - return model - def _get_optimization_settings(self, context: Context) -> List[List[dict]]: """Get optimization settings.""" return self.get_parameter( # type: ignore @@ -155,3 +122,75 @@ class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin): expected=False, context=context, ) + + +def configure_and_get_ethosu_advisor( + context: ExecutionContext, + target_profile: str, + model: Union[Path, str], + output: Optional[PathOrFileLike] = None, + **extra_args: Any, +) -> InferenceAdvisor: + """Create and configure Ethos-U advisor.""" + if context.event_handlers is None: + context.event_handlers = [EthosUEventHandler(output)] + + if context.config_parameters is None: + context.config_parameters = _get_config_parameters( + model, target_profile, **extra_args + ) + + return EthosUInferenceAdvisor() + + +_DEFAULT_OPTIMIZATION_TARGETS = [ + { + "optimization_type": "pruning", + "optimization_target": 0.5, + "layers_to_optimize": None, + }, + { + "optimization_type": "clustering", + "optimization_target": 32, + "layers_to_optimize": None, + }, +] + + +def _get_config_parameters( + model: Union[Path, str], + target_profile: str, + **extra_args: Any, +) -> Dict[str, Any]: + """Get configuration parameters for the advisor.""" + advisor_parameters: Dict[str, Any] = { + "ethos_u_inference_advisor": { + "model": model, + "target_profile": target_profile, + }, + } + + # Specifying backends is optional (default is used) + backends = extra_args.get("backends") + if backends is not None: + if not is_list_of(backends, str): + raise Exception("Backends value has wrong format") + + advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends + + optimization_targets = extra_args.get("optimization_targets") + if not optimization_targets: + optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS + + if not is_list_of(optimization_targets, dict): + raise Exception("Optimization targets value has wrong format") + + advisor_parameters.update( + { + "ethos_u_model_optimizations": { + "optimizations": [optimization_targets], + }, + } + ) + + return advisor_parameters diff --git a/src/mlia/devices/ethosu/handlers.py b/src/mlia/devices/ethosu/handlers.py index 7a0c31c..ee0b809 100644 --- a/src/mlia/devices/ethosu/handlers.py +++ b/src/mlia/devices/ethosu/handlers.py @@ -2,101 +2,27 @@ # SPDX-License-Identifier: Apache-2.0 """Event handler.""" import logging -from pathlib import Path -from typing import Dict -from typing import List from typing import Optional -from mlia.core._typing import OutputFormat from mlia.core._typing import PathOrFileLike -from mlia.core.advice_generation import Advice -from mlia.core.advice_generation import AdviceEvent -from mlia.core.events import AdviceStageFinishedEvent -from mlia.core.events import AdviceStageStartedEvent from mlia.core.events import CollectedDataEvent -from mlia.core.events import DataAnalysisStageFinishedEvent -from mlia.core.events import DataCollectionStageStartedEvent -from mlia.core.events import DataCollectorSkippedEvent -from mlia.core.events import ExecutionFailedEvent -from mlia.core.events import ExecutionStartedEvent -from mlia.core.events import SystemEventsHandler -from mlia.core.reporting import Reporter +from mlia.core.handlers import WorkflowEventsHandler from mlia.devices.ethosu.events import EthosUAdvisorEventHandler from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics from mlia.devices.ethosu.performance import PerformanceMetrics -from mlia.devices.ethosu.reporters import find_appropriate_formatter +from mlia.devices.ethosu.reporters import ethos_u_formatters from mlia.tools.vela_wrapper import Operators -from mlia.utils.console import create_section_header logger = logging.getLogger(__name__) -ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started") -MODEL_ANALYSIS_MSG = create_section_header("Model Analysis") -MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results") -ADV_GENERATION_MSG = create_section_header("Advice Generation") -REPORT_GENERATION_MSG = create_section_header("Report Generation") - - -class WorkflowEventsHandler(SystemEventsHandler): - """Event handler for the system events.""" - - def on_execution_started(self, event: ExecutionStartedEvent) -> None: - """Handle ExecutionStarted event.""" - logger.info(ADV_EXECUTION_STARTED) - - def on_execution_failed(self, event: ExecutionFailedEvent) -> None: - """Handle ExecutionFailed event.""" - raise event.err - - def on_data_collection_stage_started( - self, event: DataCollectionStageStartedEvent - ) -> None: - """Handle DataCollectionStageStarted event.""" - logger.info(MODEL_ANALYSIS_MSG) - - def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None: - """Handle AdviceStageStarted event.""" - logger.info(ADV_GENERATION_MSG) - - def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None: - """Handle DataCollectorSkipped event.""" - logger.info("Skipped: %s", event.reason) - class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): """CLI event handler.""" def __init__(self, output: Optional[PathOrFileLike] = None) -> None: """Init event handler.""" - output_format = self.resolve_output_format(output) - - self.reporter = Reporter(find_appropriate_formatter, output_format) - self.output = output - self.advice: List[Advice] = [] - - def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None: - """Handle AdviceStageFinishedEvent event.""" - self.reporter.submit( - self.advice, - show_title=False, - show_headers=False, - space="between", - table_style="no_borders", - ) - - self.reporter.generate_report(self.output) - - if self.output is not None: - logger.info(REPORT_GENERATION_MSG) - logger.info("Report(s) and advice list saved to: %s", self.output) - - def on_data_analysis_stage_finished( - self, event: DataAnalysisStageFinishedEvent - ) -> None: - """Handle DataAnalysisStageFinished event.""" - logger.info(MODEL_ANALYSIS_RESULTS_MSG) - self.reporter.print_delayed() + super().__init__(ethos_u_formatters, output) def on_collected_data(self, event: CollectedDataEvent) -> None: """Handle CollectedDataEvent event.""" @@ -106,7 +32,7 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): self.reporter.submit([data_item.ops, data_item], delay_print=True) if isinstance(data_item, PerformanceMetrics): - self.reporter.submit(data_item, delay_print=True) + self.reporter.submit(data_item, delay_print=True, space=True) if isinstance(data_item, OptimizationPerformanceMetrics): original_metrics = data_item.original_perf_metrics @@ -123,24 +49,6 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): space=True, ) - def on_advice_event(self, event: AdviceEvent) -> None: - """Handle Advice event.""" - self.advice.append(event.advice) - def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None: """Handle EthosUAdvisorStarted event.""" self.reporter.submit(event.device) - - @staticmethod - def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat: - """Resolve output format based on the output name.""" - output_format: OutputFormat = "plain_text" - - if isinstance(output, str): - output_path = Path(output) - output_formats: Dict[str, OutputFormat] = {".csv": "csv", ".json": "json"} - - if (suffix := output_path.suffix) in output_formats: - return output_formats[suffix] - - return output_format diff --git a/src/mlia/devices/ethosu/reporters.py b/src/mlia/devices/ethosu/reporters.py index d28c68f..b3aea24 100644 --- a/src/mlia/devices/ethosu/reporters.py +++ b/src/mlia/devices/ethosu/reporters.py @@ -374,7 +374,7 @@ def report_advice(advice: List[Advice]) -> Report: ) -def find_appropriate_formatter(data: Any) -> Callable[[Any], Report]: +def ethos_u_formatters(data: Any) -> Callable[[Any], Report]: """Find appropriate formatter for the provided data.""" if isinstance(data, PerformanceMetrics) or is_list_of(data, PerformanceMetrics, 2): return report_perf_metrics @@ -392,7 +392,7 @@ def find_appropriate_formatter(data: Any) -> Callable[[Any], Report]: return report_device_details if isinstance(data, (list, tuple)): - formatters = [find_appropriate_formatter(item) for item in data] + formatters = [ethos_u_formatters(item) for item in data] return CompoundFormatter(formatters) raise Exception(f"Unable to find appropriate formatter for {data}") diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py index 7975905..0c28d35 100644 --- a/src/mlia/utils/filesystem.py +++ b/src/mlia/utils/filesystem.py @@ -11,6 +11,7 @@ from pathlib import Path from tempfile import mkstemp from tempfile import TemporaryDirectory from typing import Any +from typing import cast from typing import Dict from typing import Generator from typing import Iterable @@ -32,12 +33,12 @@ def get_vela_config() -> Path: def get_profiles_file() -> Path: - """Get the Ethos-U profiles file.""" + """Get the profiles file.""" return get_mlia_resources() / "profiles.json" def get_profiles_data() -> Dict[str, Dict[str, Any]]: - """Get the Ethos-U profile values as a dictionary.""" + """Get the profile values as a dictionary.""" with open(get_profiles_file(), encoding="utf-8") as json_file: profiles = json.load(json_file) @@ -47,14 +48,17 @@ def get_profiles_data() -> Dict[str, Dict[str, Any]]: return profiles -def get_profile(target: str) -> Dict[str, Any]: +def get_profile(target_profile: str) -> Dict[str, Any]: """Get settings for the provided target profile.""" - profiles = get_profiles_data() + if not target_profile: + raise Exception("Target profile is not provided") - if target not in profiles: - raise Exception(f"Unable to find target profile {target}") + profiles = get_profiles_data() - return profiles[target] + try: + return profiles[target_profile] + except KeyError as err: + raise Exception(f"Unable to find target profile {target_profile}") from err def get_supported_profile_names() -> List[str]: @@ -62,6 +66,12 @@ def get_supported_profile_names() -> List[str]: return list(get_profiles_data().keys()) +def get_target(target_profile: str) -> str: + """Return target for the provided target_profile.""" + profile_data = get_profile(target_profile) + return cast(str, profile_data["target"]) + + @contextmanager def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]: """Create temp file and remove it after.""" diff --git a/tests/test_api.py b/tests/test_api.py index 09bc509..e8df7af 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -7,14 +7,16 @@ from unittest.mock import MagicMock import pytest from mlia.api import get_advice +from mlia.api import get_advisor from mlia.core.common import AdviceCategory from mlia.core.context import Context from mlia.core.context import ExecutionContext +from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor def test_get_advice_no_target_provided(test_keras_model: Path) -> None: """Test getting advice when no target provided.""" - with pytest.raises(Exception, match="Target is not provided"): + with pytest.raises(Exception, match="Target profile is not provided"): get_advice(None, test_keras_model, "all") # type: ignore @@ -78,7 +80,7 @@ def test_get_advice( ) -> None: """Test getting advice with valid parameters.""" advisor_mock = MagicMock() - monkeypatch.setattr("mlia.api._get_advisor", MagicMock(return_value=advisor_mock)) + monkeypatch.setattr("mlia.api.get_advisor", MagicMock(return_value=advisor_mock)) get_advice( "ethos-u55-256", @@ -92,5 +94,12 @@ def test_get_advice( assert isinstance(context, Context) assert context.advice_category == expected_category - assert context.event_handlers is not None - assert context.config_parameters is not None + +def test_get_advisor( + test_keras_model: Path, +) -> None: + """Test function for getting the advisor.""" + ethos_u55_advisor = get_advisor( + ExecutionContext(), "ethos-u55-256", str(test_keras_model) + ) + assert isinstance(ethos_u55_advisor, EthosUInferenceAdvisor) diff --git a/tests/test_core_events.py b/tests/test_core_events.py index faaab7c..a531bab 100644 --- a/tests/test_core_events.py +++ b/tests/test_core_events.py @@ -18,7 +18,7 @@ from mlia.core.events import EventHandler from mlia.core.events import ExecutionFinishedEvent from mlia.core.events import ExecutionStartedEvent from mlia.core.events import stage -from mlia.core.events import SystemEventsHandler +from mlia.core.handlers import SystemEventsHandler @dataclass diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py index 2f7ec22..d7a6ade 100644 --- a/tests/test_core_reporting.py +++ b/tests/test_core_reporting.py @@ -2,9 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for reporting module.""" from typing import List +from typing import Optional import pytest +from mlia.core._typing import OutputFormat +from mlia.core._typing import PathOrFileLike from mlia.core.reporting import BytesCell from mlia.core.reporting import Cell from mlia.core.reporting import ClockCell @@ -13,6 +16,7 @@ from mlia.core.reporting import CyclesCell from mlia.core.reporting import Format from mlia.core.reporting import NestedReport from mlia.core.reporting import ReportItem +from mlia.core.reporting import resolve_output_format from mlia.core.reporting import SingleRow from mlia.core.reporting import Table from mlia.utils.console import remove_ascii_codes @@ -411,3 +415,21 @@ Single row example: alias="simple_row_example", ) wrong_single_row.to_plain_text() + + +@pytest.mark.parametrize( + "output, expected_output_format", + [ + [None, "plain_text"], + ["", "plain_text"], + ["some_file", "plain_text"], + ["some_format.some_ext", "plain_text"], + ["output.csv", "csv"], + ["output.json", "json"], + ], +) +def test_resolve_output_format( + output: Optional[PathOrFileLike], expected_output_format: OutputFormat +) -> None: + """Test function resolve_output_format.""" + assert resolve_output_format(output) == expected_output_format diff --git a/tests/test_devices_ethosu_reporters.py b/tests/test_devices_ethosu_reporters.py index 0da50e0..a63db1c 100644 --- a/tests/test_devices_ethosu_reporters.py +++ b/tests/test_devices_ethosu_reporters.py @@ -22,7 +22,7 @@ from mlia.devices.ethosu.config import EthosUConfiguration from mlia.devices.ethosu.performance import MemoryUsage from mlia.devices.ethosu.performance import NPUCycles from mlia.devices.ethosu.performance import PerformanceMetrics -from mlia.devices.ethosu.reporters import find_appropriate_formatter +from mlia.devices.ethosu.reporters import ethos_u_formatters from mlia.devices.ethosu.reporters import report_device_details from mlia.devices.ethosu.reporters import report_operators from mlia.devices.ethosu.reporters import report_perf_metrics @@ -410,7 +410,7 @@ def test_get_reporter(tmp_path: Path) -> None: ) output = tmp_path / "output.json" - with get_reporter("json", output, find_appropriate_formatter) as reporter: + with get_reporter("json", output, ethos_u_formatters) as reporter: assert isinstance(reporter, Reporter) with pytest.raises( -- cgit v1.2.1