aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/ethos_u
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-02-02 14:02:05 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-02-10 13:45:18 +0000
commit7a661257b6adad0c8f53e32b42ced56a1e7d952f (patch)
tree938ad8578c5b9edc0573e810ce64ce0a5bda3d8c /src/mlia/target/ethos_u
parent50271dee0a84bfc481ce798184f07b5b0b4bc64d (diff)
downloadmlia-7a661257b6adad0c8f53e32b42ced56a1e7d952f.tar.gz
MLIA-769 Expand use of target/backend registries
- Use the target/backend registries to avoid hard-coded names. - Cache target profiles to avoid re-loading them Change-Id: I474b7c9ef23894e1d8a3ea06d13a37652054c62e
Diffstat (limited to 'src/mlia/target/ethos_u')
-rw-r--r--src/mlia/target/ethos_u/__init__.py19
-rw-r--r--src/mlia/target/ethos_u/advisor.py4
-rw-r--r--src/mlia/target/ethos_u/config.py22
3 files changed, 41 insertions, 4 deletions
diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py
index d53be53..6b6777d 100644
--- a/src/mlia/target/ethos_u/__init__.py
+++ b/src/mlia/target/ethos_u/__init__.py
@@ -1,8 +1,23 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U target module."""
+from mlia.backend.corstone import CORSTONE_PRIORITY
+from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
+from mlia.target.ethos_u.config import EthosUConfiguration
+from mlia.target.ethos_u.config import get_default_ethos_u_backends
from mlia.target.registry import registry
from mlia.target.registry import TargetInfo
-registry.register("ethos-u55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
-registry.register("ethos-u65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
+SUPPORTED_BACKENDS_PRIORITY = ["Vela", *CORSTONE_PRIORITY]
+
+
+for ethos_u in ("ethos-u55", "ethos-u65"):
+ registry.register(
+ ethos_u,
+ TargetInfo(
+ supported_backends=SUPPORTED_BACKENDS_PRIORITY,
+ default_backends=get_default_ethos_u_backends(SUPPORTED_BACKENDS_PRIORITY),
+ advisor_factory_func=configure_and_get_ethosu_advisor,
+ target_profile_cls=EthosUConfiguration,
+ ),
+ )
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index 225fd87..5f23fdd 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any
+from typing import cast
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
@@ -25,6 +26,7 @@ from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance
from mlia.target.ethos_u.data_collection import EthosUPerformance
from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent
from mlia.target.ethos_u.handlers import EthosUEventHandler
+from mlia.target.registry import profile
from mlia.utils.types import is_list_of
@@ -96,7 +98,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def _get_device_cfg(self, context: Context) -> EthosUConfiguration:
"""Get device configuration."""
target_profile = self.get_target_profile(context)
- return EthosUConfiguration.load_profile(target_profile)
+ return cast(EthosUConfiguration, profile(target_profile))
def _get_optimization_settings(self, context: Context) -> list[list[dict]]:
"""Get optimization settings."""
diff --git a/src/mlia/target/ethos_u/config.py b/src/mlia/target/ethos_u/config.py
index eb5691d..d1a2c7a 100644
--- a/src/mlia/target/ethos_u/config.py
+++ b/src/mlia/target/ethos_u/config.py
@@ -6,12 +6,13 @@ from __future__ import annotations
import logging
from typing import Any
+from mlia.backend.corstone import is_corstone_backend
+from mlia.backend.manager import get_available_backends
from mlia.backend.vela.compiler import resolve_compiler_config
from mlia.backend.vela.compiler import VelaCompilerOptions
from mlia.target.config import TargetProfile
from mlia.utils.filesystem import get_vela_config
-
logger = logging.getLogger(__name__)
@@ -67,3 +68,22 @@ class EthosUConfiguration(TargetProfile):
def __repr__(self) -> str:
"""Return string representation."""
return f"<Ethos-U configuration target={self.target}>"
+
+
+def get_default_ethos_u_backends(
+ supported_backends_priority_order: list[str],
+) -> list[str]:
+ """Return default backends for Ethos-U targets."""
+ available_backends = get_available_backends()
+
+ default_backends = []
+ corstone_added = False
+ for backend in supported_backends_priority_order:
+ if backend not in available_backends:
+ continue
+ if is_corstone_backend(backend):
+ if corstone_added:
+ continue # only add one Corstone backend
+ corstone_added = True
+ default_backends.append(backend)
+ return default_backends