aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/ethos_u/advisor.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/ethos_u/advisor.py')
-rw-r--r--src/mlia/target/ethos_u/advisor.py194
1 files changed, 194 insertions, 0 deletions
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
new file mode 100644
index 0000000..b9d64ff
--- /dev/null
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -0,0 +1,194 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U MLIA module."""
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+
+from mlia.core.advice_generation import AdviceProducer
+from mlia.core.advisor import DefaultInferenceAdvisor
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.common import AdviceCategory
+from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.events import Event
+from mlia.core.typing import PathOrFileLike
+from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.target.ethos_u.advice_generation import EthosUAdviceProducer
+from mlia.target.ethos_u.advice_generation import EthosUStaticAdviceProducer
+from mlia.target.ethos_u.config import EthosUConfiguration
+from mlia.target.ethos_u.config import get_target
+from mlia.target.ethos_u.data_analysis import EthosUDataAnalyzer
+from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility
+from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance
+from mlia.target.ethos_u.data_collection import EthosUPerformance
+from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent
+from mlia.target.ethos_u.handlers import EthosUEventHandler
+from mlia.utils.types import is_list_of
+
+
+class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
+ """Ethos-U Inference Advisor."""
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the advisor."""
+ return "ethos_u_inference_advisor"
+
+ def get_collectors(self, context: Context) -> list[DataCollector]:
+ """Return list of the data collectors."""
+ model = self.get_model(context)
+ device = self._get_device(context)
+ backends = self._get_backends(context)
+
+ collectors: list[DataCollector] = []
+
+ if AdviceCategory.OPERATORS in context.advice_category:
+ collectors.append(EthosUOperatorCompatibility(model, device))
+
+ # Performance and optimization are mutually exclusive.
+ # Decide which one to use (taking into account the model format).
+ if is_tflite_model(model):
+ # TensorFlow Lite models do not support optimization (only performance)!
+ if context.advice_category == AdviceCategory.OPTIMIZATION:
+ raise Exception(
+ "Command 'optimization' is not supported for TensorFlow Lite files."
+ )
+ if AdviceCategory.PERFORMANCE in context.advice_category:
+ collectors.append(EthosUPerformance(model, device, backends))
+ else:
+ # Keras/SavedModel: Prefer optimization
+ if AdviceCategory.OPTIMIZATION in context.advice_category:
+ optimization_settings = self._get_optimization_settings(context)
+ collectors.append(
+ EthosUOptimizationPerformance(
+ model, device, optimization_settings, backends
+ )
+ )
+ elif AdviceCategory.PERFORMANCE in context.advice_category:
+ collectors.append(EthosUPerformance(model, device, backends))
+
+ return collectors
+
+ def get_analyzers(self, context: Context) -> list[DataAnalyzer]:
+ """Return list of the data analyzers."""
+ return [
+ EthosUDataAnalyzer(),
+ ]
+
+ def get_producers(self, context: Context) -> list[AdviceProducer]:
+ """Return list of the advice producers."""
+ return [
+ EthosUAdviceProducer(),
+ EthosUStaticAdviceProducer(),
+ ]
+
+ def get_events(self, context: Context) -> list[Event]:
+ """Return list of the startup events."""
+ model = self.get_model(context)
+ device = self._get_device(context)
+
+ return [
+ EthosUAdvisorStartedEvent(device=device, model=model),
+ ]
+
+ def _get_device(self, context: Context) -> EthosUConfiguration:
+ """Get device."""
+ target_profile = self.get_target_profile(context)
+
+ return get_target(target_profile)
+
+ def _get_optimization_settings(self, context: Context) -> list[list[dict]]:
+ """Get optimization settings."""
+ return self.get_parameter( # type: ignore
+ EthosUOptimizationPerformance.name(),
+ "optimizations",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )
+
+ def _get_backends(self, context: Context) -> list[str] | None:
+ """Get list of backends."""
+ return self.get_parameter( # type: ignore
+ self.name(),
+ "backends",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )
+
+
+def configure_and_get_ethosu_advisor(
+ context: ExecutionContext,
+ target_profile: str,
+ model: str | Path,
+ output: PathOrFileLike | None = None,
+ **extra_args: Any,
+) -> InferenceAdvisor:
+ """Create and configure Ethos-U advisor."""
+ if context.event_handlers is None:
+ context.event_handlers = [EthosUEventHandler(output)]
+
+ if context.config_parameters is None:
+ context.config_parameters = _get_config_parameters(
+ model, target_profile, **extra_args
+ )
+
+ return EthosUInferenceAdvisor()
+
+
+_DEFAULT_OPTIMIZATION_TARGETS = [
+ {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ },
+ {
+ "optimization_type": "clustering",
+ "optimization_target": 32,
+ "layers_to_optimize": None,
+ },
+]
+
+
+def _get_config_parameters(
+ model: str | Path,
+ target_profile: str,
+ **extra_args: Any,
+) -> dict[str, Any]:
+ """Get configuration parameters for the advisor."""
+ advisor_parameters: dict[str, Any] = {
+ "ethos_u_inference_advisor": {
+ "model": model,
+ "target_profile": target_profile,
+ },
+ }
+
+ # Specifying backends is optional (default is used)
+ backends = extra_args.get("backends")
+ if backends is not None:
+ if not is_list_of(backends, str):
+ raise Exception("Backends value has wrong format")
+
+ advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends
+
+ optimization_targets = extra_args.get("optimization_targets")
+ if not optimization_targets:
+ optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS
+
+ if not is_list_of(optimization_targets, dict):
+ raise Exception("Optimization targets value has wrong format")
+
+ advisor_parameters.update(
+ {
+ "ethos_u_model_optimizations": {
+ "optimizations": [optimization_targets],
+ },
+ }
+ )
+
+ return advisor_parameters