diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-07-21 14:06:03 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-08-19 10:23:23 +0100 |
commit | a8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (patch) | |
tree | 8463b24ba0446a49b3e012477b0834c3b5415b86 /src/mlia/api.py | |
parent | 76ec769ad8f8ed53ec3ff829fdd34d53db8229fd (diff) | |
download | mlia-a8ee1aee3e674c78a77801d1bf2256881ab6b4b9.tar.gz |
MLIA-549 Refactor API module to support several target profiles
- Move target specific details out of API module
- Move common logic for workflow event handler into a
separate class
Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788
Diffstat (limited to 'src/mlia/api.py')
-rw-r--r-- | src/mlia/api.py | 108 |
1 files changed, 34 insertions, 74 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py index 0f950db..024bc98 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -14,28 +14,13 @@ from mlia.core._typing import PathOrFileLike from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext -from mlia.core.events import EventHandler -from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor -from mlia.devices.ethosu.handlers import EthosUEventHandler +from mlia.devices.ethosu.advisor import configure_and_get_ethosu_advisor +from mlia.utils.filesystem import get_target logger = logging.getLogger(__name__) -_DEFAULT_OPTIMIZATION_TARGETS = [ - { - "optimization_type": "pruning", - "optimization_target": 0.5, - "layers_to_optimize": None, - }, - { - "optimization_type": "clustering", - "optimization_target": 32, - "layers_to_optimize": None, - }, -] - - def get_advice( target_profile: str, model: Union[Path, str], @@ -71,7 +56,6 @@ def get_advice( :param backends: A list of backends that should be used for the given target. Default settings will be used if None. - Examples: NB: Before launching MLIA, the logging functionality should be configured! @@ -87,75 +71,51 @@ def get_advice( """ advice_category = AdviceCategory.from_string(category) - config_parameters = _get_config_parameters( - model, target_profile, backends, optimization_targets - ) - event_handlers = _get_event_handlers(output) if context is not None: context.advice_category = advice_category - if context.config_parameters is None: - context.config_parameters = config_parameters - - if context.event_handlers is None: - context.event_handlers = event_handlers - if context is None: context = ExecutionContext( advice_category=advice_category, working_dir=working_dir, - config_parameters=config_parameters, - event_handlers=event_handlers, ) - advisor = _get_advisor(target_profile) - advisor.run(context) - - -def _get_advisor(target: Optional[str]) -> InferenceAdvisor: - """Find appropriate advisor for the target.""" - if not target: - raise Exception("Target is not provided") + advisor = get_advisor( + context, + target_profile, + model, + output, + optimization_targets=optimization_targets, + backends=backends, + ) - return EthosUInferenceAdvisor() + advisor.run(context) -def _get_config_parameters( - model: Union[Path, str], +def get_advisor( + context: ExecutionContext, target_profile: str, - backends: Optional[List[str]], - optimization_targets: Optional[List[Dict[str, Any]]], -) -> Dict[str, Any]: - """Get configuration parameters for the advisor.""" - advisor_parameters: Dict[str, Any] = { - "ethos_u_inference_advisor": { - "model": model, - "device": { - "target_profile": target_profile, - }, - }, + model: Union[Path, str], + output: Optional[PathOrFileLike] = None, + **extra_args: Any, +) -> InferenceAdvisor: + """Find appropriate advisor for the target.""" + target_factories = { + "ethos-u55": configure_and_get_ethosu_advisor, + "ethos-u65": configure_and_get_ethosu_advisor, } - # Specifying backends is optional (default is used) - if backends is not None: - advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends - - if not optimization_targets: - optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS - - advisor_parameters.update( - { - "ethos_u_model_optimizations": { - "optimizations": [ - optimization_targets, - ], - }, - } - ) - - return advisor_parameters - -def _get_event_handlers(output: Optional[PathOrFileLike]) -> List[EventHandler]: - """Return list of the event handlers.""" - return [EthosUEventHandler(output)] + try: + target = get_target(target_profile) + factory_function = target_factories[target] + except KeyError as err: + raise Exception(f"Unsupported profile {target_profile}") from err + + return factory_function( + context, + target_profile, + model, + output, + **extra_args, + ) |