aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-12-14 11:20:11 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-01-04 10:11:33 +0000
commitdcd0bd31985c27e1d07333351b26cf8ad12ad1fd (patch)
treea3388ff5f91e7cdc7ec41271a1a76cdbfae38ece /tests
parent4b4cf29cb1e7d917ae001e258ff01f7846c34778 (diff)
downloadmlia-dcd0bd31985c27e1d07333351b26cf8ad12ad1fd.tar.gz
MLIA-589 Create an API to get target information
Change-Id: Ieeaa9188ea1e29e2ccaad7475d457bce71e3140d
Diffstat (limited to 'tests')
-rw-r--r--tests/test_backend_config.py42
-rw-r--r--tests/test_backend_registry.py81
-rw-r--r--tests/test_target_config.py53
-rw-r--r--tests/test_target_cortex_a_advice_generation.py4
-rw-r--r--tests/test_target_cortex_a_data_analysis.py4
-rw-r--r--tests/test_target_cortex_a_operators.py6
-rw-r--r--tests/test_target_registry.py82
-rw-r--r--tests/test_utils_registry.py19
8 files changed, 286 insertions, 5 deletions
diff --git a/tests/test_backend_config.py b/tests/test_backend_config.py
new file mode 100644
index 0000000..bd50945
--- /dev/null
+++ b/tests/test_backend_config.py
@@ -0,0 +1,42 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the backend config module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.core.common import AdviceCategory
+
+UNSUPPORTED_SYSTEM = next(sys for sys in System if not sys.is_compatible())
+
+
+def test_system() -> None:
+ """Test the class 'System'."""
+ assert System.CURRENT.is_compatible()
+ assert not UNSUPPORTED_SYSTEM.is_compatible()
+ assert UNSUPPORTED_SYSTEM != System.CURRENT
+ assert System.LINUX_AMD64 != System.LINUX_AARCH64
+
+
+def test_backend_config() -> None:
+ """Test the class 'BackendConfiguration'."""
+ cfg = BackendConfiguration(
+ [AdviceCategory.OPERATORS], [System.CURRENT], BackendType.CUSTOM
+ )
+ assert cfg.supported_advice == [AdviceCategory.OPERATORS]
+ assert cfg.supported_systems == [System.CURRENT]
+ assert cfg.type == BackendType.CUSTOM
+ assert str(cfg)
+ assert cfg.is_supported()
+ assert cfg.is_supported(advice=AdviceCategory.OPERATORS)
+ assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE)
+ assert cfg.is_supported(check_system=True)
+ assert cfg.is_supported(check_system=False)
+ cfg.supported_systems = None
+ assert cfg.is_supported(check_system=True)
+ assert cfg.is_supported(check_system=False)
+ cfg.supported_systems = [UNSUPPORTED_SYSTEM]
+ assert not cfg.is_supported(check_system=True)
+ assert cfg.is_supported(check_system=False)
+ assert not cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=True)
+ assert cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=False)
+ assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE, check_system=False)
diff --git a/tests/test_backend_registry.py b/tests/test_backend_registry.py
new file mode 100644
index 0000000..31a20a0
--- /dev/null
+++ b/tests/test_backend_registry.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the backend registry module."""
+from __future__ import annotations
+
+from functools import partial
+
+import pytest
+
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.backend.registry import registry
+from mlia.core.common import AdviceCategory
+
+
+@pytest.mark.parametrize(
+ ("backend", "advices", "systems", "type_"),
+ (
+ (
+ "ArmNNTFLiteDelegate",
+ [AdviceCategory.OPERATORS],
+ None,
+ BackendType.BUILTIN,
+ ),
+ (
+ "Corstone-300",
+ [AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
+ [System.LINUX_AMD64],
+ BackendType.CUSTOM,
+ ),
+ (
+ "Corstone-310",
+ [AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
+ [System.LINUX_AMD64],
+ BackendType.CUSTOM,
+ ),
+ (
+ "TOSA-Checker",
+ [AdviceCategory.OPERATORS],
+ [System.LINUX_AMD64],
+ BackendType.WHEEL,
+ ),
+ (
+ "Vela",
+ [
+ AdviceCategory.OPERATORS,
+ AdviceCategory.PERFORMANCE,
+ AdviceCategory.OPTIMIZATION,
+ ],
+ [
+ System.LINUX_AMD64,
+ System.LINUX_AARCH64,
+ System.WINDOWS_AMD64,
+ System.WINDOWS_AARCH64,
+ ],
+ BackendType.BUILTIN,
+ ),
+ ),
+)
+def test_backend_registry(
+ backend: str,
+ advices: list[AdviceCategory],
+ systems: list[System] | None,
+ type_: BackendType,
+) -> None:
+ """Test the backend registry."""
+ sorted_by_name = partial(sorted, key=lambda x: x.name)
+
+ assert backend in registry.items
+ cfg = registry.items[backend]
+ assert sorted_by_name(advices) == sorted_by_name(
+ cfg.supported_advice
+ ), f"Advices differs: {advices} != {cfg.supported_advice}"
+ if systems is None:
+ assert cfg.supported_systems is None
+ else:
+ assert cfg.supported_systems is not None
+ assert sorted_by_name(systems) == sorted_by_name(
+ cfg.supported_systems
+ ), f"Supported systems differs: {advices} != {cfg.supported_advice}"
+ assert cfg.type == type_
diff --git a/tests/test_target_config.py b/tests/test_target_config.py
new file mode 100644
index 0000000..66ebed6
--- /dev/null
+++ b/tests/test_target_config.py
@@ -0,0 +1,53 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the backend config module."""
+from __future__ import annotations
+
+import pytest
+
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.core.common import AdviceCategory
+from mlia.target.config import IPConfiguration
+from mlia.target.config import TargetInfo
+from mlia.utils.registry import Registry
+
+
+def test_ip_config() -> None:
+ """Test the class 'IPConfiguration'."""
+ cfg = IPConfiguration("AnyTarget")
+ assert cfg.target == "AnyTarget"
+
+
+@pytest.mark.parametrize(
+ ("advice", "check_system", "supported"),
+ (
+ (None, False, True),
+ (None, True, True),
+ (AdviceCategory.OPERATORS, True, True),
+ (AdviceCategory.OPTIMIZATION, True, False),
+ ),
+)
+def test_target_info(
+ monkeypatch: pytest.MonkeyPatch,
+ advice: AdviceCategory | None,
+ check_system: bool,
+ supported: bool,
+) -> None:
+ """Test the class 'TargetInfo'."""
+ info = TargetInfo(["backend"])
+
+ backend_registry = Registry[BackendConfiguration]()
+ backend_registry.register(
+ "backend",
+ BackendConfiguration(
+ [AdviceCategory.OPERATORS],
+ [System.CURRENT],
+ BackendType.BUILTIN,
+ ),
+ )
+ monkeypatch.setattr("mlia.target.config.backend_registry", backend_registry)
+
+ assert info.is_supported(advice, check_system) == supported
+ assert bool(info.filter_supported_backends(advice, check_system)) == supported
diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py
index 02a2b14..6effe4c 100644
--- a/tests/test_target_cortex_a_advice_generation.py
+++ b/tests/test_target_cortex_a_advice_generation.py
@@ -5,6 +5,9 @@ from __future__ import annotations
import pytest
+from mlia.backend.armnn_tflite_delegate.compat import (
+ ARMNN_TFLITE_DELEGATE,
+)
from mlia.core.advice_generation import Advice
from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
@@ -16,7 +19,6 @@ from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible
from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible
from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible
from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed
-from mlia.target.cortex_a.operator_compatibility import ARMNN_TFLITE_DELEGATE
BACKEND_INFO = (
f"{ARMNN_TFLITE_DELEGATE['metadata']['backend']} "
diff --git a/tests/test_target_cortex_a_data_analysis.py b/tests/test_target_cortex_a_data_analysis.py
index b223b01..e9fc8bc 100644
--- a/tests/test_target_cortex_a_data_analysis.py
+++ b/tests/test_target_cortex_a_data_analysis.py
@@ -5,6 +5,9 @@ from __future__ import annotations
import pytest
+from mlia.backend.armnn_tflite_delegate.compat import (
+ ARMNN_TFLITE_DELEGATE,
+)
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
@@ -18,7 +21,6 @@ from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible
from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible
from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible
from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed
-from mlia.target.cortex_a.operator_compatibility import ARMNN_TFLITE_DELEGATE
from mlia.target.cortex_a.operators import CortexACompatibilityInfo
from mlia.target.cortex_a.operators import Operator
diff --git a/tests/test_target_cortex_a_operators.py b/tests/test_target_cortex_a_operators.py
index 94eb890..262ebc8 100644
--- a/tests/test_target_cortex_a_operators.py
+++ b/tests/test_target_cortex_a_operators.py
@@ -6,18 +6,18 @@ from pathlib import Path
import pytest
import tensorflow as tf
+from mlia.backend.armnn_tflite_delegate import compat
from mlia.nn.tensorflow.tflite_graph import TFL_OP
from mlia.nn.tensorflow.utils import convert_to_tflite
-from mlia.target.cortex_a import operator_compatibility as op_compat
from mlia.target.cortex_a.operators import CortexACompatibilityInfo
from mlia.target.cortex_a.operators import get_cortex_a_compatibility_info
from mlia.target.cortex_a.operators import Operator
-def test_op_compat_data() -> None:
+def test_compat_data() -> None:
"""Make sure all data contains the necessary items."""
builtin_tfl_ops = {op.name for op in TFL_OP}
- for data in [op_compat.ARMNN_TFLITE_DELEGATE]:
+ for data in [compat.ARMNN_TFLITE_DELEGATE]:
assert "metadata" in data
assert "backend" in data["metadata"]
assert "version" in data["metadata"]
diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py
new file mode 100644
index 0000000..e6ee296
--- /dev/null
+++ b/tests/test_target_registry.py
@@ -0,0 +1,82 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the target registry module."""
+from __future__ import annotations
+
+import pytest
+
+from mlia.core.common import AdviceCategory
+from mlia.target.registry import registry
+from mlia.target.registry import supported_advice
+from mlia.target.registry import supported_backends
+from mlia.target.registry import supported_targets
+
+
+@pytest.mark.parametrize(
+ "expected_target", ("Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA")
+)
+def test_target_registry(expected_target: str) -> None:
+ """Test the target registry."""
+ assert expected_target in registry.items, (
+ f"Expected target '{expected_target}' not contained in registered "
+ f"targets '{registry.items.keys()}'."
+ )
+
+
+@pytest.mark.parametrize(
+ ("target_name", "expected_advices"),
+ (
+ ("Cortex-A", [AdviceCategory.OPERATORS]),
+ (
+ "Ethos-U55",
+ [
+ AdviceCategory.OPERATORS,
+ AdviceCategory.OPTIMIZATION,
+ AdviceCategory.PERFORMANCE,
+ ],
+ ),
+ (
+ "Ethos-U65",
+ [
+ AdviceCategory.OPERATORS,
+ AdviceCategory.OPTIMIZATION,
+ AdviceCategory.PERFORMANCE,
+ ],
+ ),
+ ("TOSA", [AdviceCategory.OPERATORS]),
+ ),
+)
+def test_supported_advice(
+ target_name: str, expected_advices: list[AdviceCategory]
+) -> None:
+ """Test function supported_advice()."""
+ supported = supported_advice(target_name)
+ assert all(advice in expected_advices for advice in supported)
+ assert all(advice in supported for advice in expected_advices)
+
+
+@pytest.mark.parametrize(
+ ("target_name", "expected_backends"),
+ (
+ ("Cortex-A", ["ArmNNTFLiteDelegate"]),
+ ("Ethos-U55", ["Corstone-300", "Corstone-310", "Vela"]),
+ ("Ethos-U65", ["Corstone-300", "Corstone-310", "Vela"]),
+ ("TOSA", ["TOSA-Checker"]),
+ ),
+)
+def test_supported_backends(target_name: str, expected_backends: list[str]) -> None:
+ """Test function supported_backends()."""
+ assert sorted(expected_backends) == sorted(supported_backends(target_name))
+
+
+@pytest.mark.parametrize(
+ ("advice", "expected_targets"),
+ (
+ (AdviceCategory.OPERATORS, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]),
+ (AdviceCategory.OPTIMIZATION, ["Ethos-U55", "Ethos-U65"]),
+ (AdviceCategory.PERFORMANCE, ["Ethos-U55", "Ethos-U65"]),
+ ),
+)
+def test_supported_targets(advice: AdviceCategory, expected_targets: list[str]) -> None:
+ """Test function supported_targets()."""
+ assert sorted(expected_targets) == sorted(supported_targets(advice))
diff --git a/tests/test_utils_registry.py b/tests/test_utils_registry.py
new file mode 100644
index 0000000..95721fc
--- /dev/null
+++ b/tests/test_utils_registry.py
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test the Registry base class."""
+from mlia.utils.registry import Registry
+
+
+def test_registry() -> None:
+ """Test Registry class."""
+ reg = Registry[str]()
+ assert not str(reg)
+ assert reg.register("name", "value")
+ assert not reg.register("name", "value")
+ assert "name" in reg.items
+ assert reg.items["name"] == "value"
+ assert str(reg)
+ assert reg.register("other_name", "value_2")
+ assert len(reg.items) == 2
+ assert "other_name" in reg.items
+ assert reg.items["other_name"] == "value_2"