aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2023-01-12 07:49:06 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-02-08 15:23:29 +0000
commit836efd40317a397761ec8b66e3f4398faac43ad0 (patch)
tree5133ffd51d8d6772551333a4b337d36a501a8a91 /src/mlia/target
parenta4fb8c72f15146c95df16c25e75f03344e9814fd (diff)
downloadmlia-836efd40317a397761ec8b66e3f4398faac43ad0.tar.gz
MLIA-770 List all available backends
- Rely on target and backend registry for support information - Make above information less Ethos(TM)-U specific Change-Id: I8dbfb84401016412a3d719a84eb592f21d79c46b
Diffstat (limited to 'src/mlia/target')
-rw-r--r--src/mlia/target/ethos_u/performance.py10
-rw-r--r--src/mlia/target/registry.py24
2 files changed, 29 insertions, 5 deletions
diff --git a/src/mlia/target/ethos_u/performance.py b/src/mlia/target/ethos_u/performance.py
index e39f4d9..0d791a1 100644
--- a/src/mlia/target/ethos_u/performance.py
+++ b/src/mlia/target/ethos_u/performance.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Performance estimation."""
from __future__ import annotations
@@ -14,14 +14,14 @@ import mlia.backend.vela.performance as vela_perf
from mlia.backend.corstone.performance import DeviceInfo
from mlia.backend.corstone.performance import estimate_performance
from mlia.backend.corstone.performance import ModelInfo
-from mlia.backend.install import is_supported
-from mlia.backend.install import supported_backends
+from mlia.backend.registry import get_supported_backends
from mlia.core.context import Context
from mlia.core.performance import PerformanceEstimator
from mlia.nn.tensorflow.config import get_tflite_model
from mlia.nn.tensorflow.config import ModelConfiguration
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
from mlia.target.ethos_u.config import EthosUConfiguration
+from mlia.target.registry import is_supported
from mlia.utils.logging import log_action
@@ -226,7 +226,7 @@ class EthosUPerformanceEstimator(
if backend != "Vela" and not is_supported(backend):
raise ValueError(
f"Unsupported backend '{backend}'. "
- f"Only 'Vela' and {supported_backends()} "
+ f"Only 'Vela' and {get_supported_backends()} "
"are supported."
)
self.backends = set(backends)
@@ -246,7 +246,7 @@ class EthosUPerformanceEstimator(
if backend == "Vela":
vela_estimator = VelaPerformanceEstimator(self.context, self.device)
memory_usage = vela_estimator.estimate(tflite_model)
- elif backend in supported_backends():
+ elif backend in get_supported_backends():
corstone_estimator = CorstonePerformanceEstimator(
self.context, self.device, backend
)
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py
index 2d29f1b..325dd04 100644
--- a/src/mlia/target/registry.py
+++ b/src/mlia/target/registry.py
@@ -32,6 +32,30 @@ def supported_backends(target: str) -> list[str]:
return registry.items[target].filter_supported_backends(check_system=False)
+def get_backend_to_supported_targets() -> dict[str, list]:
+ """Get a dict that maps a list of supported targets given backend."""
+ targets = dict(registry.items)
+ supported_backends_dict: dict[str, list] = {}
+ for target, info in targets.items():
+ target_backends = info.supported_backends
+ for backend in target_backends:
+ supported_backends_dict.setdefault(backend, []).append(target)
+ return supported_backends_dict
+
+
+def is_supported(backend: str, target: str | None = None) -> bool:
+ """Check if the backend (and optionally target) is supported."""
+ backends = get_backend_to_supported_targets()
+ if target is None:
+ if backend in backends:
+ return True
+ return False
+ try:
+ return target in backends[backend]
+ except KeyError:
+ return False
+
+
def supported_targets(advice: AdviceCategory) -> list[str]:
"""Get a list of all targets supporting the given advice category."""
return [