From dcd0bd31985c27e1d07333351b26cf8ad12ad1fd Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Wed, 14 Dec 2022 11:20:11 +0000 Subject: MLIA-589 Create an API to get target information Change-Id: Ieeaa9188ea1e29e2ccaad7475d457bce71e3140d --- tests/test_backend_config.py | 42 +++++++++++++ tests/test_backend_registry.py | 81 ++++++++++++++++++++++++ tests/test_target_config.py | 53 ++++++++++++++++ tests/test_target_cortex_a_advice_generation.py | 4 +- tests/test_target_cortex_a_data_analysis.py | 4 +- tests/test_target_cortex_a_operators.py | 6 +- tests/test_target_registry.py | 82 +++++++++++++++++++++++++ tests/test_utils_registry.py | 19 ++++++ 8 files changed, 286 insertions(+), 5 deletions(-) create mode 100644 tests/test_backend_config.py create mode 100644 tests/test_backend_registry.py create mode 100644 tests/test_target_config.py create mode 100644 tests/test_target_registry.py create mode 100644 tests/test_utils_registry.py (limited to 'tests') diff --git a/tests/test_backend_config.py b/tests/test_backend_config.py new file mode 100644 index 0000000..bd50945 --- /dev/null +++ b/tests/test_backend_config.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the backend config module.""" +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.core.common import AdviceCategory + +UNSUPPORTED_SYSTEM = next(sys for sys in System if not sys.is_compatible()) + + +def test_system() -> None: + """Test the class 'System'.""" + assert System.CURRENT.is_compatible() + assert not UNSUPPORTED_SYSTEM.is_compatible() + assert UNSUPPORTED_SYSTEM != System.CURRENT + assert System.LINUX_AMD64 != System.LINUX_AARCH64 + + +def test_backend_config() -> None: + """Test the class 'BackendConfiguration'.""" + cfg = BackendConfiguration( + [AdviceCategory.OPERATORS], [System.CURRENT], BackendType.CUSTOM + ) + assert cfg.supported_advice == [AdviceCategory.OPERATORS] + assert cfg.supported_systems == [System.CURRENT] + assert cfg.type == BackendType.CUSTOM + assert str(cfg) + assert cfg.is_supported() + assert cfg.is_supported(advice=AdviceCategory.OPERATORS) + assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE) + assert cfg.is_supported(check_system=True) + assert cfg.is_supported(check_system=False) + cfg.supported_systems = None + assert cfg.is_supported(check_system=True) + assert cfg.is_supported(check_system=False) + cfg.supported_systems = [UNSUPPORTED_SYSTEM] + assert not cfg.is_supported(check_system=True) + assert cfg.is_supported(check_system=False) + assert not cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=True) + assert cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=False) + assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE, check_system=False) diff --git a/tests/test_backend_registry.py b/tests/test_backend_registry.py new file mode 100644 index 0000000..31a20a0 --- /dev/null +++ b/tests/test_backend_registry.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the backend registry module.""" +from __future__ import annotations + +from functools import partial + +import pytest + +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.backend.registry import registry +from mlia.core.common import AdviceCategory + + +@pytest.mark.parametrize( + ("backend", "advices", "systems", "type_"), + ( + ( + "ArmNNTFLiteDelegate", + [AdviceCategory.OPERATORS], + None, + BackendType.BUILTIN, + ), + ( + "Corstone-300", + [AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], + [System.LINUX_AMD64], + BackendType.CUSTOM, + ), + ( + "Corstone-310", + [AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION], + [System.LINUX_AMD64], + BackendType.CUSTOM, + ), + ( + "TOSA-Checker", + [AdviceCategory.OPERATORS], + [System.LINUX_AMD64], + BackendType.WHEEL, + ), + ( + "Vela", + [ + AdviceCategory.OPERATORS, + AdviceCategory.PERFORMANCE, + AdviceCategory.OPTIMIZATION, + ], + [ + System.LINUX_AMD64, + System.LINUX_AARCH64, + System.WINDOWS_AMD64, + System.WINDOWS_AARCH64, + ], + BackendType.BUILTIN, + ), + ), +) +def test_backend_registry( + backend: str, + advices: list[AdviceCategory], + systems: list[System] | None, + type_: BackendType, +) -> None: + """Test the backend registry.""" + sorted_by_name = partial(sorted, key=lambda x: x.name) + + assert backend in registry.items + cfg = registry.items[backend] + assert sorted_by_name(advices) == sorted_by_name( + cfg.supported_advice + ), f"Advices differs: {advices} != {cfg.supported_advice}" + if systems is None: + assert cfg.supported_systems is None + else: + assert cfg.supported_systems is not None + assert sorted_by_name(systems) == sorted_by_name( + cfg.supported_systems + ), f"Supported systems differs: {advices} != {cfg.supported_advice}" + assert cfg.type == type_ diff --git a/tests/test_target_config.py b/tests/test_target_config.py new file mode 100644 index 0000000..66ebed6 --- /dev/null +++ b/tests/test_target_config.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the backend config module.""" +from __future__ import annotations + +import pytest + +from mlia.backend.config import BackendConfiguration +from mlia.backend.config import BackendType +from mlia.backend.config import System +from mlia.core.common import AdviceCategory +from mlia.target.config import IPConfiguration +from mlia.target.config import TargetInfo +from mlia.utils.registry import Registry + + +def test_ip_config() -> None: + """Test the class 'IPConfiguration'.""" + cfg = IPConfiguration("AnyTarget") + assert cfg.target == "AnyTarget" + + +@pytest.mark.parametrize( + ("advice", "check_system", "supported"), + ( + (None, False, True), + (None, True, True), + (AdviceCategory.OPERATORS, True, True), + (AdviceCategory.OPTIMIZATION, True, False), + ), +) +def test_target_info( + monkeypatch: pytest.MonkeyPatch, + advice: AdviceCategory | None, + check_system: bool, + supported: bool, +) -> None: + """Test the class 'TargetInfo'.""" + info = TargetInfo(["backend"]) + + backend_registry = Registry[BackendConfiguration]() + backend_registry.register( + "backend", + BackendConfiguration( + [AdviceCategory.OPERATORS], + [System.CURRENT], + BackendType.BUILTIN, + ), + ) + monkeypatch.setattr("mlia.target.config.backend_registry", backend_registry) + + assert info.is_supported(advice, check_system) == supported + assert bool(info.filter_supported_backends(advice, check_system)) == supported diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py index 02a2b14..6effe4c 100644 --- a/tests/test_target_cortex_a_advice_generation.py +++ b/tests/test_target_cortex_a_advice_generation.py @@ -5,6 +5,9 @@ from __future__ import annotations import pytest +from mlia.backend.armnn_tflite_delegate.compat import ( + ARMNN_TFLITE_DELEGATE, +) from mlia.core.advice_generation import Advice from mlia.core.common import AdviceCategory from mlia.core.common import DataItem @@ -16,7 +19,6 @@ from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed -from mlia.target.cortex_a.operator_compatibility import ARMNN_TFLITE_DELEGATE BACKEND_INFO = ( f"{ARMNN_TFLITE_DELEGATE['metadata']['backend']} " diff --git a/tests/test_target_cortex_a_data_analysis.py b/tests/test_target_cortex_a_data_analysis.py index b223b01..e9fc8bc 100644 --- a/tests/test_target_cortex_a_data_analysis.py +++ b/tests/test_target_cortex_a_data_analysis.py @@ -5,6 +5,9 @@ from __future__ import annotations import pytest +from mlia.backend.armnn_tflite_delegate.compat import ( + ARMNN_TFLITE_DELEGATE, +) from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo @@ -18,7 +21,6 @@ from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed -from mlia.target.cortex_a.operator_compatibility import ARMNN_TFLITE_DELEGATE from mlia.target.cortex_a.operators import CortexACompatibilityInfo from mlia.target.cortex_a.operators import Operator diff --git a/tests/test_target_cortex_a_operators.py b/tests/test_target_cortex_a_operators.py index 94eb890..262ebc8 100644 --- a/tests/test_target_cortex_a_operators.py +++ b/tests/test_target_cortex_a_operators.py @@ -6,18 +6,18 @@ from pathlib import Path import pytest import tensorflow as tf +from mlia.backend.armnn_tflite_delegate import compat from mlia.nn.tensorflow.tflite_graph import TFL_OP from mlia.nn.tensorflow.utils import convert_to_tflite -from mlia.target.cortex_a import operator_compatibility as op_compat from mlia.target.cortex_a.operators import CortexACompatibilityInfo from mlia.target.cortex_a.operators import get_cortex_a_compatibility_info from mlia.target.cortex_a.operators import Operator -def test_op_compat_data() -> None: +def test_compat_data() -> None: """Make sure all data contains the necessary items.""" builtin_tfl_ops = {op.name for op in TFL_OP} - for data in [op_compat.ARMNN_TFLITE_DELEGATE]: + for data in [compat.ARMNN_TFLITE_DELEGATE]: assert "metadata" in data assert "backend" in data["metadata"] assert "version" in data["metadata"] diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py new file mode 100644 index 0000000..e6ee296 --- /dev/null +++ b/tests/test_target_registry.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the target registry module.""" +from __future__ import annotations + +import pytest + +from mlia.core.common import AdviceCategory +from mlia.target.registry import registry +from mlia.target.registry import supported_advice +from mlia.target.registry import supported_backends +from mlia.target.registry import supported_targets + + +@pytest.mark.parametrize( + "expected_target", ("Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA") +) +def test_target_registry(expected_target: str) -> None: + """Test the target registry.""" + assert expected_target in registry.items, ( + f"Expected target '{expected_target}' not contained in registered " + f"targets '{registry.items.keys()}'." + ) + + +@pytest.mark.parametrize( + ("target_name", "expected_advices"), + ( + ("Cortex-A", [AdviceCategory.OPERATORS]), + ( + "Ethos-U55", + [ + AdviceCategory.OPERATORS, + AdviceCategory.OPTIMIZATION, + AdviceCategory.PERFORMANCE, + ], + ), + ( + "Ethos-U65", + [ + AdviceCategory.OPERATORS, + AdviceCategory.OPTIMIZATION, + AdviceCategory.PERFORMANCE, + ], + ), + ("TOSA", [AdviceCategory.OPERATORS]), + ), +) +def test_supported_advice( + target_name: str, expected_advices: list[AdviceCategory] +) -> None: + """Test function supported_advice().""" + supported = supported_advice(target_name) + assert all(advice in expected_advices for advice in supported) + assert all(advice in supported for advice in expected_advices) + + +@pytest.mark.parametrize( + ("target_name", "expected_backends"), + ( + ("Cortex-A", ["ArmNNTFLiteDelegate"]), + ("Ethos-U55", ["Corstone-300", "Corstone-310", "Vela"]), + ("Ethos-U65", ["Corstone-300", "Corstone-310", "Vela"]), + ("TOSA", ["TOSA-Checker"]), + ), +) +def test_supported_backends(target_name: str, expected_backends: list[str]) -> None: + """Test function supported_backends().""" + assert sorted(expected_backends) == sorted(supported_backends(target_name)) + + +@pytest.mark.parametrize( + ("advice", "expected_targets"), + ( + (AdviceCategory.OPERATORS, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]), + (AdviceCategory.OPTIMIZATION, ["Ethos-U55", "Ethos-U65"]), + (AdviceCategory.PERFORMANCE, ["Ethos-U55", "Ethos-U65"]), + ), +) +def test_supported_targets(advice: AdviceCategory, expected_targets: list[str]) -> None: + """Test function supported_targets().""" + assert sorted(expected_targets) == sorted(supported_targets(advice)) diff --git a/tests/test_utils_registry.py b/tests/test_utils_registry.py new file mode 100644 index 0000000..95721fc --- /dev/null +++ b/tests/test_utils_registry.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Test the Registry base class.""" +from mlia.utils.registry import Registry + + +def test_registry() -> None: + """Test Registry class.""" + reg = Registry[str]() + assert not str(reg) + assert reg.register("name", "value") + assert not reg.register("name", "value") + assert "name" in reg.items + assert reg.items["name"] == "value" + assert str(reg) + assert reg.register("other_name", "value_2") + assert len(reg.items) == 2 + assert "other_name" in reg.items + assert reg.items["other_name"] == "value_2" -- cgit v1.2.1