aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-12-14 11:20:11 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-01-04 10:11:33 +0000
commitdcd0bd31985c27e1d07333351b26cf8ad12ad1fd (patch)
treea3388ff5f91e7cdc7ec41271a1a76cdbfae38ece
parent4b4cf29cb1e7d917ae001e258ff01f7846c34778 (diff)
downloadmlia-dcd0bd31985c27e1d07333351b26cf8ad12ad1fd.tar.gz
MLIA-589 Create an API to get target information
Change-Id: Ieeaa9188ea1e29e2ccaad7475d457bce71e3140d
-rw-r--r--src/mlia/backend/__init__.py9
-rw-r--r--src/mlia/backend/armnn_tflite_delegate/__init__.py16
-rw-r--r--src/mlia/backend/armnn_tflite_delegate/compat.py (renamed from src/mlia/target/cortex_a/operator_compatibility.py)0
-rw-r--r--src/mlia/backend/config.py82
-rw-r--r--src/mlia/backend/corstone/__init__.py22
-rw-r--r--src/mlia/backend/registry.py8
-rw-r--r--src/mlia/backend/tosa_checker/__init__.py14
-rw-r--r--src/mlia/backend/vela/__init__.py23
-rw-r--r--src/mlia/cli/main.py9
-rw-r--r--src/mlia/cli/options.py2
-rw-r--r--src/mlia/target/__init__.py6
-rw-r--r--src/mlia/target/config.py36
-rw-r--r--src/mlia/target/cortex_a/__init__.py4
-rw-r--r--src/mlia/target/cortex_a/operators.py6
-rw-r--r--src/mlia/target/ethos_u/__init__.py5
-rw-r--r--src/mlia/target/registry.py34
-rw-r--r--src/mlia/target/tosa/__init__.py4
-rw-r--r--src/mlia/utils/registry.py31
-rw-r--r--tests/test_backend_config.py42
-rw-r--r--tests/test_backend_registry.py81
-rw-r--r--tests/test_target_config.py53
-rw-r--r--tests/test_target_cortex_a_advice_generation.py4
-rw-r--r--tests/test_target_cortex_a_data_analysis.py4
-rw-r--r--tests/test_target_cortex_a_operators.py6
-rw-r--r--tests/test_target_registry.py82
-rw-r--r--tests/test_utils_registry.py19
26 files changed, 588 insertions, 14 deletions
diff --git a/src/mlia/backend/__init__.py b/src/mlia/backend/__init__.py
index 745aa1b..2d1da70 100644
--- a/src/mlia/backend/__init__.py
+++ b/src/mlia/backend/__init__.py
@@ -1,3 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Backends module."""
+"""Backend module."""
+# Make sure all targets are registered with the registry by importing the
+# sub-modules
+# flake8: noqa
+from mlia.backend import armnn_tflite_delegate
+from mlia.backend import corstone
+from mlia.backend import tosa_checker
+from mlia.backend import vela
diff --git a/src/mlia/backend/armnn_tflite_delegate/__init__.py b/src/mlia/backend/armnn_tflite_delegate/__init__.py
new file mode 100644
index 0000000..6d5af42
--- /dev/null
+++ b/src/mlia/backend/armnn_tflite_delegate/__init__.py
@@ -0,0 +1,16 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Arm NN TensorFlow Lite delegate backend module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.registry import registry
+from mlia.core.common import AdviceCategory
+
+registry.register(
+ "ArmNNTFLiteDelegate",
+ BackendConfiguration(
+ supported_advice=[AdviceCategory.OPERATORS],
+ supported_systems=None,
+ backend_type=BackendType.BUILTIN,
+ ),
+)
diff --git a/src/mlia/target/cortex_a/operator_compatibility.py b/src/mlia/backend/armnn_tflite_delegate/compat.py
index c474e75..c474e75 100644
--- a/src/mlia/target/cortex_a/operator_compatibility.py
+++ b/src/mlia/backend/armnn_tflite_delegate/compat.py
diff --git a/src/mlia/backend/config.py b/src/mlia/backend/config.py
new file mode 100644
index 0000000..8d14b28
--- /dev/null
+++ b/src/mlia/backend/config.py
@@ -0,0 +1,82 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Backend config module."""
+from __future__ import annotations
+
+import platform
+from enum import auto
+from enum import Enum
+from typing import cast
+
+from mlia.core.common import AdviceCategory
+
+
+class System(Enum):
+ """Enum of system configurations (e.g. OS and architecture)."""
+
+ LINUX_AMD64 = ("Linux", "x86_64")
+ LINUX_AARCH64 = ("Linux", "aarch64")
+ WINDOWS_AMD64 = ("Windows", "AMD64")
+ WINDOWS_AARCH64 = ("Windows", "ARM64")
+ DARWIN_AARCH64 = ("Darwin", "arm64")
+ CURRENT = (platform.system(), platform.machine())
+
+ def __init__(
+ self, system: str = platform.system(), machine: str = platform.machine()
+ ) -> None:
+ """Set the system parameters (defaults to the current system)."""
+ self.system = system.lower()
+ self.machine = machine.lower()
+
+ def __eq__(self, other: object) -> bool:
+ """
+ Compare two System instances for equality.
+
+ Raises a TypeError if the input is not a System.
+ """
+ if isinstance(other, self.__class__):
+ return self.system == other.system and self.machine == other.machine
+ return False
+
+ def is_compatible(self) -> bool:
+ """Check if this system is compatible with the current system."""
+ return self == self.CURRENT
+
+
+class BackendType(Enum):
+ """Define the type of the backend (builtin, wheel file etc)."""
+
+ BUILTIN = auto()
+ WHEEL = auto()
+ CUSTOM = auto()
+
+
+class BackendConfiguration:
+ """Base class for backend configurations."""
+
+ def __init__(
+ self,
+ supported_advice: list[AdviceCategory],
+ supported_systems: list[System] | None,
+ backend_type: BackendType,
+ ) -> None:
+ """Set up basic information about the backend."""
+ self.supported_advice = supported_advice
+ self.supported_systems = supported_systems
+ self.type = backend_type
+
+ def __str__(self) -> str:
+ """List supported advice."""
+ return ", ".join(cast(str, adv.name).lower() for adv in self.supported_advice)
+
+ def is_supported(
+ self, advice: AdviceCategory | None = None, check_system: bool = False
+ ) -> bool:
+ """Check backend supports the current system and advice."""
+ is_system_supported = (
+ not self.supported_systems
+ or not check_system
+ or any(sys.is_compatible() for sys in self.supported_systems)
+ )
+ is_advice_supported = advice is None or advice in self.supported_advice
+ return is_system_supported and is_advice_supported
diff --git a/src/mlia/backend/corstone/__init__.py b/src/mlia/backend/corstone/__init__.py
index a1eac14..f89da63 100644
--- a/src/mlia/backend/corstone/__init__.py
+++ b/src/mlia/backend/corstone/__init__.py
@@ -1,3 +1,25 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Corstone backend module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.backend.registry import registry
+from mlia.core.common import AdviceCategory
+
+registry.register(
+ "Corstone-300",
+ BackendConfiguration(
+ supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
+ supported_systems=[System.LINUX_AMD64],
+ backend_type=BackendType.CUSTOM,
+ ),
+)
+registry.register(
+ "Corstone-310",
+ BackendConfiguration(
+ supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
+ supported_systems=[System.LINUX_AMD64],
+ backend_type=BackendType.CUSTOM,
+ ),
+)
diff --git a/src/mlia/backend/registry.py b/src/mlia/backend/registry.py
new file mode 100644
index 0000000..6a0da74
--- /dev/null
+++ b/src/mlia/backend/registry.py
@@ -0,0 +1,8 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Backend module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.utils.registry import Registry
+
+# All supported targets are required to be registered here.
+registry = Registry[BackendConfiguration]()
diff --git a/src/mlia/backend/tosa_checker/__init__.py b/src/mlia/backend/tosa_checker/__init__.py
index cec210d..19fc8be 100644
--- a/src/mlia/backend/tosa_checker/__init__.py
+++ b/src/mlia/backend/tosa_checker/__init__.py
@@ -1,3 +1,17 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA checker backend module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.backend.registry import registry
+from mlia.core.common import AdviceCategory
+
+registry.register(
+ "TOSA-Checker",
+ BackendConfiguration(
+ supported_advice=[AdviceCategory.OPERATORS],
+ supported_systems=[System.LINUX_AMD64],
+ backend_type=BackendType.WHEEL,
+ ),
+)
diff --git a/src/mlia/backend/vela/__init__.py b/src/mlia/backend/vela/__init__.py
index 6ea0c21..38a623e 100644
--- a/src/mlia/backend/vela/__init__.py
+++ b/src/mlia/backend/vela/__init__.py
@@ -1,3 +1,26 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Vela backend module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.backend.registry import registry
+from mlia.core.common import AdviceCategory
+
+registry.register(
+ "Vela",
+ BackendConfiguration(
+ supported_advice=[
+ AdviceCategory.OPERATORS,
+ AdviceCategory.PERFORMANCE,
+ AdviceCategory.OPTIMIZATION,
+ ],
+ supported_systems=[
+ System.LINUX_AMD64,
+ System.LINUX_AARCH64,
+ System.WINDOWS_AMD64,
+ System.WINDOWS_AARCH64,
+ ],
+ backend_type=BackendType.BUILTIN,
+ ),
+)
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 98fdb63..ac60308 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -12,6 +12,7 @@ from pathlib import Path
from mlia import __version__
from mlia.backend.errors import BackendUnavailableError
+from mlia.backend.registry import registry as backend_registry
from mlia.cli.commands import all_tests
from mlia.cli.commands import backend_install
from mlia.cli.commands import backend_list
@@ -36,6 +37,7 @@ from mlia.cli.options import add_tflite_model_options
from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
+from mlia.target.registry import registry as target_registry
logger = logging.getLogger(__name__)
@@ -46,11 +48,10 @@ ML Inference Advisor {__version__}
Help the design and optimization of neural network models for efficient inference on a target CPU and NPU
Supported targets:
+{target_registry}
- - Cortex-A <op compatibility>
- - Ethos-U55 <op compatibility, perf estimation, model opt>
- - Ethos-U65 <op compatibility, perf estimation, model opt>
- - TOSA <op compatibility>
+Supported backends:
+{backend_registry}
""".strip()
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 5eab9aa..8ea4250 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -38,7 +38,7 @@ def add_target_options(
help="Target profile that will set the target options "
"such as target, mac value, memory mode, etc. "
f"For the values associated with each target profile "
- f" please refer to the documenation {default_help}.",
+ f" please refer to the documentation {default_help}.",
)
diff --git a/src/mlia/target/__init__.py b/src/mlia/target/__init__.py
index 2370221..a9979c6 100644
--- a/src/mlia/target/__init__.py
+++ b/src/mlia/target/__init__.py
@@ -1,3 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target module."""
+# Make sure all targets are registered with the registry by importing the
+# sub-modules
+# flake8: noqa
+from mlia.target import cortex_a
+from mlia.target import ethos_u
+from mlia.target import tosa
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index 7ab6b43..f257784 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -1,6 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""IP configuration module."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from mlia.backend.registry import registry as backend_registry
+from mlia.core.common import AdviceCategory
class IPConfiguration: # pylint: disable=too-few-public-methods
@@ -9,3 +15,33 @@ class IPConfiguration: # pylint: disable=too-few-public-methods
def __init__(self, target: str) -> None:
"""Init IP configuration instance."""
self.target = target
+
+
+@dataclass
+class TargetInfo:
+ """Collect information about supported targets."""
+
+ supported_backends: list[str]
+
+ def __str__(self) -> str:
+ """List supported backends."""
+ return ", ".join(sorted(self.supported_backends))
+
+ def is_supported(
+ self, advice: AdviceCategory | None = None, check_system: bool = False
+ ) -> bool:
+ """Check if any of the supported backends support this kind of advice."""
+ return any(
+ backend_registry.items[name].is_supported(advice, check_system)
+ for name in self.supported_backends
+ )
+
+ def filter_supported_backends(
+ self, advice: AdviceCategory | None = None, check_system: bool = False
+ ) -> list[str]:
+ """Get the list of supported backends filtered by the given arguments."""
+ return [
+ name
+ for name in self.supported_backends
+ if backend_registry.items[name].is_supported(advice, check_system)
+ ]
diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py
index fe01835..9b0e611 100644
--- a/src/mlia/target/cortex_a/__init__.py
+++ b/src/mlia/target/cortex_a/__init__.py
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A target module."""
+from mlia.target.registry import registry
+from mlia.target.registry import TargetInfo
+
+registry.register("Cortex-A", TargetInfo(["ArmNNTFLiteDelegate"]))
diff --git a/src/mlia/target/cortex_a/operators.py b/src/mlia/target/cortex_a/operators.py
index 91f1886..ae611e5 100644
--- a/src/mlia/target/cortex_a/operators.py
+++ b/src/mlia/target/cortex_a/operators.py
@@ -9,12 +9,12 @@ from pathlib import Path
from typing import Any
from typing import ClassVar
+from mlia.backend.armnn_tflite_delegate.compat import (
+ ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT,
+)
from mlia.nn.tensorflow.tflite_graph import Op
from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
-from mlia.target.cortex_a.operator_compatibility import (
- ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT,
-)
@dataclass
diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py
index 503919d..3c92ae5 100644
--- a/src/mlia/target/ethos_u/__init__.py
+++ b/src/mlia/target/ethos_u/__init__.py
@@ -1,3 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U target module."""
+from mlia.target.registry import registry
+from mlia.target.registry import TargetInfo
+
+registry.register("Ethos-U55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
+registry.register("Ethos-U65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py
new file mode 100644
index 0000000..6b33084
--- /dev/null
+++ b/src/mlia/target/registry.py
@@ -0,0 +1,34 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Target module."""
+from __future__ import annotations
+
+from mlia.backend.registry import registry as backend_registry
+from mlia.core.common import AdviceCategory
+from mlia.target.config import TargetInfo
+from mlia.utils.registry import Registry
+
+# All supported targets are required to be registered here.
+registry = Registry[TargetInfo]()
+
+
+def supported_advice(target: str) -> list[AdviceCategory]:
+ """Get a list of supported advice for the given target."""
+ advice: set[AdviceCategory] = set()
+ for supported_backend in registry.items[target].supported_backends:
+ advice.update(backend_registry.items[supported_backend].supported_advice)
+ return list(advice)
+
+
+def supported_backends(target: str) -> list[str]:
+ """Get a list of backends supported by the given target."""
+ return registry.items[target].filter_supported_backends(check_system=False)
+
+
+def supported_targets(advice: AdviceCategory) -> list[str]:
+ """Get a list of all targets supporting the given advice category."""
+ return [
+ name
+ for name, info in registry.items.items()
+ if info.is_supported(advice, check_system=False)
+ ]
diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py
index 762c831..33c9cf2 100644
--- a/src/mlia/target/tosa/__init__.py
+++ b/src/mlia/target/tosa/__init__.py
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA target module."""
+from mlia.target.registry import registry
+from mlia.target.registry import TargetInfo
+
+registry.register("TOSA", TargetInfo(["TOSA-Checker"]))
diff --git a/src/mlia/utils/registry.py b/src/mlia/utils/registry.py
new file mode 100644
index 0000000..9b25a81
--- /dev/null
+++ b/src/mlia/utils/registry.py
@@ -0,0 +1,31 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Generic registry class."""
+from __future__ import annotations
+
+from typing import Generic
+from typing import TypeVar
+
+T = TypeVar("T")
+
+
+class Registry(Generic[T]):
+ """Generic registry for name-config pairs."""
+
+ def __init__(self) -> None:
+ """Create an empty registry."""
+ self.items: dict[str, T] = {}
+
+ def __str__(self) -> str:
+ """List all registered items."""
+ return "\n".join(
+ f"- {name}: {item}"
+ for name, item in sorted(self.items.items(), key=lambda v: v[0])
+ )
+
+ def register(self, name: str, item: T) -> bool:
+ """Register an item: returns `False` if already registered."""
+ if name in self.items:
+ return False # already registered
+ self.items[name] = item
+ return True
diff --git a/tests/test_backend_config.py b/tests/test_backend_config.py
new file mode 100644
index 0000000..bd50945
--- /dev/null
+++ b/tests/test_backend_config.py
@@ -0,0 +1,42 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the backend config module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.core.common import AdviceCategory
+
+UNSUPPORTED_SYSTEM = next(sys for sys in System if not sys.is_compatible())
+
+
+def test_system() -> None:
+ """Test the class 'System'."""
+ assert System.CURRENT.is_compatible()
+ assert not UNSUPPORTED_SYSTEM.is_compatible()
+ assert UNSUPPORTED_SYSTEM != System.CURRENT
+ assert System.LINUX_AMD64 != System.LINUX_AARCH64
+
+
+def test_backend_config() -> None:
+ """Test the class 'BackendConfiguration'."""
+ cfg = BackendConfiguration(
+ [AdviceCategory.OPERATORS], [System.CURRENT], BackendType.CUSTOM
+ )
+ assert cfg.supported_advice == [AdviceCategory.OPERATORS]
+ assert cfg.supported_systems == [System.CURRENT]
+ assert cfg.type == BackendType.CUSTOM
+ assert str(cfg)
+ assert cfg.is_supported()
+ assert cfg.is_supported(advice=AdviceCategory.OPERATORS)
+ assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE)
+ assert cfg.is_supported(check_system=True)
+ assert cfg.is_supported(check_system=False)
+ cfg.supported_systems = None
+ assert cfg.is_supported(check_system=True)
+ assert cfg.is_supported(check_system=False)
+ cfg.supported_systems = [UNSUPPORTED_SYSTEM]
+ assert not cfg.is_supported(check_system=True)
+ assert cfg.is_supported(check_system=False)
+ assert not cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=True)
+ assert cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=False)
+ assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE, check_system=False)
diff --git a/tests/test_backend_registry.py b/tests/test_backend_registry.py
new file mode 100644
index 0000000..31a20a0
--- /dev/null
+++ b/tests/test_backend_registry.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the backend registry module."""
+from __future__ import annotations
+
+from functools import partial
+
+import pytest
+
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.backend.registry import registry
+from mlia.core.common import AdviceCategory
+
+
+@pytest.mark.parametrize(
+ ("backend", "advices", "systems", "type_"),
+ (
+ (
+ "ArmNNTFLiteDelegate",
+ [AdviceCategory.OPERATORS],
+ None,
+ BackendType.BUILTIN,
+ ),
+ (
+ "Corstone-300",
+ [AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
+ [System.LINUX_AMD64],
+ BackendType.CUSTOM,
+ ),
+ (
+ "Corstone-310",
+ [AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
+ [System.LINUX_AMD64],
+ BackendType.CUSTOM,
+ ),
+ (
+ "TOSA-Checker",
+ [AdviceCategory.OPERATORS],
+ [System.LINUX_AMD64],
+ BackendType.WHEEL,
+ ),
+ (
+ "Vela",
+ [
+ AdviceCategory.OPERATORS,
+ AdviceCategory.PERFORMANCE,
+ AdviceCategory.OPTIMIZATION,
+ ],
+ [
+ System.LINUX_AMD64,
+ System.LINUX_AARCH64,
+ System.WINDOWS_AMD64,
+ System.WINDOWS_AARCH64,
+ ],
+ BackendType.BUILTIN,
+ ),
+ ),
+)
+def test_backend_registry(
+ backend: str,
+ advices: list[AdviceCategory],
+ systems: list[System] | None,
+ type_: BackendType,
+) -> None:
+ """Test the backend registry."""
+ sorted_by_name = partial(sorted, key=lambda x: x.name)
+
+ assert backend in registry.items
+ cfg = registry.items[backend]
+ assert sorted_by_name(advices) == sorted_by_name(
+ cfg.supported_advice
+ ), f"Advices differs: {advices} != {cfg.supported_advice}"
+ if systems is None:
+ assert cfg.supported_systems is None
+ else:
+ assert cfg.supported_systems is not None
+ assert sorted_by_name(systems) == sorted_by_name(
+ cfg.supported_systems
+ ), f"Supported systems differs: {advices} != {cfg.supported_advice}"
+ assert cfg.type == type_
diff --git a/tests/test_target_config.py b/tests/test_target_config.py
new file mode 100644
index 0000000..66ebed6
--- /dev/null
+++ b/tests/test_target_config.py
@@ -0,0 +1,53 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the backend config module."""
+from __future__ import annotations
+
+import pytest
+
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.core.common import AdviceCategory
+from mlia.target.config import IPConfiguration
+from mlia.target.config import TargetInfo
+from mlia.utils.registry import Registry
+
+
+def test_ip_config() -> None:
+ """Test the class 'IPConfiguration'."""
+ cfg = IPConfiguration("AnyTarget")
+ assert cfg.target == "AnyTarget"
+
+
+@pytest.mark.parametrize(
+ ("advice", "check_system", "supported"),
+ (
+ (None, False, True),
+ (None, True, True),
+ (AdviceCategory.OPERATORS, True, True),
+ (AdviceCategory.OPTIMIZATION, True, False),
+ ),
+)
+def test_target_info(
+ monkeypatch: pytest.MonkeyPatch,
+ advice: AdviceCategory | None,
+ check_system: bool,
+ supported: bool,
+) -> None:
+ """Test the class 'TargetInfo'."""
+ info = TargetInfo(["backend"])
+
+ backend_registry = Registry[BackendConfiguration]()
+ backend_registry.register(
+ "backend",
+ BackendConfiguration(
+ [AdviceCategory.OPERATORS],
+ [System.CURRENT],
+ BackendType.BUILTIN,
+ ),
+ )
+ monkeypatch.setattr("mlia.target.config.backend_registry", backend_registry)
+
+ assert info.is_supported(advice, check_system) == supported
+ assert bool(info.filter_supported_backends(advice, check_system)) == supported
diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py
index 02a2b14..6effe4c 100644
--- a/tests/test_target_cortex_a_advice_generation.py
+++ b/tests/test_target_cortex_a_advice_generation.py
@@ -5,6 +5,9 @@ from __future__ import annotations
import pytest
+from mlia.backend.armnn_tflite_delegate.compat import (
+ ARMNN_TFLITE_DELEGATE,
+)
from mlia.core.advice_generation import Advice
from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
@@ -16,7 +19,6 @@ from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible
from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible
from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible
from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed
-from mlia.target.cortex_a.operator_compatibility import ARMNN_TFLITE_DELEGATE
BACKEND_INFO = (
f"{ARMNN_TFLITE_DELEGATE['metadata']['backend']} "
diff --git a/tests/test_target_cortex_a_data_analysis.py b/tests/test_target_cortex_a_data_analysis.py
index b223b01..e9fc8bc 100644
--- a/tests/test_target_cortex_a_data_analysis.py
+++ b/tests/test_target_cortex_a_data_analysis.py
@@ -5,6 +5,9 @@ from __future__ import annotations
import pytest
+from mlia.backend.armnn_tflite_delegate.compat import (
+ ARMNN_TFLITE_DELEGATE,
+)
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
@@ -18,7 +21,6 @@ from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible
from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible
from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible
from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed
-from mlia.target.cortex_a.operator_compatibility import ARMNN_TFLITE_DELEGATE
from mlia.target.cortex_a.operators import CortexACompatibilityInfo
from mlia.target.cortex_a.operators import Operator
diff --git a/tests/test_target_cortex_a_operators.py b/tests/test_target_cortex_a_operators.py
index 94eb890..262ebc8 100644
--- a/tests/test_target_cortex_a_operators.py
+++ b/tests/test_target_cortex_a_operators.py
@@ -6,18 +6,18 @@ from pathlib import Path
import pytest
import tensorflow as tf
+from mlia.backend.armnn_tflite_delegate import compat
from mlia.nn.tensorflow.tflite_graph import TFL_OP
from mlia.nn.tensorflow.utils import convert_to_tflite
-from mlia.target.cortex_a import operator_compatibility as op_compat
from mlia.target.cortex_a.operators import CortexACompatibilityInfo
from mlia.target.cortex_a.operators import get_cortex_a_compatibility_info
from mlia.target.cortex_a.operators import Operator
-def test_op_compat_data() -> None:
+def test_compat_data() -> None:
"""Make sure all data contains the necessary items."""
builtin_tfl_ops = {op.name for op in TFL_OP}
- for data in [op_compat.ARMNN_TFLITE_DELEGATE]:
+ for data in [compat.ARMNN_TFLITE_DELEGATE]:
assert "metadata" in data
assert "backend" in data["metadata"]
assert "version" in data["metadata"]
diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py
new file mode 100644
index 0000000..e6ee296
--- /dev/null
+++ b/tests/test_target_registry.py
@@ -0,0 +1,82 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the target registry module."""
+from __future__ import annotations
+
+import pytest
+
+from mlia.core.common import AdviceCategory
+from mlia.target.registry import registry
+from mlia.target.registry import supported_advice
+from mlia.target.registry import supported_backends
+from mlia.target.registry import supported_targets
+
+
+@pytest.mark.parametrize(
+ "expected_target", ("Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA")
+)
+def test_target_registry(expected_target: str) -> None:
+ """Test the target registry."""
+ assert expected_target in registry.items, (
+ f"Expected target '{expected_target}' not contained in registered "
+ f"targets '{registry.items.keys()}'."
+ )
+
+
+@pytest.mark.parametrize(
+ ("target_name", "expected_advices"),
+ (
+ ("Cortex-A", [AdviceCategory.OPERATORS]),
+ (
+ "Ethos-U55",
+ [
+ AdviceCategory.OPERATORS,
+ AdviceCategory.OPTIMIZATION,
+ AdviceCategory.PERFORMANCE,
+ ],
+ ),
+ (
+ "Ethos-U65",
+ [
+ AdviceCategory.OPERATORS,
+ AdviceCategory.OPTIMIZATION,
+ AdviceCategory.PERFORMANCE,
+ ],
+ ),
+ ("TOSA", [AdviceCategory.OPERATORS]),
+ ),
+)
+def test_supported_advice(
+ target_name: str, expected_advices: list[AdviceCategory]
+) -> None:
+ """Test function supported_advice()."""
+ supported = supported_advice(target_name)
+ assert all(advice in expected_advices for advice in supported)
+ assert all(advice in supported for advice in expected_advices)
+
+
+@pytest.mark.parametrize(
+ ("target_name", "expected_backends"),
+ (
+ ("Cortex-A", ["ArmNNTFLiteDelegate"]),
+ ("Ethos-U55", ["Corstone-300", "Corstone-310", "Vela"]),
+ ("Ethos-U65", ["Corstone-300", "Corstone-310", "Vela"]),
+ ("TOSA", ["TOSA-Checker"]),
+ ),
+)
+def test_supported_backends(target_name: str, expected_backends: list[str]) -> None:
+ """Test function supported_backends()."""
+ assert sorted(expected_backends) == sorted(supported_backends(target_name))
+
+
+@pytest.mark.parametrize(
+ ("advice", "expected_targets"),
+ (
+ (AdviceCategory.OPERATORS, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]),
+ (AdviceCategory.OPTIMIZATION, ["Ethos-U55", "Ethos-U65"]),
+ (AdviceCategory.PERFORMANCE, ["Ethos-U55", "Ethos-U65"]),
+ ),
+)
+def test_supported_targets(advice: AdviceCategory, expected_targets: list[str]) -> None:
+ """Test function supported_targets()."""
+ assert sorted(expected_targets) == sorted(supported_targets(advice))
diff --git a/tests/test_utils_registry.py b/tests/test_utils_registry.py
new file mode 100644
index 0000000..95721fc
--- /dev/null
+++ b/tests/test_utils_registry.py
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test the Registry base class."""
+from mlia.utils.registry import Registry
+
+
+def test_registry() -> None:
+ """Test Registry class."""
+ reg = Registry[str]()
+ assert not str(reg)
+ assert reg.register("name", "value")
+ assert not reg.register("name", "value")
+ assert "name" in reg.items
+ assert reg.items["name"] == "value"
+ assert str(reg)
+ assert reg.register("other_name", "value_2")
+ assert len(reg.items) == 2
+ assert "other_name" in reg.items
+ assert reg.items["other_name"] == "value_2"