diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-12-14 11:20:11 +0000 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-01-04 10:11:33 +0000 |
commit | dcd0bd31985c27e1d07333351b26cf8ad12ad1fd (patch) | |
tree | a3388ff5f91e7cdc7ec41271a1a76cdbfae38ece | |
parent | 4b4cf29cb1e7d917ae001e258ff01f7846c34778 (diff) | |
download | mlia-dcd0bd31985c27e1d07333351b26cf8ad12ad1fd.tar.gz |
MLIA-589 Create an API to get target information
Change-Id: Ieeaa9188ea1e29e2ccaad7475d457bce71e3140d
26 files changed, 588 insertions, 14 deletions
diff --git a/src/mlia/backend/__init__.py b/src/mlia/backend/__init__.py index 745aa1b..2d1da70 100644 --- a/src/mlia/backend/__init__.py +++ b/src/mlia/backend/__init__.py @@ -1,3 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Backends module.""" +"""Backend module.""" +# Make sure all targets are registered with the registry by importing the +# sub-modules +# flake8: noqa +from mlia.backend import armnn_tflite_delegate +from mlia.backend import corstone +from mlia.backend import tosa_checker +from mlia.backend import vela diff --git a/src/mlia/backend/armnn_tflite_delegate/__init__.py b/src/mlia/backend/armnn_tflite_delegate/__init__.py new file mode 100644 index 0000000..6d5af42 --- /dev/null +++ b/src/mlia/backend/armnn_tflite_delegate/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Arm NN TensorFlow Lite delegate backend module.""" +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.registry import registry +from mlia.core.common import AdviceCategory + +registry.register( + "ArmNNTFLiteDelegate", + BackendConfiguration( + supported_advice=[AdviceCategory.OPERATORS], + supported_systems=None, + backend_type=BackendType.BUILTIN, + ), +) diff --git a/src/mlia/target/cortex_a/operator_compatibility.py b/src/mlia/backend/armnn_tflite_delegate/compat.py index c474e75..c474e75 100644 --- a/src/mlia/target/cortex_a/operator_compatibility.py +++ b/src/mlia/backend/armnn_tflite_delegate/compat.py diff --git a/src/mlia/backend/config.py b/src/mlia/backend/config.py new file mode 100644 index 0000000..8d14b28 --- /dev/null +++ b/src/mlia/backend/config.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Backend config module.""" +from __future__ import annotations + +import platform +from enum import auto +from enum import Enum +from typing import cast + +from mlia.core.common import AdviceCategory + + +class System(Enum): + """Enum of system configurations (e.g. OS and architecture).""" + + LINUX_AMD64 = ("Linux", "x86_64") + LINUX_AARCH64 = ("Linux", "aarch64") + WINDOWS_AMD64 = ("Windows", "AMD64") + WINDOWS_AARCH64 = ("Windows", "ARM64") + DARWIN_AARCH64 = ("Darwin", "arm64") + CURRENT = (platform.system(), platform.machine()) + + def __init__( + self, system: str = platform.system(), machine: str = platform.machine() + ) -> None: + """Set the system parameters (defaults to the current system).""" + self.system = system.lower() + self.machine = machine.lower() + + def __eq__(self, other: object) -> bool: + """ + Compare two System instances for equality. + + Raises a TypeError if the input is not a System. + """ + if isinstance(other, self.__class__): + return self.system == other.system and self.machine == other.machine + return False + + def is_compatible(self) -> bool: + """Check if this system is compatible with the current system.""" + return self == self.CURRENT + + +class BackendType(Enum): + """Define the type of the backend (builtin, wheel file etc).""" + + BUILTIN = auto() + WHEEL = auto() + CUSTOM = auto() + + +class BackendConfiguration: + """Base class for backend configurations.""" + + def __init__( + self, + supported_advice: list[AdviceCategory], + supported_systems: list[System] | None, + backend_type: BackendType, + ) -> None: + """Set up basic information about the backend.""" + self.supported_advice = supported_advice + self.supported_systems = supported_systems + self.type = backend_type + + def __str__(self) -> str: + """List supported advice.""" + return ", ".join(cast(str, adv.name).lower() for adv in self.supported_advice) + + def is_supported( + self, advice: AdviceCategory | None = None, check_system: bool = False + ) -> bool: + """Check backend supports the current system and advice.""" + is_system_supported = ( + not self.supported_systems + or not check_system + or any(sys.is_compatible() for sys in self.supported_systems) + ) + is_advice_supported = advice is None or advice in self.supported_advice + return is_system_supported and is_advice_supported diff --git a/src/mlia/backend/corstone/__init__.py b/src/mlia/backend/corstone/__init__.py index a1eac14..f89da63 100644 --- a/src/mlia/backend/corstone/__init__.py +++ b/src/mlia/backend/corstone/__init__.py @@ -1,3 +1,25 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Corstone backend module.""" +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.backend.registry import registry +from mlia.core.common import AdviceCategory + +registry.register( + "Corstone-300", + BackendConfiguration( + supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], + supported_systems=[System.LINUX_AMD64], + backend_type=BackendType.CUSTOM, + ), +) +registry.register( + "Corstone-310", + BackendConfiguration( + supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], + supported_systems=[System.LINUX_AMD64], + backend_type=BackendType.CUSTOM, + ), +) diff --git a/src/mlia/backend/registry.py b/src/mlia/backend/registry.py new file mode 100644 index 0000000..6a0da74 --- /dev/null +++ b/src/mlia/backend/registry.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Backend module.""" +from mlia.backend.config import BackendConfiguration +from mlia.utils.registry import Registry + +# All supported targets are required to be registered here. +registry = Registry[BackendConfiguration]() diff --git a/src/mlia/backend/tosa_checker/__init__.py b/src/mlia/backend/tosa_checker/__init__.py index cec210d..19fc8be 100644 --- a/src/mlia/backend/tosa_checker/__init__.py +++ b/src/mlia/backend/tosa_checker/__init__.py @@ -1,3 +1,17 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA checker backend module.""" +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.backend.registry import registry +from mlia.core.common import AdviceCategory + +registry.register( + "TOSA-Checker", + BackendConfiguration( + supported_advice=[AdviceCategory.OPERATORS], + supported_systems=[System.LINUX_AMD64], + backend_type=BackendType.WHEEL, + ), +) diff --git a/src/mlia/backend/vela/__init__.py b/src/mlia/backend/vela/__init__.py index 6ea0c21..38a623e 100644 --- a/src/mlia/backend/vela/__init__.py +++ b/src/mlia/backend/vela/__init__.py @@ -1,3 +1,26 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Vela backend module.""" +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.backend.registry import registry +from mlia.core.common import AdviceCategory + +registry.register( + "Vela", + BackendConfiguration( + supported_advice=[ + AdviceCategory.OPERATORS, + AdviceCategory.PERFORMANCE, + AdviceCategory.OPTIMIZATION, + ], + supported_systems=[ + System.LINUX_AMD64, + System.LINUX_AARCH64, + System.WINDOWS_AMD64, + System.WINDOWS_AARCH64, + ], + backend_type=BackendType.BUILTIN, + ), +) diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 98fdb63..ac60308 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -12,6 +12,7 @@ from pathlib import Path from mlia import __version__ from mlia.backend.errors import BackendUnavailableError +from mlia.backend.registry import registry as backend_registry from mlia.cli.commands import all_tests from mlia.cli.commands import backend_install from mlia.cli.commands import backend_list @@ -36,6 +37,7 @@ from mlia.cli.options import add_tflite_model_options from mlia.core.context import ExecutionContext from mlia.core.errors import ConfigurationError from mlia.core.errors import InternalError +from mlia.target.registry import registry as target_registry logger = logging.getLogger(__name__) @@ -46,11 +48,10 @@ ML Inference Advisor {__version__} Help the design and optimization of neural network models for efficient inference on a target CPU and NPU Supported targets: +{target_registry} - - Cortex-A <op compatibility> - - Ethos-U55 <op compatibility, perf estimation, model opt> - - Ethos-U65 <op compatibility, perf estimation, model opt> - - TOSA <op compatibility> +Supported backends: +{backend_registry} """.strip() diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index 5eab9aa..8ea4250 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -38,7 +38,7 @@ def add_target_options( help="Target profile that will set the target options " "such as target, mac value, memory mode, etc. " f"For the values associated with each target profile " - f" please refer to the documenation {default_help}.", + f" please refer to the documentation {default_help}.", ) diff --git a/src/mlia/target/__init__.py b/src/mlia/target/__init__.py index 2370221..a9979c6 100644 --- a/src/mlia/target/__init__.py +++ b/src/mlia/target/__init__.py @@ -1,3 +1,9 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Target module.""" +# Make sure all targets are registered with the registry by importing the +# sub-modules +# flake8: noqa +from mlia.target import cortex_a +from mlia.target import ethos_u +from mlia.target import tosa diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index 7ab6b43..f257784 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -1,6 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """IP configuration module.""" +from __future__ import annotations + +from dataclasses import dataclass + +from mlia.backend.registry import registry as backend_registry +from mlia.core.common import AdviceCategory class IPConfiguration: # pylint: disable=too-few-public-methods @@ -9,3 +15,33 @@ class IPConfiguration: # pylint: disable=too-few-public-methods def __init__(self, target: str) -> None: """Init IP configuration instance.""" self.target = target + + +@dataclass +class TargetInfo: + """Collect information about supported targets.""" + + supported_backends: list[str] + + def __str__(self) -> str: + """List supported backends.""" + return ", ".join(sorted(self.supported_backends)) + + def is_supported( + self, advice: AdviceCategory | None = None, check_system: bool = False + ) -> bool: + """Check if any of the supported backends support this kind of advice.""" + return any( + backend_registry.items[name].is_supported(advice, check_system) + for name in self.supported_backends + ) + + def filter_supported_backends( + self, advice: AdviceCategory | None = None, check_system: bool = False + ) -> list[str]: + """Get the list of supported backends filtered by the given arguments.""" + return [ + name + for name in self.supported_backends + if backend_registry.items[name].is_supported(advice, check_system) + ] diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py index fe01835..9b0e611 100644 --- a/src/mlia/target/cortex_a/__init__.py +++ b/src/mlia/target/cortex_a/__init__.py @@ -1,3 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cortex-A target module.""" +from mlia.target.registry import registry +from mlia.target.registry import TargetInfo + +registry.register("Cortex-A", TargetInfo(["ArmNNTFLiteDelegate"])) diff --git a/src/mlia/target/cortex_a/operators.py b/src/mlia/target/cortex_a/operators.py index 91f1886..ae611e5 100644 --- a/src/mlia/target/cortex_a/operators.py +++ b/src/mlia/target/cortex_a/operators.py @@ -9,12 +9,12 @@ from pathlib import Path from typing import Any from typing import ClassVar +from mlia.backend.armnn_tflite_delegate.compat import ( + ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT, +) from mlia.nn.tensorflow.tflite_graph import Op from mlia.nn.tensorflow.tflite_graph import parse_subgraphs from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION -from mlia.target.cortex_a.operator_compatibility import ( - ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT, -) @dataclass diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py index 503919d..3c92ae5 100644 --- a/src/mlia/target/ethos_u/__init__.py +++ b/src/mlia/target/ethos_u/__init__.py @@ -1,3 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U target module.""" +from mlia.target.registry import registry +from mlia.target.registry import TargetInfo + +registry.register("Ethos-U55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) +registry.register("Ethos-U65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"])) diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py new file mode 100644 index 0000000..6b33084 --- /dev/null +++ b/src/mlia/target/registry.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Target module.""" +from __future__ import annotations + +from mlia.backend.registry import registry as backend_registry +from mlia.core.common import AdviceCategory +from mlia.target.config import TargetInfo +from mlia.utils.registry import Registry + +# All supported targets are required to be registered here. +registry = Registry[TargetInfo]() + + +def supported_advice(target: str) -> list[AdviceCategory]: + """Get a list of supported advice for the given target.""" + advice: set[AdviceCategory] = set() + for supported_backend in registry.items[target].supported_backends: + advice.update(backend_registry.items[supported_backend].supported_advice) + return list(advice) + + +def supported_backends(target: str) -> list[str]: + """Get a list of backends supported by the given target.""" + return registry.items[target].filter_supported_backends(check_system=False) + + +def supported_targets(advice: AdviceCategory) -> list[str]: + """Get a list of all targets supporting the given advice category.""" + return [ + name + for name, info in registry.items.items() + if info.is_supported(advice, check_system=False) + ] diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py index 762c831..33c9cf2 100644 --- a/src/mlia/target/tosa/__init__.py +++ b/src/mlia/target/tosa/__init__.py @@ -1,3 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA target module.""" +from mlia.target.registry import registry +from mlia.target.registry import TargetInfo + +registry.register("TOSA", TargetInfo(["TOSA-Checker"])) diff --git a/src/mlia/utils/registry.py b/src/mlia/utils/registry.py new file mode 100644 index 0000000..9b25a81 --- /dev/null +++ b/src/mlia/utils/registry.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Generic registry class.""" +from __future__ import annotations + +from typing import Generic +from typing import TypeVar + +T = TypeVar("T") + + +class Registry(Generic[T]): + """Generic registry for name-config pairs.""" + + def __init__(self) -> None: + """Create an empty registry.""" + self.items: dict[str, T] = {} + + def __str__(self) -> str: + """List all registered items.""" + return "\n".join( + f"- {name}: {item}" + for name, item in sorted(self.items.items(), key=lambda v: v[0]) + ) + + def register(self, name: str, item: T) -> bool: + """Register an item: returns `False` if already registered.""" + if name in self.items: + return False # already registered + self.items[name] = item + return True diff --git a/tests/test_backend_config.py b/tests/test_backend_config.py new file mode 100644 index 0000000..bd50945 --- /dev/null +++ b/tests/test_backend_config.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the backend config module.""" +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.core.common import AdviceCategory + +UNSUPPORTED_SYSTEM = next(sys for sys in System if not sys.is_compatible()) + + +def test_system() -> None: + """Test the class 'System'.""" + assert System.CURRENT.is_compatible() + assert not UNSUPPORTED_SYSTEM.is_compatible() + assert UNSUPPORTED_SYSTEM != System.CURRENT + assert System.LINUX_AMD64 != System.LINUX_AARCH64 + + +def test_backend_config() -> None: + """Test the class 'BackendConfiguration'.""" + cfg = BackendConfiguration( + [AdviceCategory.OPERATORS], [System.CURRENT], BackendType.CUSTOM + ) + assert cfg.supported_advice == [AdviceCategory.OPERATORS] + assert cfg.supported_systems == [System.CURRENT] + assert cfg.type == BackendType.CUSTOM + assert str(cfg) + assert cfg.is_supported() + assert cfg.is_supported(advice=AdviceCategory.OPERATORS) + assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE) + assert cfg.is_supported(check_system=True) + assert cfg.is_supported(check_system=False) + cfg.supported_systems = None + assert cfg.is_supported(check_system=True) + assert cfg.is_supported(check_system=False) + cfg.supported_systems = [UNSUPPORTED_SYSTEM] + assert not cfg.is_supported(check_system=True) + assert cfg.is_supported(check_system=False) + assert not cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=True) + assert cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=False) + assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE, check_system=False) diff --git a/tests/test_backend_registry.py b/tests/test_backend_registry.py new file mode 100644 index 0000000..31a20a0 --- /dev/null +++ b/tests/test_backend_registry.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the backend registry module.""" +from __future__ import annotations + +from functools import partial + +import pytest + +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.backend.registry import registry +from mlia.core.common import AdviceCategory + + +@pytest.mark.parametrize( + ("backend", "advices", "systems", "type_"), + ( + ( + "ArmNNTFLiteDelegate", + [AdviceCategory.OPERATORS], + None, + BackendType.BUILTIN, + ), + ( + "Corstone-300", + [AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], + [System.LINUX_AMD64], + BackendType.CUSTOM, + ), + ( + "Corstone-310", + [AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], + [System.LINUX_AMD64], + BackendType.CUSTOM, + ), + ( + "TOSA-Checker", + [AdviceCategory.OPERATORS], + [System.LINUX_AMD64], + BackendType.WHEEL, + ), + ( + "Vela", + [ + AdviceCategory.OPERATORS, + AdviceCategory.PERFORMANCE, + AdviceCategory.OPTIMIZATION, + ], + [ + System.LINUX_AMD64, + System.LINUX_AARCH64, + System.WINDOWS_AMD64, + System.WINDOWS_AARCH64, + ], + BackendType.BUILTIN, + ), + ), +) +def test_backend_registry( + backend: str, + advices: list[AdviceCategory], + systems: list[System] | None, + type_: BackendType, +) -> None: + """Test the backend registry.""" + sorted_by_name = partial(sorted, key=lambda x: x.name) + + assert backend in registry.items + cfg = registry.items[backend] + assert sorted_by_name(advices) == sorted_by_name( + cfg.supported_advice + ), f"Advices differs: {advices} != {cfg.supported_advice}" + if systems is None: + assert cfg.supported_systems is None + else: + assert cfg.supported_systems is not None + assert sorted_by_name(systems) == sorted_by_name( + cfg.supported_systems + ), f"Supported systems differs: {advices} != {cfg.supported_advice}" + assert cfg.type == type_ diff --git a/tests/test_target_config.py b/tests/test_target_config.py new file mode 100644 index 0000000..66ebed6 --- /dev/null +++ b/tests/test_target_config.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the backend config module.""" +from __future__ import annotations + +import pytest + +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.core.common import AdviceCategory +from mlia.target.config import IPConfiguration +from mlia.target.config import TargetInfo +from mlia.utils.registry import Registry + + +def test_ip_config() -> None: + """Test the class 'IPConfiguration'.""" + cfg = IPConfiguration("AnyTarget") + assert cfg.target == "AnyTarget" + + +@pytest.mark.parametrize( + ("advice", "check_system", "supported"), + ( + (None, False, True), + (None, True, True), + (AdviceCategory.OPERATORS, True, True), + (AdviceCategory.OPTIMIZATION, True, False), + ), +) +def test_target_info( + monkeypatch: pytest.MonkeyPatch, + advice: AdviceCategory | None, + check_system: bool, + supported: bool, +) -> None: + """Test the class 'TargetInfo'.""" + info = TargetInfo(["backend"]) + + backend_registry = Registry[BackendConfiguration]() + backend_registry.register( + "backend", + BackendConfiguration( + [AdviceCategory.OPERATORS], + [System.CURRENT], + BackendType.BUILTIN, + ), + ) + monkeypatch.setattr("mlia.target.config.backend_registry", backend_registry) + + assert info.is_supported(advice, check_system) == supported + assert bool(info.filter_supported_backends(advice, check_system)) == supported diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py index 02a2b14..6effe4c 100644 --- a/tests/test_target_cortex_a_advice_generation.py +++ b/tests/test_target_cortex_a_advice_generation.py @@ -5,6 +5,9 @@ from __future__ import annotations import pytest +from mlia.backend.armnn_tflite_delegate.compat import ( + ARMNN_TFLITE_DELEGATE, +) from mlia.core.advice_generation import Advice from mlia.core.common import AdviceCategory from mlia.core.common import DataItem @@ -16,7 +19,6 @@ from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed -from mlia.target.cortex_a.operator_compatibility import ARMNN_TFLITE_DELEGATE BACKEND_INFO = ( f"{ARMNN_TFLITE_DELEGATE['metadata']['backend']} " diff --git a/tests/test_target_cortex_a_data_analysis.py b/tests/test_target_cortex_a_data_analysis.py index b223b01..e9fc8bc 100644 --- a/tests/test_target_cortex_a_data_analysis.py +++ b/tests/test_target_cortex_a_data_analysis.py @@ -5,6 +5,9 @@ from __future__ import annotations import pytest +from mlia.backend.armnn_tflite_delegate.compat import ( + ARMNN_TFLITE_DELEGATE, +) from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo @@ -18,7 +21,6 @@ from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed -from mlia.target.cortex_a.operator_compatibility import ARMNN_TFLITE_DELEGATE from mlia.target.cortex_a.operators import CortexACompatibilityInfo from mlia.target.cortex_a.operators import Operator diff --git a/tests/test_target_cortex_a_operators.py b/tests/test_target_cortex_a_operators.py index 94eb890..262ebc8 100644 --- a/tests/test_target_cortex_a_operators.py +++ b/tests/test_target_cortex_a_operators.py @@ -6,18 +6,18 @@ from pathlib import Path import pytest import tensorflow as tf +from mlia.backend.armnn_tflite_delegate import compat from mlia.nn.tensorflow.tflite_graph import TFL_OP from mlia.nn.tensorflow.utils import convert_to_tflite -from mlia.target.cortex_a import operator_compatibility as op_compat from mlia.target.cortex_a.operators import CortexACompatibilityInfo from mlia.target.cortex_a.operators import get_cortex_a_compatibility_info from mlia.target.cortex_a.operators import Operator -def test_op_compat_data() -> None: +def test_compat_data() -> None: """Make sure all data contains the necessary items.""" builtin_tfl_ops = {op.name for op in TFL_OP} - for data in [op_compat.ARMNN_TFLITE_DELEGATE]: + for data in [compat.ARMNN_TFLITE_DELEGATE]: assert "metadata" in data assert "backend" in data["metadata"] assert "version" in data["metadata"] diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py new file mode 100644 index 0000000..e6ee296 --- /dev/null +++ b/tests/test_target_registry.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the target registry module.""" +from __future__ import annotations + +import pytest + +from mlia.core.common import AdviceCategory +from mlia.target.registry import registry +from mlia.target.registry import supported_advice +from mlia.target.registry import supported_backends +from mlia.target.registry import supported_targets + + +@pytest.mark.parametrize( + "expected_target", ("Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA") +) +def test_target_registry(expected_target: str) -> None: + """Test the target registry.""" + assert expected_target in registry.items, ( + f"Expected target '{expected_target}' not contained in registered " + f"targets '{registry.items.keys()}'." + ) + + +@pytest.mark.parametrize( + ("target_name", "expected_advices"), + ( + ("Cortex-A", [AdviceCategory.OPERATORS]), + ( + "Ethos-U55", + [ + AdviceCategory.OPERATORS, + AdviceCategory.OPTIMIZATION, + AdviceCategory.PERFORMANCE, + ], + ), + ( + "Ethos-U65", + [ + AdviceCategory.OPERATORS, + AdviceCategory.OPTIMIZATION, + AdviceCategory.PERFORMANCE, + ], + ), + ("TOSA", [AdviceCategory.OPERATORS]), + ), +) +def test_supported_advice( + target_name: str, expected_advices: list[AdviceCategory] +) -> None: + """Test function supported_advice().""" + supported = supported_advice(target_name) + assert all(advice in expected_advices for advice in supported) + assert all(advice in supported for advice in expected_advices) + + +@pytest.mark.parametrize( + ("target_name", "expected_backends"), + ( + ("Cortex-A", ["ArmNNTFLiteDelegate"]), + ("Ethos-U55", ["Corstone-300", "Corstone-310", "Vela"]), + ("Ethos-U65", ["Corstone-300", "Corstone-310", "Vela"]), + ("TOSA", ["TOSA-Checker"]), + ), +) +def test_supported_backends(target_name: str, expected_backends: list[str]) -> None: + """Test function supported_backends().""" + assert sorted(expected_backends) == sorted(supported_backends(target_name)) + + +@pytest.mark.parametrize( + ("advice", "expected_targets"), + ( + (AdviceCategory.OPERATORS, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]), + (AdviceCategory.OPTIMIZATION, ["Ethos-U55", "Ethos-U65"]), + (AdviceCategory.PERFORMANCE, ["Ethos-U55", "Ethos-U65"]), + ), +) +def test_supported_targets(advice: AdviceCategory, expected_targets: list[str]) -> None: + """Test function supported_targets().""" + assert sorted(expected_targets) == sorted(supported_targets(advice)) diff --git a/tests/test_utils_registry.py b/tests/test_utils_registry.py new file mode 100644 index 0000000..95721fc --- /dev/null +++ b/tests/test_utils_registry.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Test the Registry base class.""" +from mlia.utils.registry import Registry + + +def test_registry() -> None: + """Test Registry class.""" + reg = Registry[str]() + assert not str(reg) + assert reg.register("name", "value") + assert not reg.register("name", "value") + assert "name" in reg.items + assert reg.items["name"] == "value" + assert str(reg) + assert reg.register("other_name", "value_2") + assert len(reg.items) == 2 + assert "other_name" in reg.items + assert reg.items["other_name"] == "value_2" |