aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/api.py
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-07-21 14:06:03 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-08-19 10:23:23 +0100
commita8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (patch)
tree8463b24ba0446a49b3e012477b0834c3b5415b86 /src/mlia/api.py
parent76ec769ad8f8ed53ec3ff829fdd34d53db8229fd (diff)
downloadmlia-a8ee1aee3e674c78a77801d1bf2256881ab6b4b9.tar.gz
MLIA-549 Refactor API module to support several target profiles
- Move target specific details out of API module - Move common logic for workflow event handler into a separate class Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788
Diffstat (limited to 'src/mlia/api.py')
-rw-r--r--src/mlia/api.py108
1 files changed, 34 insertions, 74 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index 0f950db..024bc98 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -14,28 +14,13 @@ from mlia.core._typing import PathOrFileLike
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
-from mlia.core.events import EventHandler
-from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor
-from mlia.devices.ethosu.handlers import EthosUEventHandler
+from mlia.devices.ethosu.advisor import configure_and_get_ethosu_advisor
+from mlia.utils.filesystem import get_target
logger = logging.getLogger(__name__)
-_DEFAULT_OPTIMIZATION_TARGETS = [
- {
- "optimization_type": "pruning",
- "optimization_target": 0.5,
- "layers_to_optimize": None,
- },
- {
- "optimization_type": "clustering",
- "optimization_target": 32,
- "layers_to_optimize": None,
- },
-]
-
-
def get_advice(
target_profile: str,
model: Union[Path, str],
@@ -71,7 +56,6 @@ def get_advice(
:param backends: A list of backends that should be used for the given
target. Default settings will be used if None.
-
Examples:
NB: Before launching MLIA, the logging functionality should be configured!
@@ -87,75 +71,51 @@ def get_advice(
"""
advice_category = AdviceCategory.from_string(category)
- config_parameters = _get_config_parameters(
- model, target_profile, backends, optimization_targets
- )
- event_handlers = _get_event_handlers(output)
if context is not None:
context.advice_category = advice_category
- if context.config_parameters is None:
- context.config_parameters = config_parameters
-
- if context.event_handlers is None:
- context.event_handlers = event_handlers
-
if context is None:
context = ExecutionContext(
advice_category=advice_category,
working_dir=working_dir,
- config_parameters=config_parameters,
- event_handlers=event_handlers,
)
- advisor = _get_advisor(target_profile)
- advisor.run(context)
-
-
-def _get_advisor(target: Optional[str]) -> InferenceAdvisor:
- """Find appropriate advisor for the target."""
- if not target:
- raise Exception("Target is not provided")
+ advisor = get_advisor(
+ context,
+ target_profile,
+ model,
+ output,
+ optimization_targets=optimization_targets,
+ backends=backends,
+ )
- return EthosUInferenceAdvisor()
+ advisor.run(context)
-def _get_config_parameters(
- model: Union[Path, str],
+def get_advisor(
+ context: ExecutionContext,
target_profile: str,
- backends: Optional[List[str]],
- optimization_targets: Optional[List[Dict[str, Any]]],
-) -> Dict[str, Any]:
- """Get configuration parameters for the advisor."""
- advisor_parameters: Dict[str, Any] = {
- "ethos_u_inference_advisor": {
- "model": model,
- "device": {
- "target_profile": target_profile,
- },
- },
+ model: Union[Path, str],
+ output: Optional[PathOrFileLike] = None,
+ **extra_args: Any,
+) -> InferenceAdvisor:
+ """Find appropriate advisor for the target."""
+ target_factories = {
+ "ethos-u55": configure_and_get_ethosu_advisor,
+ "ethos-u65": configure_and_get_ethosu_advisor,
}
- # Specifying backends is optional (default is used)
- if backends is not None:
- advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends
-
- if not optimization_targets:
- optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS
-
- advisor_parameters.update(
- {
- "ethos_u_model_optimizations": {
- "optimizations": [
- optimization_targets,
- ],
- },
- }
- )
-
- return advisor_parameters
-
-def _get_event_handlers(output: Optional[PathOrFileLike]) -> List[EventHandler]:
- """Return list of the event handlers."""
- return [EthosUEventHandler(output)]
+ try:
+ target = get_target(target_profile)
+ factory_function = target_factories[target]
+ except KeyError as err:
+ raise Exception(f"Unsupported profile {target_profile}") from err
+
+ return factory_function(
+ context,
+ target_profile,
+ model,
+ output,
+ **extra_args,
+ )