From 836efd40317a397761ec8b66e3f4398faac43ad0 Mon Sep 17 00:00:00 2001 From: Annie Tallund Date: Thu, 12 Jan 2023 07:49:06 +0100 Subject: MLIA-770 List all available backends - Rely on target and backend registry for support information - Make above information less Ethos(TM)-U specific Change-Id: I8dbfb84401016412a3d719a84eb592f21d79c46b --- src/mlia/target/ethos_u/performance.py | 10 +++++----- src/mlia/target/registry.py | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) (limited to 'src/mlia/target') diff --git a/src/mlia/target/ethos_u/performance.py b/src/mlia/target/ethos_u/performance.py index e39f4d9..0d791a1 100644 --- a/src/mlia/target/ethos_u/performance.py +++ b/src/mlia/target/ethos_u/performance.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Performance estimation.""" from __future__ import annotations @@ -14,14 +14,14 @@ import mlia.backend.vela.performance as vela_perf from mlia.backend.corstone.performance import DeviceInfo from mlia.backend.corstone.performance import estimate_performance from mlia.backend.corstone.performance import ModelInfo -from mlia.backend.install import is_supported -from mlia.backend.install import supported_backends +from mlia.backend.registry import get_supported_backends from mlia.core.context import Context from mlia.core.performance import PerformanceEstimator from mlia.nn.tensorflow.config import get_tflite_model from mlia.nn.tensorflow.config import ModelConfiguration from mlia.nn.tensorflow.optimizations.select import OptimizationSettings from mlia.target.ethos_u.config import EthosUConfiguration +from mlia.target.registry import is_supported from mlia.utils.logging import log_action @@ -226,7 +226,7 @@ class EthosUPerformanceEstimator( if backend != "Vela" and not is_supported(backend): raise ValueError( f"Unsupported backend '{backend}'. " - f"Only 'Vela' and {supported_backends()} " + f"Only 'Vela' and {get_supported_backends()} " "are supported." ) self.backends = set(backends) @@ -246,7 +246,7 @@ class EthosUPerformanceEstimator( if backend == "Vela": vela_estimator = VelaPerformanceEstimator(self.context, self.device) memory_usage = vela_estimator.estimate(tflite_model) - elif backend in supported_backends(): + elif backend in get_supported_backends(): corstone_estimator = CorstonePerformanceEstimator( self.context, self.device, backend ) diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py index 2d29f1b..325dd04 100644 --- a/src/mlia/target/registry.py +++ b/src/mlia/target/registry.py @@ -32,6 +32,30 @@ def supported_backends(target: str) -> list[str]: return registry.items[target].filter_supported_backends(check_system=False) +def get_backend_to_supported_targets() -> dict[str, list]: + """Get a dict that maps a list of supported targets given backend.""" + targets = dict(registry.items) + supported_backends_dict: dict[str, list] = {} + for target, info in targets.items(): + target_backends = info.supported_backends + for backend in target_backends: + supported_backends_dict.setdefault(backend, []).append(target) + return supported_backends_dict + + +def is_supported(backend: str, target: str | None = None) -> bool: + """Check if the backend (and optionally target) is supported.""" + backends = get_backend_to_supported_targets() + if target is None: + if backend in backends: + return True + return False + try: + return target in backends[backend] + except KeyError: + return False + + def supported_targets(advice: AdviceCategory) -> list[str]: """Get a list of all targets supporting the given advice category.""" return [ -- cgit v1.2.1