diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-12-14 11:20:11 +0000 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-01-04 10:11:33 +0000 |
commit | dcd0bd31985c27e1d07333351b26cf8ad12ad1fd (patch) | |
tree | a3388ff5f91e7cdc7ec41271a1a76cdbfae38ece /src | |
parent | 4b4cf29cb1e7d917ae001e258ff01f7846c34778 (diff) | |
download | mlia-dcd0bd31985c27e1d07333351b26cf8ad12ad1fd.tar.gz |
MLIA-589 Create an API to get target information
Change-Id: Ieeaa9188ea1e29e2ccaad7475d457bce71e3140d
Diffstat (limited to 'src')
-rw-r--r-- | src/mlia/backend/__init__.py | 9 | ||||
-rw-r--r-- | src/mlia/backend/armnn_tflite_delegate/__init__.py | 16 | ||||
-rw-r--r-- | src/mlia/backend/armnn_tflite_delegate/compat.py (renamed from src/mlia/target/cortex_a/operator_compatibility.py) | 0 | ||||
-rw-r--r-- | src/mlia/backend/config.py | 82 | ||||
-rw-r--r-- | src/mlia/backend/corstone/__init__.py | 22 | ||||
-rw-r--r-- | src/mlia/backend/registry.py | 8 | ||||
-rw-r--r-- | src/mlia/backend/tosa_checker/__init__.py | 14 | ||||
-rw-r--r-- | src/mlia/backend/vela/__init__.py | 23 | ||||
-rw-r--r-- | src/mlia/cli/main.py | 9 | ||||
-rw-r--r-- | src/mlia/cli/options.py | 2 | ||||
-rw-r--r-- | src/mlia/target/__init__.py | 6 | ||||
-rw-r--r-- | src/mlia/target/config.py | 36 | ||||
-rw-r--r-- | src/mlia/target/cortex_a/__init__.py | 4 | ||||
-rw-r--r-- | src/mlia/target/cortex_a/operators.py | 6 | ||||
-rw-r--r-- | src/mlia/target/ethos_u/__init__.py | 5 | ||||
-rw-r--r-- | src/mlia/target/registry.py | 34 | ||||
-rw-r--r-- | src/mlia/target/tosa/__init__.py | 4 | ||||
-rw-r--r-- | src/mlia/utils/registry.py | 31 |
18 files changed, 302 insertions, 9 deletions
diff --git a/src/mlia/backend/__init__.py b/src/mlia/backend/__init__.py index 745aa1b..2d1da70 100644 --- a/src/mlia/backend/__init__.py +++ b/src/mlia/backend/__init__.py @@ -1,3 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Backends module.""" +"""Backend module.""" +# Make sure all targets are registered with the registry by importing the +# sub-modules +# flake8: noqa +from mlia.backend import armnn_tflite_delegate +from mlia.backend import corstone +from mlia.backend import tosa_checker +from mlia.backend import vela diff --git a/src/mlia/backend/armnn_tflite_delegate/__init__.py b/src/mlia/backend/armnn_tflite_delegate/__init__.py new file mode 100644 index 0000000..6d5af42 --- /dev/null +++ b/src/mlia/backend/armnn_tflite_delegate/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Arm NN TensorFlow Lite delegate backend module.""" +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.registry import registry +from mlia.core.common import AdviceCategory + +registry.register( + "ArmNNTFLiteDelegate", + BackendConfiguration( + supported_advice=[AdviceCategory.OPERATORS], + supported_systems=None, + backend_type=BackendType.BUILTIN, + ), +) diff --git a/src/mlia/target/cortex_a/operator_compatibility.py b/src/mlia/backend/armnn_tflite_delegate/compat.py index c474e75..c474e75 100644 --- a/src/mlia/target/cortex_a/operator_compatibility.py +++ b/src/mlia/backend/armnn_tflite_delegate/compat.py diff --git a/src/mlia/backend/config.py b/src/mlia/backend/config.py new file mode 100644 index 0000000..8d14b28 --- /dev/null +++ b/src/mlia/backend/config.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Backend config module.""" +from __future__ import annotations + +import platform +from enum import auto +from enum import Enum +from typing import cast + +from mlia.core.common import AdviceCategory + + +class System(Enum): + """Enum of system configurations (e.g. OS and architecture).""" + + LINUX_AMD64 = ("Linux", "x86_64") + LINUX_AARCH64 = ("Linux", "aarch64") + WINDOWS_AMD64 = ("Windows", "AMD64") + WINDOWS_AARCH64 = ("Windows", "ARM64") + DARWIN_AARCH64 = ("Darwin", "arm64") + CURRENT = (platform.system(), platform.machine()) + + def __init__( + self, system: str = platform.system(), machine: str = platform.machine() + ) -> None: + """Set the system parameters (defaults to the current system).""" + self.system = system.lower() + self.machine = machine.lower() + + def __eq__(self, other: object) -> bool: + """ + Compare two System instances for equality. + + Raises a TypeError if the input is not a System. + """ + if isinstance(other, self.__class__): + return self.system == other.system and self.machine == other.machine + return False + + def is_compatible(self) -> bool: + """Check if this system is compatible with the current system.""" + return self == self.CURRENT + + +class BackendType(Enum): + """Define the type of the backend (builtin, wheel file etc).""" + + BUILTIN = auto() + WHEEL = auto() + CUSTOM = auto() + + +class BackendConfiguration: + """Base class for backend configurations.""" + + def __init__( + self, + supported_advice: list[AdviceCategory], + supported_systems: list[System] | None, + backend_type: BackendType, + ) -> None: + """Set up basic information about the backend.""" + self.supported_advice = supported_advice + self.supported_systems = supported_systems + self.type = backend_type + + def __str__(self) -> str: + """List supported advice.""" + return ", ".join(cast(str, adv.name).lower() for adv in self.supported_advice) + + def is_supported( + self, advice: AdviceCategory | None = None, check_system: bool = False + ) -> bool: + """Check backend supports the current system and advice.""" + is_system_supported = ( + not self.supported_systems + or not check_system + or any(sys.is_compatible() for sys in self.supported_systems) + ) + is_advice_supported = advice is None or advice in self.supported_advice + return is_system_supported and is_advice_supported diff --git a/src/mlia/backend/corstone/__init__.py b/src/mlia/backend/corstone/__init__.py index a1eac14..f89da63 100644 --- a/src/mlia/backend/corstone/__init__.py +++ b/src/mlia/backend/corstone/__init__.py @@ -1,3 +1,25 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Corstone backend module.""" +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.backend.registry import registry +from mlia.core.common import AdviceCategory + +registry.register( + "Corstone-300", + BackendConfiguration( + supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], + supported_systems=[System.LINUX_AMD64], + backend_type=BackendType.CUSTOM, + ), +) +registry.register( + "Corstone-310", + BackendConfiguration( + supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], + supported_systems=[System.LINUX_AMD64], + backend_type=BackendType.CUSTOM, + ), +) diff --git a/src/mlia/backend/registry.py b/src/mlia/backend/registry.py new file mode 100644 index 0000000..6a0da74 --- /dev/null +++ b/src/mlia/backend/registry.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Backend module.""" +from mlia.backend.config import BackendConfiguration +from mlia.utils.registry import Registry + +# All supported targets are required to be registered here. +registry = Registry[BackendConfiguration]() diff --git a/src/mlia/backend/tosa_checker/__init__.py b/src/mlia/backend/tosa_checker/__init__.py index cec210d..19fc8be 100644 --- a/src/mlia/backend/tosa_checker/__init__.py +++ b/src/mlia/backend/tosa_checker/__init__.py @@ -1,3 +1,17 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA checker backend module.""" +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.backend.registry import registry +from mlia.core.common import AdviceCategory + +registry.register( + "TOSA-Checker", + BackendConfiguration( + supported_advice=[AdviceCategory.OPERATORS], + supported_systems=[System.LINUX_AMD64], + backend_type=BackendType.WHEEL, + ), +) diff --git a/src/mlia/backend/vela/__init__.py b/src/mlia/backend/vela/__init__.py index 6ea0c21..38a623e 100644 --- a/src/mlia/backend/vela/__init__.py +++ b/src/mlia/backend/vela/__init__.py @@ -1,3 +1,26 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Vela backend module.""" +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.backend.registry import registry +from mlia.core.common import AdviceCategory + +registry.register( + "Vela", + BackendConfiguration( + supported_advice=[ + AdviceCategory.OPERATORS, + AdviceCategory.PERFORMANCE, + AdviceCategory.OPTIMIZATION, + ], + supported_systems=[ + System.LINUX_AMD64, + System.LINUX_AARCH64, + System.WINDOWS_AMD64, + System.WINDOWS_AARCH64, + ], + backend_type=BackendType.BUILTIN, + ), +) diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 98fdb63..ac60308 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -12,6 +12,7 @@ from pathlib import Path from mlia import __version__ from mlia.backend.errors import BackendUnavailableError +from mlia.backend.registry import registry as backend_registry from mlia.cli.commands import all_tests from mlia.cli.commands import backend_install from mlia.cli.commands import backend_list @@ -36,6 +37,7 @@ from mlia.cli.options import add_tflite_model_options from mlia.core.context import ExecutionContext from mlia.core.errors import ConfigurationError from mlia.core.errors import InternalError +from mlia.target.registry import registry as target_registry logger = logging.getLogger(__name__) @@ -46,11 +48,10 @@ ML Inference Advisor {__version__} Help the design and optimization of neural network models for efficient inference on a target CPU and NPU Supported targets: +{target_registry} - - Cortex-A <op compatibility> - - Ethos-U55 <op compatibility, perf estimation, model opt> - - Ethos-U65 <op compatibility, perf estimation, model opt> - - TOSA <op compatibility> +Supported backends: +{backend_registry} """.strip() diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index 5eab9aa..8ea4250 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -38,7 +38,7 @@ def add_target_options( help="Target profile that will set the target options " "such as target, mac value, memory mode, etc. " f"For the values associated with each target profile " - f" please refer to the documenation {default_help}.", + f" please refer to the documentation {default_help}.", ) diff --git a/src/mlia/target/__init__.py b/src/mlia/target/__init__.py index 2370221..a9979c6 100644 --- a/src/mlia/target/__init__.py +++ b/src/mlia/target/__init__.py @@ -1,3 +1,9 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Target module.""" +# Make sure all targets are registered with the registry by importing the +# sub-modules +# flake8: noqa +from mlia.target import cortex_a +from mlia.target import ethos_u +from mlia.target import tosa diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index 7ab6b43..f257784 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -1,6 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """IP configuration module.""" +from __future__ import annotations + +from dataclasses import dataclass + +from mlia.backend.registry import registry as backend_registry +from mlia.core.common import AdviceCategory class IPConfiguration: # pylint: disable=too-few-public-methods @@ -9,3 +15,33 @@ class IPConfiguration: # pylint: disable=too-few-public-methods def __init__(self, target: str) -> None: """Init IP configuration instance.""" self.target = target + + +@dataclass +class TargetInfo: + """Collect information about supported targets.""" + + supported_backends: list[str] + + def __str__(self) -> str: + """List supported backends.""" + return ", ".join(sorted(self.supported_backends)) + + def is_supported( + self, advice: AdviceCategory | None = None, check_system: bool = False + ) -> bool: + """Check if any of the supported backends support this kind of advice.""" + return any( + backend_registry.items[name].is_supported(advice, check_system) + for name in self.supported_backends + ) + + def filter_supported_backends( + self, advice: AdviceCategory | None = None, check_system: bool = False + ) -> list[str]: + """Get the list of supported backends filtered by the given arguments.""" + return [ + name + for name in self.supported_backends + if backend_registry.items[name].is_supported(advice, check_system) + ] diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py index fe01835..9b0e611 100644 --- a/src/mlia/target/cortex_a/__init__.py +++ b/src/mlia/target/cortex_a/__init__.py @@ -1,3 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cortex-A target module.""" +from mlia.target.registry import registry +from mlia.target.registry import TargetInfo + +registry.register("Cortex-A", TargetInfo(["ArmNNTFLiteDelegate"])) diff --git a/src/mlia/target/cortex_a/operators.py b/src/mlia/target/cortex_a/operators.py index 91f1886..ae611e5 100644 --- a/src/mlia/target/cortex_a/operators.py +++ b/src/mlia/target/cortex_a/operators.py @@ -9,12 +9,12 @@ from pathlib import Path from typing import Any from typing import ClassVar +from mlia.backend.armnn_tflite_delegate.compat import ( + ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT, +) from mlia.nn.tensorflow.tflite_graph import Op from mlia.nn.tensorflow.tflite_graph import parse_subgraphs from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION -from mlia.target.cortex_a.operator_compatibility import ( - ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT, -) @dataclass diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py index 503919d..3c92ae5 100644 --- a/src/mlia/target/ethos_u/__init__.py +++ b/src/mlia/target/ethos_u/__init__.py @@ -1,3 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U target module.""" +from mlia.target.registry import registry +from mlia.target.registry import TargetInfo + +registry.register("Ethos-U55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) +registry.register("Ethos-U65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py new file mode 100644 index 0000000..6b33084 --- /dev/null +++ b/src/mlia/target/registry.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Target module.""" +from __future__ import annotations + +from mlia.backend.registry import registry as backend_registry +from mlia.core.common import AdviceCategory +from mlia.target.config import TargetInfo +from mlia.utils.registry import Registry + +# All supported targets are required to be registered here. +registry = Registry[TargetInfo]() + + +def supported_advice(target: str) -> list[AdviceCategory]: + """Get a list of supported advice for the given target.""" + advice: set[AdviceCategory] = set() + for supported_backend in registry.items[target].supported_backends: + advice.update(backend_registry.items[supported_backend].supported_advice) + return list(advice) + + +def supported_backends(target: str) -> list[str]: + """Get a list of backends supported by the given target.""" + return registry.items[target].filter_supported_backends(check_system=False) + + +def supported_targets(advice: AdviceCategory) -> list[str]: + """Get a list of all targets supporting the given advice category.""" + return [ + name + for name, info in registry.items.items() + if info.is_supported(advice, check_system=False) + ] diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py index 762c831..33c9cf2 100644 --- a/src/mlia/target/tosa/__init__.py +++ b/src/mlia/target/tosa/__init__.py @@ -1,3 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA target module.""" +from mlia.target.registry import registry +from mlia.target.registry import TargetInfo + +registry.register("TOSA", TargetInfo(["TOSA-Checker"])) diff --git a/src/mlia/utils/registry.py b/src/mlia/utils/registry.py new file mode 100644 index 0000000..9b25a81 --- /dev/null +++ b/src/mlia/utils/registry.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Generic registry class.""" +from __future__ import annotations + +from typing import Generic +from typing import TypeVar + +T = TypeVar("T") + + +class Registry(Generic[T]): + """Generic registry for name-config pairs.""" + + def __init__(self) -> None: + """Create an empty registry.""" + self.items: dict[str, T] = {} + + def __str__(self) -> str: + """List all registered items.""" + return "\n".join( + f"- {name}: {item}" + for name, item in sorted(self.items.items(), key=lambda v: v[0]) + ) + + def register(self, name: str, item: T) -> bool: + """Register an item: returns `False` if already registered.""" + if name in self.items: + return False # already registered + self.items[name] = item + return True |