diff options
Diffstat (limited to 'src/mlia/target/ethos_u')
-rw-r--r-- | src/mlia/target/ethos_u/__init__.py | 19 | ||||
-rw-r--r-- | src/mlia/target/ethos_u/advisor.py | 4 | ||||
-rw-r--r-- | src/mlia/target/ethos_u/config.py | 22 |
3 files changed, 41 insertions, 4 deletions
diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py index d53be53..6b6777d 100644 --- a/src/mlia/target/ethos_u/__init__.py +++ b/src/mlia/target/ethos_u/__init__.py @@ -1,8 +1,23 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U target module.""" +from mlia.backend.corstone import CORSTONE_PRIORITY +from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor +from mlia.target.ethos_u.config import EthosUConfiguration +from mlia.target.ethos_u.config import get_default_ethos_u_backends from mlia.target.registry import registry from mlia.target.registry import TargetInfo -registry.register("ethos-u55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) -registry.register("ethos-u65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) +SUPPORTED_BACKENDS_PRIORITY = ["Vela", *CORSTONE_PRIORITY] + + +for ethos_u in ("ethos-u55", "ethos-u65"): + registry.register( + ethos_u, + TargetInfo( + supported_backends=SUPPORTED_BACKENDS_PRIORITY, + default_backends=get_default_ethos_u_backends(SUPPORTED_BACKENDS_PRIORITY), + advisor_factory_func=configure_and_get_ethosu_advisor, + target_profile_cls=EthosUConfiguration, + ), + ) diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py index 225fd87..5f23fdd 100644 --- a/src/mlia/target/ethos_u/advisor.py +++ b/src/mlia/target/ethos_u/advisor.py @@ -5,6 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Any +from typing import cast from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor @@ -25,6 +26,7 @@ from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance from mlia.target.ethos_u.data_collection import EthosUPerformance from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent from mlia.target.ethos_u.handlers import EthosUEventHandler +from mlia.target.registry import profile from mlia.utils.types import is_list_of @@ -96,7 +98,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): def _get_device_cfg(self, context: Context) -> EthosUConfiguration: """Get device configuration.""" target_profile = self.get_target_profile(context) - return EthosUConfiguration.load_profile(target_profile) + return cast(EthosUConfiguration, profile(target_profile)) def _get_optimization_settings(self, context: Context) -> list[list[dict]]: """Get optimization settings.""" diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py index eb5691d..d1a2c7a 100644 --- a/src/mlia/target/ethos_u/config.py +++ b/src/mlia/target/ethos_u/config.py @@ -6,12 +6,13 @@ from __future__ import annotations import logging from typing import Any +from mlia.backend.corstone import is_corstone_backend +from mlia.backend.manager import get_available_backends from mlia.backend.vela.compiler import resolve_compiler_config from mlia.backend.vela.compiler import VelaCompilerOptions from mlia.target.config import TargetProfile from mlia.utils.filesystem import get_vela_config - logger = logging.getLogger(__name__) @@ -67,3 +68,22 @@ class EthosUConfiguration(TargetProfile): def __repr__(self) -> str: """Return string representation.""" return f"<Ethos-U configuration target={self.target}>" + + +def get_default_ethos_u_backends( + supported_backends_priority_order: list[str], +) -> list[str]: + """Return default backends for Ethos-U targets.""" + available_backends = get_available_backends() + + default_backends = [] + corstone_added = False + for backend in supported_backends_priority_order: + if backend not in available_backends: + continue + if is_corstone_backend(backend): + if corstone_added: + continue # only add one Corstone backend + corstone_added = True + default_backends.append(backend) + return default_backends |