aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-07-21 14:06:03 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-08-19 10:23:23 +0100
commita8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (patch)
tree8463b24ba0446a49b3e012477b0834c3b5415b86
parent76ec769ad8f8ed53ec3ff829fdd34d53db8229fd (diff)
downloadmlia-a8ee1aee3e674c78a77801d1bf2256881ab6b4b9.tar.gz
MLIA-549 Refactor API module to support several target profiles
- Move target specific details out of API module - Move common logic for workflow event handler into a separate class Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788
-rw-r--r--src/mlia/api.py108
-rw-r--r--src/mlia/core/advisor.py64
-rw-r--r--src/mlia/core/events.py54
-rw-r--r--src/mlia/core/handlers.py166
-rw-r--r--src/mlia/core/reporting.py11
-rw-r--r--src/mlia/core/workflow.py8
-rw-r--r--src/mlia/devices/ethosu/advisor.py159
-rw-r--r--src/mlia/devices/ethosu/handlers.py100
-rw-r--r--src/mlia/devices/ethosu/reporters.py4
-rw-r--r--src/mlia/utils/filesystem.py24
-rw-r--r--tests/test_api.py17
-rw-r--r--tests/test_core_events.py2
-rw-r--r--tests/test_core_reporting.py22
-rw-r--r--tests/test_devices_ethosu_reporters.py4
14 files changed, 439 insertions, 304 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index 0f950db..024bc98 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -14,28 +14,13 @@ from mlia.core._typing import PathOrFileLike
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
-from mlia.core.events import EventHandler
-from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor
-from mlia.devices.ethosu.handlers import EthosUEventHandler
+from mlia.devices.ethosu.advisor import configure_and_get_ethosu_advisor
+from mlia.utils.filesystem import get_target
logger = logging.getLogger(__name__)
-_DEFAULT_OPTIMIZATION_TARGETS = [
- {
- "optimization_type": "pruning",
- "optimization_target": 0.5,
- "layers_to_optimize": None,
- },
- {
- "optimization_type": "clustering",
- "optimization_target": 32,
- "layers_to_optimize": None,
- },
-]
-
-
def get_advice(
target_profile: str,
model: Union[Path, str],
@@ -71,7 +56,6 @@ def get_advice(
:param backends: A list of backends that should be used for the given
target. Default settings will be used if None.
-
Examples:
NB: Before launching MLIA, the logging functionality should be configured!
@@ -87,75 +71,51 @@ def get_advice(
"""
advice_category = AdviceCategory.from_string(category)
- config_parameters = _get_config_parameters(
- model, target_profile, backends, optimization_targets
- )
- event_handlers = _get_event_handlers(output)
if context is not None:
context.advice_category = advice_category
- if context.config_parameters is None:
- context.config_parameters = config_parameters
-
- if context.event_handlers is None:
- context.event_handlers = event_handlers
-
if context is None:
context = ExecutionContext(
advice_category=advice_category,
working_dir=working_dir,
- config_parameters=config_parameters,
- event_handlers=event_handlers,
)
- advisor = _get_advisor(target_profile)
- advisor.run(context)
-
-
-def _get_advisor(target: Optional[str]) -> InferenceAdvisor:
- """Find appropriate advisor for the target."""
- if not target:
- raise Exception("Target is not provided")
+ advisor = get_advisor(
+ context,
+ target_profile,
+ model,
+ output,
+ optimization_targets=optimization_targets,
+ backends=backends,
+ )
- return EthosUInferenceAdvisor()
+ advisor.run(context)
-def _get_config_parameters(
- model: Union[Path, str],
+def get_advisor(
+ context: ExecutionContext,
target_profile: str,
- backends: Optional[List[str]],
- optimization_targets: Optional[List[Dict[str, Any]]],
-) -> Dict[str, Any]:
- """Get configuration parameters for the advisor."""
- advisor_parameters: Dict[str, Any] = {
- "ethos_u_inference_advisor": {
- "model": model,
- "device": {
- "target_profile": target_profile,
- },
- },
+ model: Union[Path, str],
+ output: Optional[PathOrFileLike] = None,
+ **extra_args: Any,
+) -> InferenceAdvisor:
+ """Find appropriate advisor for the target."""
+ target_factories = {
+ "ethos-u55": configure_and_get_ethosu_advisor,
+ "ethos-u65": configure_and_get_ethosu_advisor,
}
- # Specifying backends is optional (default is used)
- if backends is not None:
- advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends
-
- if not optimization_targets:
- optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS
-
- advisor_parameters.update(
- {
- "ethos_u_model_optimizations": {
- "optimizations": [
- optimization_targets,
- ],
- },
- }
- )
-
- return advisor_parameters
-
-def _get_event_handlers(output: Optional[PathOrFileLike]) -> List[EventHandler]:
- """Return list of the event handlers."""
- return [EthosUEventHandler(output)]
+ try:
+ target = get_target(target_profile)
+ factory_function = target_factories[target]
+ except KeyError as err:
+ raise Exception(f"Unsupported profile {target_profile}") from err
+
+ return factory_function(
+ context,
+ target_profile,
+ model,
+ output,
+ **extra_args,
+ )
diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py
index 868d0c7..13689fa 100644
--- a/src/mlia/core/advisor.py
+++ b/src/mlia/core/advisor.py
@@ -2,9 +2,18 @@
# SPDX-License-Identifier: Apache-2.0
"""Inference advisor module."""
from abc import abstractmethod
+from pathlib import Path
+from typing import cast
+from typing import List
+from mlia.core.advice_generation import AdviceProducer
from mlia.core.common import NamedEntity
from mlia.core.context import Context
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.events import Event
+from mlia.core.mixins import ParameterResolverMixin
+from mlia.core.workflow import DefaultWorkflowExecutor
from mlia.core.workflow import WorkflowExecutor
@@ -19,3 +28,58 @@ class InferenceAdvisor(NamedEntity):
"""Run inference advisor."""
executor = self.configure(context)
executor.run()
+
+
+class DefaultInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
+ """Default implementation for the advisor."""
+
+ def configure(self, context: Context) -> WorkflowExecutor:
+ """Configure advisor."""
+ return DefaultWorkflowExecutor(
+ context,
+ self.get_collectors(context),
+ self.get_analyzers(context),
+ self.get_producers(context),
+ self.get_events(context),
+ )
+
+ @abstractmethod
+ def get_collectors(self, context: Context) -> List[DataCollector]:
+ """Return list of the data collectors."""
+
+ @abstractmethod
+ def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
+ """Return list of the data analyzers."""
+
+ @abstractmethod
+ def get_producers(self, context: Context) -> List[AdviceProducer]:
+ """Return list of the advice producers."""
+
+ @abstractmethod
+ def get_events(self, context: Context) -> List[Event]:
+ """Return list of the startup events."""
+
+ def get_string_parameter(self, context: Context, param: str) -> str:
+ """Get string parameter value."""
+ value = self.get_parameter(
+ self.name(),
+ param,
+ expected_type=str,
+ context=context,
+ )
+
+ return cast(str, value)
+
+ def get_model(self, context: Context) -> Path:
+ """Get path to the model."""
+ model_param = self.get_string_parameter(context, "model")
+
+ model = Path(model_param)
+ if not model.exists():
+ raise Exception(f"Path {model} does not exist")
+
+ return model
+
+ def get_target_profile(self, context: Context) -> str:
+ """Get target profile."""
+ return self.get_string_parameter(context, "target_profile")
diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py
index 10aec86..0b8461b 100644
--- a/src/mlia/core/events.py
+++ b/src/mlia/core/events.py
@@ -399,57 +399,3 @@ def action(
publisher.publish_event(action_started)
yield
publisher.publish_event(action_finished)
-
-
-class SystemEventsHandler(EventDispatcher):
- """System events handler."""
-
- def on_execution_started(self, event: ExecutionStartedEvent) -> None:
- """Handle ExecutionStarted event."""
-
- def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
- """Handle ExecutionFinished event."""
-
- def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
- """Handle ExecutionFailed event."""
-
- def on_data_collection_stage_started(
- self, event: DataCollectionStageStartedEvent
- ) -> None:
- """Handle DataCollectionStageStarted event."""
-
- def on_data_collection_stage_finished(
- self, event: DataCollectionStageFinishedEvent
- ) -> None:
- """Handle DataCollectionStageFinished event."""
-
- def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
- """Handle DataCollectorSkipped event."""
-
- def on_data_analysis_stage_started(
- self, event: DataAnalysisStageStartedEvent
- ) -> None:
- """Handle DataAnalysisStageStartedEvent event."""
-
- def on_data_analysis_stage_finished(
- self, event: DataAnalysisStageFinishedEvent
- ) -> None:
- """Handle DataAnalysisStageFinishedEvent event."""
-
- def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
- """Handle AdviceStageStarted event."""
-
- def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
- """Handle AdviceStageFinished event."""
-
- def on_collected_data(self, event: CollectedDataEvent) -> None:
- """Handle CollectedData event."""
-
- def on_analyzed_data(self, event: AnalyzedDataEvent) -> None:
- """Handle AnalyzedData event."""
-
- def on_action_started(self, event: ActionStartedEvent) -> None:
- """Handle ActionStarted event."""
-
- def on_action_finished(self, event: ActionFinishedEvent) -> None:
- """Handle ActionFinished event."""
diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py
new file mode 100644
index 0000000..e576f74
--- /dev/null
+++ b/src/mlia/core/handlers.py
@@ -0,0 +1,166 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Event handlers module."""
+import logging
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Optional
+
+from mlia.core._typing import PathOrFileLike
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import AdviceEvent
+from mlia.core.events import ActionFinishedEvent
+from mlia.core.events import ActionStartedEvent
+from mlia.core.events import AdviceStageFinishedEvent
+from mlia.core.events import AdviceStageStartedEvent
+from mlia.core.events import AnalyzedDataEvent
+from mlia.core.events import CollectedDataEvent
+from mlia.core.events import DataAnalysisStageFinishedEvent
+from mlia.core.events import DataAnalysisStageStartedEvent
+from mlia.core.events import DataCollectionStageFinishedEvent
+from mlia.core.events import DataCollectionStageStartedEvent
+from mlia.core.events import DataCollectorSkippedEvent
+from mlia.core.events import EventDispatcher
+from mlia.core.events import ExecutionFailedEvent
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.reporting import Report
+from mlia.core.reporting import Reporter
+from mlia.core.reporting import resolve_output_format
+from mlia.utils.console import create_section_header
+
+
+logger = logging.getLogger(__name__)
+
+
+class SystemEventsHandler(EventDispatcher):
+ """System events handler."""
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+
+ def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
+ """Handle ExecutionFinished event."""
+
+ def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
+ """Handle ExecutionFailed event."""
+
+ def on_data_collection_stage_started(
+ self, event: DataCollectionStageStartedEvent
+ ) -> None:
+ """Handle DataCollectionStageStarted event."""
+
+ def on_data_collection_stage_finished(
+ self, event: DataCollectionStageFinishedEvent
+ ) -> None:
+ """Handle DataCollectionStageFinished event."""
+
+ def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
+ """Handle DataCollectorSkipped event."""
+
+ def on_data_analysis_stage_started(
+ self, event: DataAnalysisStageStartedEvent
+ ) -> None:
+ """Handle DataAnalysisStageStartedEvent event."""
+
+ def on_data_analysis_stage_finished(
+ self, event: DataAnalysisStageFinishedEvent
+ ) -> None:
+ """Handle DataAnalysisStageFinishedEvent event."""
+
+ def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
+ """Handle AdviceStageStarted event."""
+
+ def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
+ """Handle AdviceStageFinished event."""
+
+ def on_collected_data(self, event: CollectedDataEvent) -> None:
+ """Handle CollectedData event."""
+
+ def on_analyzed_data(self, event: AnalyzedDataEvent) -> None:
+ """Handle AnalyzedData event."""
+
+ def on_action_started(self, event: ActionStartedEvent) -> None:
+ """Handle ActionStarted event."""
+
+ def on_action_finished(self, event: ActionFinishedEvent) -> None:
+ """Handle ActionFinished event."""
+
+
+_ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started")
+_MODEL_ANALYSIS_MSG = create_section_header("Model Analysis")
+_MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results")
+_ADV_GENERATION_MSG = create_section_header("Advice Generation")
+_REPORT_GENERATION_MSG = create_section_header("Report Generation")
+
+
+class WorkflowEventsHandler(SystemEventsHandler):
+ """Event handler for the system events."""
+
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ output: Optional[PathOrFileLike] = None,
+ ) -> None:
+ """Init event handler."""
+ output_format = resolve_output_format(output)
+ self.reporter = Reporter(formatter_resolver, output_format)
+ self.output = output
+
+ self.advice: List[Advice] = []
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+ logger.info(_ADV_EXECUTION_STARTED)
+
+ def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
+ """Handle ExecutionFailed event."""
+ raise event.err
+
+ def on_data_collection_stage_started(
+ self, event: DataCollectionStageStartedEvent
+ ) -> None:
+ """Handle DataCollectionStageStarted event."""
+ logger.info(_MODEL_ANALYSIS_MSG)
+
+ def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
+ """Handle AdviceStageStarted event."""
+ logger.info(_ADV_GENERATION_MSG)
+
+ def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
+ """Handle DataCollectorSkipped event."""
+ logger.info("Skipped: %s", event.reason)
+
+ @staticmethod
+ def report_generated(output: PathOrFileLike) -> None:
+ """Log report generation."""
+ logger.info(_REPORT_GENERATION_MSG)
+ logger.info("Report(s) and advice list saved to: %s", output)
+
+ def on_data_analysis_stage_finished(
+ self, event: DataAnalysisStageFinishedEvent
+ ) -> None:
+ """Handle DataAnalysisStageFinished event."""
+ logger.info(_MODEL_ANALYSIS_RESULTS_MSG)
+
+ self.reporter.print_delayed()
+
+ def on_advice_event(self, event: AdviceEvent) -> None:
+ """Handle Advice event."""
+ self.advice.append(event.advice)
+
+ def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
+ """Handle AdviceStageFinishedEvent event."""
+ self.reporter.submit(
+ self.advice,
+ show_title=False,
+ show_headers=False,
+ space="between",
+ table_style="no_borders",
+ )
+
+ self.reporter.generate_report(self.output)
+
+ if self.output is not None:
+ self.report_generated(self.output)
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index 9006602..58a41d3 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -760,3 +760,14 @@ def _apply_format_parameters(
return report
return wrapper
+
+
+def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat:
+ """Resolve output format based on the output name."""
+ if isinstance(output, (str, Path)):
+ format_from_filename = Path(output).suffix.lstrip(".")
+
+ if format_from_filename in ["json", "csv"]:
+ return cast(OutputFormat, format_from_filename)
+
+ return "plain_text"
diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py
index 0245087..03f3d1c 100644
--- a/src/mlia/core/workflow.py
+++ b/src/mlia/core/workflow.py
@@ -87,7 +87,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
collectors: Sequence[DataCollector],
analyzers: Sequence[DataAnalyzer],
producers: Sequence[AdviceProducer],
- before_start_events: Optional[Sequence[Event]] = None,
+ startup_events: Optional[Sequence[Event]] = None,
):
"""Init default workflow executor.
@@ -95,14 +95,14 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
:param collectors: List of the data collectors
:param analyzers: List of the data analyzers
:param producers: List of the advice producers
- :param before_start_events: Optional list of the custom events that
+ :param startup_events: Optional list of the custom events that
should be published before start of the worfkow execution.
"""
self.context = context
self.collectors = collectors
self.analyzers = analyzers
self.producers = producers
- self.before_start_events = before_start_events
+ self.startup_events = startup_events
def run(self) -> None:
"""Run the workflow."""
@@ -125,7 +125,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
def before_start(self) -> None:
"""Run actions before start of the workflow execution."""
- events = self.before_start_events or []
+ events = self.startup_events or []
for event in events:
self.publish(event)
diff --git a/src/mlia/devices/ethosu/advisor.py b/src/mlia/devices/ethosu/advisor.py
index e93858f..b7b8305 100644
--- a/src/mlia/devices/ethosu/advisor.py
+++ b/src/mlia/devices/ethosu/advisor.py
@@ -2,18 +2,22 @@
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U MLIA module."""
from pathlib import Path
+from typing import Any
+from typing import Dict
from typing import List
from typing import Optional
+from typing import Union
+from mlia.core._typing import PathOrFileLike
from mlia.core.advice_generation import AdviceProducer
+from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
-from mlia.core.mixins import ParameterResolverMixin
-from mlia.core.workflow import DefaultWorkflowExecutor
-from mlia.core.workflow import WorkflowExecutor
+from mlia.core.events import Event
from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer
from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer
from mlia.devices.ethosu.config import EthosUConfiguration
@@ -23,10 +27,12 @@ from mlia.devices.ethosu.data_collection import EthosUOperatorCompatibility
from mlia.devices.ethosu.data_collection import EthosUOptimizationPerformance
from mlia.devices.ethosu.data_collection import EthosUPerformance
from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent
+from mlia.devices.ethosu.handlers import EthosUEventHandler
from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.utils.types import is_list_of
-class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
+class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
"""Ethos-U Inference Advisor."""
@classmethod
@@ -34,34 +40,12 @@ class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
"""Return name of the advisor."""
return "ethos_u_inference_advisor"
- def configure(self, context: Context) -> WorkflowExecutor:
- """Configure advisor execution."""
- model = self._get_model(context)
+ def get_collectors(self, context: Context) -> List[DataCollector]:
+ """Return list of the data collectors."""
+ model = self.get_model(context)
device = self._get_device(context)
backends = self._get_backends(context)
- collectors = self._get_collectors(context, model, device, backends)
- analyzers = self._get_analyzers()
- producers = self._get_advice_producers()
-
- return DefaultWorkflowExecutor(
- context,
- collectors,
- analyzers,
- producers,
- before_start_events=[
- EthosUAdvisorStartedEvent(device=device, model=model),
- ],
- )
-
- def _get_collectors(
- self,
- context: Context,
- model: Path,
- device: EthosUConfiguration,
- backends: Optional[List[str]],
- ) -> List[DataCollector]:
- """Get collectors."""
collectors: List[DataCollector] = []
if AdviceCategory.OPERATORS in context.advice_category:
@@ -91,51 +75,34 @@ class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
return collectors
- @staticmethod
- def _get_analyzers() -> List[DataAnalyzer]:
- """Return data analyzers."""
+ def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
+ """Return list of the data analyzers."""
return [
EthosUDataAnalyzer(),
]
- @staticmethod
- def _get_advice_producers() -> List[AdviceProducer]:
- """Return advice producers."""
+ def get_producers(self, context: Context) -> List[AdviceProducer]:
+ """Return list of the advice producers."""
return [
EthosUAdviceProducer(),
EthosUStaticAdviceProducer(),
]
+ def get_events(self, context: Context) -> List[Event]:
+ """Return list of the startup events."""
+ model = self.get_model(context)
+ device = self._get_device(context)
+
+ return [
+ EthosUAdvisorStartedEvent(device=device, model=model),
+ ]
+
def _get_device(self, context: Context) -> EthosUConfiguration:
"""Get device."""
- device_params = self.get_parameter(
- self.name(),
- "device",
- expected_type=dict,
- context=context,
- )
-
- try:
- target_profile = device_params["target_profile"]
- except KeyError as err:
- raise Exception("Unable to get device details") from err
+ target_profile = self.get_target_profile(context)
return get_target(target_profile)
- def _get_model(self, context: Context) -> Path:
- """Get path to the model."""
- model_param = self.get_parameter(
- self.name(),
- "model",
- expected_type=str,
- context=context,
- )
-
- if not (model := Path(model_param)).exists():
- raise Exception(f"Path {model} does not exist")
-
- return model
-
def _get_optimization_settings(self, context: Context) -> List[List[dict]]:
"""Get optimization settings."""
return self.get_parameter( # type: ignore
@@ -155,3 +122,75 @@ class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
expected=False,
context=context,
)
+
+
+def configure_and_get_ethosu_advisor(
+ context: ExecutionContext,
+ target_profile: str,
+ model: Union[Path, str],
+ output: Optional[PathOrFileLike] = None,
+ **extra_args: Any,
+) -> InferenceAdvisor:
+ """Create and configure Ethos-U advisor."""
+ if context.event_handlers is None:
+ context.event_handlers = [EthosUEventHandler(output)]
+
+ if context.config_parameters is None:
+ context.config_parameters = _get_config_parameters(
+ model, target_profile, **extra_args
+ )
+
+ return EthosUInferenceAdvisor()
+
+
+_DEFAULT_OPTIMIZATION_TARGETS = [
+ {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ },
+ {
+ "optimization_type": "clustering",
+ "optimization_target": 32,
+ "layers_to_optimize": None,
+ },
+]
+
+
+def _get_config_parameters(
+ model: Union[Path, str],
+ target_profile: str,
+ **extra_args: Any,
+) -> Dict[str, Any]:
+ """Get configuration parameters for the advisor."""
+ advisor_parameters: Dict[str, Any] = {
+ "ethos_u_inference_advisor": {
+ "model": model,
+ "target_profile": target_profile,
+ },
+ }
+
+ # Specifying backends is optional (default is used)
+ backends = extra_args.get("backends")
+ if backends is not None:
+ if not is_list_of(backends, str):
+ raise Exception("Backends value has wrong format")
+
+ advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends
+
+ optimization_targets = extra_args.get("optimization_targets")
+ if not optimization_targets:
+ optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS
+
+ if not is_list_of(optimization_targets, dict):
+ raise Exception("Optimization targets value has wrong format")
+
+ advisor_parameters.update(
+ {
+ "ethos_u_model_optimizations": {
+ "optimizations": [optimization_targets],
+ },
+ }
+ )
+
+ return advisor_parameters
diff --git a/src/mlia/devices/ethosu/handlers.py b/src/mlia/devices/ethosu/handlers.py
index 7a0c31c..ee0b809 100644
--- a/src/mlia/devices/ethosu/handlers.py
+++ b/src/mlia/devices/ethosu/handlers.py
@@ -2,101 +2,27 @@
# SPDX-License-Identifier: Apache-2.0
"""Event handler."""
import logging
-from pathlib import Path
-from typing import Dict
-from typing import List
from typing import Optional
-from mlia.core._typing import OutputFormat
from mlia.core._typing import PathOrFileLike
-from mlia.core.advice_generation import Advice
-from mlia.core.advice_generation import AdviceEvent
-from mlia.core.events import AdviceStageFinishedEvent
-from mlia.core.events import AdviceStageStartedEvent
from mlia.core.events import CollectedDataEvent
-from mlia.core.events import DataAnalysisStageFinishedEvent
-from mlia.core.events import DataCollectionStageStartedEvent
-from mlia.core.events import DataCollectorSkippedEvent
-from mlia.core.events import ExecutionFailedEvent
-from mlia.core.events import ExecutionStartedEvent
-from mlia.core.events import SystemEventsHandler
-from mlia.core.reporting import Reporter
+from mlia.core.handlers import WorkflowEventsHandler
from mlia.devices.ethosu.events import EthosUAdvisorEventHandler
from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent
from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
from mlia.devices.ethosu.performance import PerformanceMetrics
-from mlia.devices.ethosu.reporters import find_appropriate_formatter
+from mlia.devices.ethosu.reporters import ethos_u_formatters
from mlia.tools.vela_wrapper import Operators
-from mlia.utils.console import create_section_header
logger = logging.getLogger(__name__)
-ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started")
-MODEL_ANALYSIS_MSG = create_section_header("Model Analysis")
-MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results")
-ADV_GENERATION_MSG = create_section_header("Advice Generation")
-REPORT_GENERATION_MSG = create_section_header("Report Generation")
-
-
-class WorkflowEventsHandler(SystemEventsHandler):
- """Event handler for the system events."""
-
- def on_execution_started(self, event: ExecutionStartedEvent) -> None:
- """Handle ExecutionStarted event."""
- logger.info(ADV_EXECUTION_STARTED)
-
- def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
- """Handle ExecutionFailed event."""
- raise event.err
-
- def on_data_collection_stage_started(
- self, event: DataCollectionStageStartedEvent
- ) -> None:
- """Handle DataCollectionStageStarted event."""
- logger.info(MODEL_ANALYSIS_MSG)
-
- def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
- """Handle AdviceStageStarted event."""
- logger.info(ADV_GENERATION_MSG)
-
- def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
- """Handle DataCollectorSkipped event."""
- logger.info("Skipped: %s", event.reason)
-
class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
"""CLI event handler."""
def __init__(self, output: Optional[PathOrFileLike] = None) -> None:
"""Init event handler."""
- output_format = self.resolve_output_format(output)
-
- self.reporter = Reporter(find_appropriate_formatter, output_format)
- self.output = output
- self.advice: List[Advice] = []
-
- def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
- """Handle AdviceStageFinishedEvent event."""
- self.reporter.submit(
- self.advice,
- show_title=False,
- show_headers=False,
- space="between",
- table_style="no_borders",
- )
-
- self.reporter.generate_report(self.output)
-
- if self.output is not None:
- logger.info(REPORT_GENERATION_MSG)
- logger.info("Report(s) and advice list saved to: %s", self.output)
-
- def on_data_analysis_stage_finished(
- self, event: DataAnalysisStageFinishedEvent
- ) -> None:
- """Handle DataAnalysisStageFinished event."""
- logger.info(MODEL_ANALYSIS_RESULTS_MSG)
- self.reporter.print_delayed()
+ super().__init__(ethos_u_formatters, output)
def on_collected_data(self, event: CollectedDataEvent) -> None:
"""Handle CollectedDataEvent event."""
@@ -106,7 +32,7 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
self.reporter.submit([data_item.ops, data_item], delay_print=True)
if isinstance(data_item, PerformanceMetrics):
- self.reporter.submit(data_item, delay_print=True)
+ self.reporter.submit(data_item, delay_print=True, space=True)
if isinstance(data_item, OptimizationPerformanceMetrics):
original_metrics = data_item.original_perf_metrics
@@ -123,24 +49,6 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
space=True,
)
- def on_advice_event(self, event: AdviceEvent) -> None:
- """Handle Advice event."""
- self.advice.append(event.advice)
-
def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None:
"""Handle EthosUAdvisorStarted event."""
self.reporter.submit(event.device)
-
- @staticmethod
- def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat:
- """Resolve output format based on the output name."""
- output_format: OutputFormat = "plain_text"
-
- if isinstance(output, str):
- output_path = Path(output)
- output_formats: Dict[str, OutputFormat] = {".csv": "csv", ".json": "json"}
-
- if (suffix := output_path.suffix) in output_formats:
- return output_formats[suffix]
-
- return output_format
diff --git a/src/mlia/devices/ethosu/reporters.py b/src/mlia/devices/ethosu/reporters.py
index d28c68f..b3aea24 100644
--- a/src/mlia/devices/ethosu/reporters.py
+++ b/src/mlia/devices/ethosu/reporters.py
@@ -374,7 +374,7 @@ def report_advice(advice: List[Advice]) -> Report:
)
-def find_appropriate_formatter(data: Any) -> Callable[[Any], Report]:
+def ethos_u_formatters(data: Any) -> Callable[[Any], Report]:
"""Find appropriate formatter for the provided data."""
if isinstance(data, PerformanceMetrics) or is_list_of(data, PerformanceMetrics, 2):
return report_perf_metrics
@@ -392,7 +392,7 @@ def find_appropriate_formatter(data: Any) -> Callable[[Any], Report]:
return report_device_details
if isinstance(data, (list, tuple)):
- formatters = [find_appropriate_formatter(item) for item in data]
+ formatters = [ethos_u_formatters(item) for item in data]
return CompoundFormatter(formatters)
raise Exception(f"Unable to find appropriate formatter for {data}")
diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py
index 7975905..0c28d35 100644
--- a/src/mlia/utils/filesystem.py
+++ b/src/mlia/utils/filesystem.py
@@ -11,6 +11,7 @@ from pathlib import Path
from tempfile import mkstemp
from tempfile import TemporaryDirectory
from typing import Any
+from typing import cast
from typing import Dict
from typing import Generator
from typing import Iterable
@@ -32,12 +33,12 @@ def get_vela_config() -> Path:
def get_profiles_file() -> Path:
- """Get the Ethos-U profiles file."""
+ """Get the profiles file."""
return get_mlia_resources() / "profiles.json"
def get_profiles_data() -> Dict[str, Dict[str, Any]]:
- """Get the Ethos-U profile values as a dictionary."""
+ """Get the profile values as a dictionary."""
with open(get_profiles_file(), encoding="utf-8") as json_file:
profiles = json.load(json_file)
@@ -47,14 +48,17 @@ def get_profiles_data() -> Dict[str, Dict[str, Any]]:
return profiles
-def get_profile(target: str) -> Dict[str, Any]:
+def get_profile(target_profile: str) -> Dict[str, Any]:
"""Get settings for the provided target profile."""
- profiles = get_profiles_data()
+ if not target_profile:
+ raise Exception("Target profile is not provided")
- if target not in profiles:
- raise Exception(f"Unable to find target profile {target}")
+ profiles = get_profiles_data()
- return profiles[target]
+ try:
+ return profiles[target_profile]
+ except KeyError as err:
+ raise Exception(f"Unable to find target profile {target_profile}") from err
def get_supported_profile_names() -> List[str]:
@@ -62,6 +66,12 @@ def get_supported_profile_names() -> List[str]:
return list(get_profiles_data().keys())
+def get_target(target_profile: str) -> str:
+ """Return target for the provided target_profile."""
+ profile_data = get_profile(target_profile)
+ return cast(str, profile_data["target"])
+
+
@contextmanager
def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]:
"""Create temp file and remove it after."""
diff --git a/tests/test_api.py b/tests/test_api.py
index 09bc509..e8df7af 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -7,14 +7,16 @@ from unittest.mock import MagicMock
import pytest
from mlia.api import get_advice
+from mlia.api import get_advisor
from mlia.core.common import AdviceCategory
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
+from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor
def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
"""Test getting advice when no target provided."""
- with pytest.raises(Exception, match="Target is not provided"):
+ with pytest.raises(Exception, match="Target profile is not provided"):
get_advice(None, test_keras_model, "all") # type: ignore
@@ -78,7 +80,7 @@ def test_get_advice(
) -> None:
"""Test getting advice with valid parameters."""
advisor_mock = MagicMock()
- monkeypatch.setattr("mlia.api._get_advisor", MagicMock(return_value=advisor_mock))
+ monkeypatch.setattr("mlia.api.get_advisor", MagicMock(return_value=advisor_mock))
get_advice(
"ethos-u55-256",
@@ -92,5 +94,12 @@ def test_get_advice(
assert isinstance(context, Context)
assert context.advice_category == expected_category
- assert context.event_handlers is not None
- assert context.config_parameters is not None
+
+def test_get_advisor(
+ test_keras_model: Path,
+) -> None:
+ """Test function for getting the advisor."""
+ ethos_u55_advisor = get_advisor(
+ ExecutionContext(), "ethos-u55-256", str(test_keras_model)
+ )
+ assert isinstance(ethos_u55_advisor, EthosUInferenceAdvisor)
diff --git a/tests/test_core_events.py b/tests/test_core_events.py
index faaab7c..a531bab 100644
--- a/tests/test_core_events.py
+++ b/tests/test_core_events.py
@@ -18,7 +18,7 @@ from mlia.core.events import EventHandler
from mlia.core.events import ExecutionFinishedEvent
from mlia.core.events import ExecutionStartedEvent
from mlia.core.events import stage
-from mlia.core.events import SystemEventsHandler
+from mlia.core.handlers import SystemEventsHandler
@dataclass
diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py
index 2f7ec22..d7a6ade 100644
--- a/tests/test_core_reporting.py
+++ b/tests/test_core_reporting.py
@@ -2,9 +2,12 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for reporting module."""
from typing import List
+from typing import Optional
import pytest
+from mlia.core._typing import OutputFormat
+from mlia.core._typing import PathOrFileLike
from mlia.core.reporting import BytesCell
from mlia.core.reporting import Cell
from mlia.core.reporting import ClockCell
@@ -13,6 +16,7 @@ from mlia.core.reporting import CyclesCell
from mlia.core.reporting import Format
from mlia.core.reporting import NestedReport
from mlia.core.reporting import ReportItem
+from mlia.core.reporting import resolve_output_format
from mlia.core.reporting import SingleRow
from mlia.core.reporting import Table
from mlia.utils.console import remove_ascii_codes
@@ -411,3 +415,21 @@ Single row example:
alias="simple_row_example",
)
wrong_single_row.to_plain_text()
+
+
+@pytest.mark.parametrize(
+ "output, expected_output_format",
+ [
+ [None, "plain_text"],
+ ["", "plain_text"],
+ ["some_file", "plain_text"],
+ ["some_format.some_ext", "plain_text"],
+ ["output.csv", "csv"],
+ ["output.json", "json"],
+ ],
+)
+def test_resolve_output_format(
+ output: Optional[PathOrFileLike], expected_output_format: OutputFormat
+) -> None:
+ """Test function resolve_output_format."""
+ assert resolve_output_format(output) == expected_output_format
diff --git a/tests/test_devices_ethosu_reporters.py b/tests/test_devices_ethosu_reporters.py
index 0da50e0..a63db1c 100644
--- a/tests/test_devices_ethosu_reporters.py
+++ b/tests/test_devices_ethosu_reporters.py
@@ -22,7 +22,7 @@ from mlia.devices.ethosu.config import EthosUConfiguration
from mlia.devices.ethosu.performance import MemoryUsage
from mlia.devices.ethosu.performance import NPUCycles
from mlia.devices.ethosu.performance import PerformanceMetrics
-from mlia.devices.ethosu.reporters import find_appropriate_formatter
+from mlia.devices.ethosu.reporters import ethos_u_formatters
from mlia.devices.ethosu.reporters import report_device_details
from mlia.devices.ethosu.reporters import report_operators
from mlia.devices.ethosu.reporters import report_perf_metrics
@@ -410,7 +410,7 @@ def test_get_reporter(tmp_path: Path) -> None:
)
output = tmp_path / "output.json"
- with get_reporter("json", output, find_appropriate_formatter) as reporter:
+ with get_reporter("json", output, ethos_u_formatters) as reporter:
assert isinstance(reporter, Reporter)
with pytest.raises(