diff options
Diffstat (limited to 'src/mlia/api.py')
-rw-r--r-- | src/mlia/api.py | 21 |
1 files changed, 4 insertions, 17 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py index fd5fc13..7adae48 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -10,10 +10,8 @@ from typing import Any from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext -from mlia.target.config import get_target -from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor -from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor -from mlia.target.tosa.advisor import configure_and_get_tosa_advisor +from mlia.target.registry import profile +from mlia.target.registry import registry as target_registry logger = logging.getLogger(__name__) @@ -84,19 +82,8 @@ def get_advisor( **extra_args: Any, ) -> InferenceAdvisor: """Find appropriate advisor for the target.""" - target_factories = { - "ethos-u55": configure_and_get_ethosu_advisor, - "ethos-u65": configure_and_get_ethosu_advisor, - "tosa": configure_and_get_tosa_advisor, - "cortex-a": configure_and_get_cortexa_advisor, - } - - try: - target = get_target(target_profile) - factory_function = target_factories[target] - except KeyError as err: - raise Exception(f"Unsupported profile {target_profile}") from err - + target = profile(target_profile).target + factory_function = target_registry.items[target].advisor_factory_func return factory_function( context, target_profile, |