aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/devices/ethosu/advisor.py
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-07-21 14:06:03 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-08-19 10:23:23 +0100
commita8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (patch)
tree8463b24ba0446a49b3e012477b0834c3b5415b86 /src/mlia/devices/ethosu/advisor.py
parent76ec769ad8f8ed53ec3ff829fdd34d53db8229fd (diff)
downloadmlia-a8ee1aee3e674c78a77801d1bf2256881ab6b4b9.tar.gz
MLIA-549 Refactor API module to support several target profiles
- Move target specific details out of API module - Move common logic for workflow event handler into a separate class Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788
Diffstat (limited to 'src/mlia/devices/ethosu/advisor.py')
-rw-r--r--src/mlia/devices/ethosu/advisor.py159
1 files changed, 99 insertions, 60 deletions
diff --git a/src/mlia/devices/ethosu/advisor.py b/src/mlia/devices/ethosu/advisor.py
index e93858f..b7b8305 100644
--- a/src/mlia/devices/ethosu/advisor.py
+++ b/src/mlia/devices/ethosu/advisor.py
@@ -2,18 +2,22 @@
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U MLIA module."""
from pathlib import Path
+from typing import Any
+from typing import Dict
from typing import List
from typing import Optional
+from typing import Union
+from mlia.core._typing import PathOrFileLike
from mlia.core.advice_generation import AdviceProducer
+from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
-from mlia.core.mixins import ParameterResolverMixin
-from mlia.core.workflow import DefaultWorkflowExecutor
-from mlia.core.workflow import WorkflowExecutor
+from mlia.core.events import Event
from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer
from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer
from mlia.devices.ethosu.config import EthosUConfiguration
@@ -23,10 +27,12 @@ from mlia.devices.ethosu.data_collection import EthosUOperatorCompatibility
from mlia.devices.ethosu.data_collection import EthosUOptimizationPerformance
from mlia.devices.ethosu.data_collection import EthosUPerformance
from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent
+from mlia.devices.ethosu.handlers import EthosUEventHandler
from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.utils.types import is_list_of
-class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
+class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
"""Ethos-U Inference Advisor."""
@classmethod
@@ -34,34 +40,12 @@ class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
"""Return name of the advisor."""
return "ethos_u_inference_advisor"
- def configure(self, context: Context) -> WorkflowExecutor:
- """Configure advisor execution."""
- model = self._get_model(context)
+ def get_collectors(self, context: Context) -> List[DataCollector]:
+ """Return list of the data collectors."""
+ model = self.get_model(context)
device = self._get_device(context)
backends = self._get_backends(context)
- collectors = self._get_collectors(context, model, device, backends)
- analyzers = self._get_analyzers()
- producers = self._get_advice_producers()
-
- return DefaultWorkflowExecutor(
- context,
- collectors,
- analyzers,
- producers,
- before_start_events=[
- EthosUAdvisorStartedEvent(device=device, model=model),
- ],
- )
-
- def _get_collectors(
- self,
- context: Context,
- model: Path,
- device: EthosUConfiguration,
- backends: Optional[List[str]],
- ) -> List[DataCollector]:
- """Get collectors."""
collectors: List[DataCollector] = []
if AdviceCategory.OPERATORS in context.advice_category:
@@ -91,51 +75,34 @@ class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
return collectors
- @staticmethod
- def _get_analyzers() -> List[DataAnalyzer]:
- """Return data analyzers."""
+ def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
+ """Return list of the data analyzers."""
return [
EthosUDataAnalyzer(),
]
- @staticmethod
- def _get_advice_producers() -> List[AdviceProducer]:
- """Return advice producers."""
+ def get_producers(self, context: Context) -> List[AdviceProducer]:
+ """Return list of the advice producers."""
return [
EthosUAdviceProducer(),
EthosUStaticAdviceProducer(),
]
+ def get_events(self, context: Context) -> List[Event]:
+ """Return list of the startup events."""
+ model = self.get_model(context)
+ device = self._get_device(context)
+
+ return [
+ EthosUAdvisorStartedEvent(device=device, model=model),
+ ]
+
def _get_device(self, context: Context) -> EthosUConfiguration:
"""Get device."""
- device_params = self.get_parameter(
- self.name(),
- "device",
- expected_type=dict,
- context=context,
- )
-
- try:
- target_profile = device_params["target_profile"]
- except KeyError as err:
- raise Exception("Unable to get device details") from err
+ target_profile = self.get_target_profile(context)
return get_target(target_profile)
- def _get_model(self, context: Context) -> Path:
- """Get path to the model."""
- model_param = self.get_parameter(
- self.name(),
- "model",
- expected_type=str,
- context=context,
- )
-
- if not (model := Path(model_param)).exists():
- raise Exception(f"Path {model} does not exist")
-
- return model
-
def _get_optimization_settings(self, context: Context) -> List[List[dict]]:
"""Get optimization settings."""
return self.get_parameter( # type: ignore
@@ -155,3 +122,75 @@ class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
expected=False,
context=context,
)
+
+
+def configure_and_get_ethosu_advisor(
+ context: ExecutionContext,
+ target_profile: str,
+ model: Union[Path, str],
+ output: Optional[PathOrFileLike] = None,
+ **extra_args: Any,
+) -> InferenceAdvisor:
+ """Create and configure Ethos-U advisor."""
+ if context.event_handlers is None:
+ context.event_handlers = [EthosUEventHandler(output)]
+
+ if context.config_parameters is None:
+ context.config_parameters = _get_config_parameters(
+ model, target_profile, **extra_args
+ )
+
+ return EthosUInferenceAdvisor()
+
+
+_DEFAULT_OPTIMIZATION_TARGETS = [
+ {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ },
+ {
+ "optimization_type": "clustering",
+ "optimization_target": 32,
+ "layers_to_optimize": None,
+ },
+]
+
+
+def _get_config_parameters(
+ model: Union[Path, str],
+ target_profile: str,
+ **extra_args: Any,
+) -> Dict[str, Any]:
+ """Get configuration parameters for the advisor."""
+ advisor_parameters: Dict[str, Any] = {
+ "ethos_u_inference_advisor": {
+ "model": model,
+ "target_profile": target_profile,
+ },
+ }
+
+ # Specifying backends is optional (default is used)
+ backends = extra_args.get("backends")
+ if backends is not None:
+ if not is_list_of(backends, str):
+ raise Exception("Backends value has wrong format")
+
+ advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends
+
+ optimization_targets = extra_args.get("optimization_targets")
+ if not optimization_targets:
+ optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS
+
+ if not is_list_of(optimization_targets, dict):
+ raise Exception("Optimization targets value has wrong format")
+
+ advisor_parameters.update(
+ {
+ "ethos_u_model_optimizations": {
+ "optimizations": [optimization_targets],
+ },
+ }
+ )
+
+ return advisor_parameters