aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/devices/ethosu/advisor.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/devices/ethosu/advisor.py')
-rw-r--r--src/mlia/devices/ethosu/advisor.py151
1 files changed, 151 insertions, 0 deletions
diff --git a/src/mlia/devices/ethosu/advisor.py b/src/mlia/devices/ethosu/advisor.py
new file mode 100644
index 0000000..802826b
--- /dev/null
+++ b/src/mlia/devices/ethosu/advisor.py
@@ -0,0 +1,151 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U MLIA module."""
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+from mlia.core.advice_generation import AdviceProducer
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.common import AdviceCategory
+from mlia.core.context import Context
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.mixins import ParameterResolverMixin
+from mlia.core.workflow import DefaultWorkflowExecutor
+from mlia.core.workflow import WorkflowExecutor
+from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer
+from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.config import get_target
+from mlia.devices.ethosu.data_analysis import EthosUDataAnalyzer
+from mlia.devices.ethosu.data_collection import EthosUOperatorCompatibility
+from mlia.devices.ethosu.data_collection import EthosUOptimizationPerformance
+from mlia.devices.ethosu.data_collection import EthosUPerformance
+from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent
+
+
+class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
+ """Ethos-U Inference Advisor."""
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the advisor."""
+ return "ethos_u_inference_advisor"
+
+ def configure(self, context: Context) -> WorkflowExecutor:
+ """Configure advisor execution."""
+ model = self._get_model(context)
+ device = self._get_device(context)
+ backends = self._get_backends(context)
+
+ collectors = self._get_collectors(context, model, device, backends)
+ analyzers = self._get_analyzers()
+ producers = self._get_advice_producers()
+
+ return DefaultWorkflowExecutor(
+ context,
+ collectors,
+ analyzers,
+ producers,
+ before_start_events=[
+ EthosUAdvisorStartedEvent(device=device, model=model),
+ ],
+ )
+
+ def _get_collectors(
+ self,
+ context: Context,
+ model: Path,
+ device: EthosUConfiguration,
+ backends: Optional[List[str]],
+ ) -> List[DataCollector]:
+ """Get collectors."""
+ collectors: List[DataCollector] = []
+
+ if context.any_category_enabled(
+ AdviceCategory.OPERATORS,
+ AdviceCategory.ALL,
+ ):
+ collectors.append(EthosUOperatorCompatibility(model, device))
+
+ if context.category_enabled(AdviceCategory.PERFORMANCE):
+ collectors.append(EthosUPerformance(model, device, backends))
+
+ if context.any_category_enabled(
+ AdviceCategory.OPTIMIZATION,
+ AdviceCategory.ALL,
+ ):
+ optimization_settings = self._get_optimization_settings(context)
+ collectors.append(
+ EthosUOptimizationPerformance(
+ model, device, optimization_settings, backends
+ )
+ )
+
+ return collectors
+
+ @staticmethod
+ def _get_analyzers() -> List[DataAnalyzer]:
+ """Return data analyzers."""
+ return [
+ EthosUDataAnalyzer(),
+ ]
+
+ @staticmethod
+ def _get_advice_producers() -> List[AdviceProducer]:
+ """Return advice producers."""
+ return [
+ EthosUAdviceProducer(),
+ EthosUStaticAdviceProducer(),
+ ]
+
+ def _get_device(self, context: Context) -> EthosUConfiguration:
+ """Get device."""
+ device_params = self.get_parameter(
+ self.name(),
+ "device",
+ expected_type=dict,
+ context=context,
+ )
+
+ try:
+ target_profile = device_params["target_profile"]
+ except KeyError as err:
+ raise Exception("Unable to get device details") from err
+
+ return get_target(target_profile)
+
+ def _get_model(self, context: Context) -> Path:
+ """Get path to the model."""
+ model_param = self.get_parameter(
+ self.name(),
+ "model",
+ expected_type=str,
+ context=context,
+ )
+
+ if not (model := Path(model_param)).exists():
+ raise Exception(f"Path {model} does not exist")
+
+ return model
+
+ def _get_optimization_settings(self, context: Context) -> List[List[dict]]:
+ """Get optimization settings."""
+ return self.get_parameter( # type: ignore
+ EthosUOptimizationPerformance.name(),
+ "optimizations",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )
+
+ def _get_backends(self, context: Context) -> Optional[List[str]]:
+ """Get list of backends."""
+ return self.get_parameter( # type: ignore
+ self.name(),
+ "backends",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )