aboutsummaryrefslogtreecommitdiff
path: root/tests/test_target_registry.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_target_registry.py')
-rw-r--r--tests/test_target_registry.py82
1 files changed, 82 insertions, 0 deletions
diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py
new file mode 100644
index 0000000..e6ee296
--- /dev/null
+++ b/tests/test_target_registry.py
@@ -0,0 +1,82 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the target registry module."""
+from __future__ import annotations
+
+import pytest
+
+from mlia.core.common import AdviceCategory
+from mlia.target.registry import registry
+from mlia.target.registry import supported_advice
+from mlia.target.registry import supported_backends
+from mlia.target.registry import supported_targets
+
+
+@pytest.mark.parametrize(
+ "expected_target", ("Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA")
+)
+def test_target_registry(expected_target: str) -> None:
+ """Test the target registry."""
+ assert expected_target in registry.items, (
+ f"Expected target '{expected_target}' not contained in registered "
+ f"targets '{registry.items.keys()}'."
+ )
+
+
+@pytest.mark.parametrize(
+ ("target_name", "expected_advices"),
+ (
+ ("Cortex-A", [AdviceCategory.OPERATORS]),
+ (
+ "Ethos-U55",
+ [
+ AdviceCategory.OPERATORS,
+ AdviceCategory.OPTIMIZATION,
+ AdviceCategory.PERFORMANCE,
+ ],
+ ),
+ (
+ "Ethos-U65",
+ [
+ AdviceCategory.OPERATORS,
+ AdviceCategory.OPTIMIZATION,
+ AdviceCategory.PERFORMANCE,
+ ],
+ ),
+ ("TOSA", [AdviceCategory.OPERATORS]),
+ ),
+)
+def test_supported_advice(
+ target_name: str, expected_advices: list[AdviceCategory]
+) -> None:
+ """Test function supported_advice()."""
+ supported = supported_advice(target_name)
+ assert all(advice in expected_advices for advice in supported)
+ assert all(advice in supported for advice in expected_advices)
+
+
+@pytest.mark.parametrize(
+ ("target_name", "expected_backends"),
+ (
+ ("Cortex-A", ["ArmNNTFLiteDelegate"]),
+ ("Ethos-U55", ["Corstone-300", "Corstone-310", "Vela"]),
+ ("Ethos-U65", ["Corstone-300", "Corstone-310", "Vela"]),
+ ("TOSA", ["TOSA-Checker"]),
+ ),
+)
+def test_supported_backends(target_name: str, expected_backends: list[str]) -> None:
+ """Test function supported_backends()."""
+ assert sorted(expected_backends) == sorted(supported_backends(target_name))
+
+
+@pytest.mark.parametrize(
+ ("advice", "expected_targets"),
+ (
+ (AdviceCategory.OPERATORS, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]),
+ (AdviceCategory.OPTIMIZATION, ["Ethos-U55", "Ethos-U65"]),
+ (AdviceCategory.PERFORMANCE, ["Ethos-U55", "Ethos-U65"]),
+ ),
+)
+def test_supported_targets(advice: AdviceCategory, expected_targets: list[str]) -> None:
+ """Test function supported_targets()."""
+ assert sorted(expected_targets) == sorted(supported_targets(advice))