aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-12-14 11:20:11 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-01-04 10:11:33 +0000
commitdcd0bd31985c27e1d07333351b26cf8ad12ad1fd (patch)
treea3388ff5f91e7cdc7ec41271a1a76cdbfae38ece /src/mlia/target
parent4b4cf29cb1e7d917ae001e258ff01f7846c34778 (diff)
downloadmlia-dcd0bd31985c27e1d07333351b26cf8ad12ad1fd.tar.gz
MLIA-589 Create an API to get target information
Change-Id: Ieeaa9188ea1e29e2ccaad7475d457bce71e3140d
Diffstat (limited to 'src/mlia/target')
-rw-r--r--src/mlia/target/__init__.py6
-rw-r--r--src/mlia/target/config.py36
-rw-r--r--src/mlia/target/cortex_a/__init__.py4
-rw-r--r--src/mlia/target/cortex_a/operator_compatibility.py184
-rw-r--r--src/mlia/target/cortex_a/operators.py6
-rw-r--r--src/mlia/target/ethos_u/__init__.py5
-rw-r--r--src/mlia/target/registry.py34
-rw-r--r--src/mlia/target/tosa/__init__.py4
8 files changed, 92 insertions, 187 deletions
diff --git a/src/mlia/target/__init__.py b/src/mlia/target/__init__.py
index 2370221..a9979c6 100644
--- a/src/mlia/target/__init__.py
+++ b/src/mlia/target/__init__.py
@@ -1,3 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target module."""
+# Make sure all targets are registered with the registry by importing the
+# sub-modules
+# flake8: noqa
+from mlia.target import cortex_a
+from mlia.target import ethos_u
+from mlia.target import tosa
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index 7ab6b43..f257784 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -1,6 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""IP configuration module."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from mlia.backend.registry import registry as backend_registry
+from mlia.core.common import AdviceCategory
class IPConfiguration: # pylint: disable=too-few-public-methods
@@ -9,3 +15,33 @@ class IPConfiguration: # pylint: disable=too-few-public-methods
def __init__(self, target: str) -> None:
"""Init IP configuration instance."""
self.target = target
+
+
+@dataclass
+class TargetInfo:
+ """Collect information about supported targets."""
+
+ supported_backends: list[str]
+
+ def __str__(self) -> str:
+ """List supported backends."""
+ return ", ".join(sorted(self.supported_backends))
+
+ def is_supported(
+ self, advice: AdviceCategory | None = None, check_system: bool = False
+ ) -> bool:
+ """Check if any of the supported backends support this kind of advice."""
+ return any(
+ backend_registry.items[name].is_supported(advice, check_system)
+ for name in self.supported_backends
+ )
+
+ def filter_supported_backends(
+ self, advice: AdviceCategory | None = None, check_system: bool = False
+ ) -> list[str]:
+ """Get the list of supported backends filtered by the given arguments."""
+ return [
+ name
+ for name in self.supported_backends
+ if backend_registry.items[name].is_supported(advice, check_system)
+ ]
diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py
index fe01835..9b0e611 100644
--- a/src/mlia/target/cortex_a/__init__.py
+++ b/src/mlia/target/cortex_a/__init__.py
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A target module."""
+from mlia.target.registry import registry
+from mlia.target.registry import TargetInfo
+
+registry.register("Cortex-A", TargetInfo(["ArmNNTFLiteDelegate"]))
diff --git a/src/mlia/target/cortex_a/operator_compatibility.py b/src/mlia/target/cortex_a/operator_compatibility.py
deleted file mode 100644
index c474e75..0000000
--- a/src/mlia/target/cortex_a/operator_compatibility.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Collection of Cortex-A operator compatibility information."""
-from __future__ import annotations
-
-from typing import Any
-
-ARMNN_TFLITE_DELEGATE: dict[str, dict[str, Any]] = {
- "metadata": {
- "backend": "Arm NN TensorFlow Lite delegate",
- "version": "22.08",
- },
- # BUILTIN OPERATORS
- "builtin_ops": {
- "ABS": {},
- "ADD": {},
- "ARG_MAX": {},
- "ARG_MIN": {},
- "AVERAGE_POOL_2D": {
- "supported_fused_activation": [
- "RELU",
- "RELU6",
- "RELU_N1_TO_1",
- "SIGMOID",
- "TANH",
- "NONE",
- ]
- },
- "BATCH_TO_SPACE_ND": {},
- "CAST": {},
- "CONCATENATION": {
- "supported_fused_activation": [
- "RELU",
- "RELU6",
- "RELU_N1_TO_1",
- "SIGMOID",
- "TANH",
- "NONE",
- ]
- },
- "CONV_2D": {
- "supported_fused_activation": [
- "RELU",
- "RELU6",
- "RELU_N1_TO_1",
- "SIGMOID",
- "TANH",
- "NONE",
- ]
- },
- "CONV_3D": {
- "supported_fused_activation": [
- "RELU",
- "RELU6",
- "RELU_N1_TO_1",
- "SIGMOID",
- "TANH",
- "NONE",
- ]
- },
- "DEPTH_TO_SPACE": {},
- "DEPTHWISE_CONV_2D": {
- "supported_fused_activation": [
- "RELU",
- "RELU6",
- "RELU_N1_TO_1",
- "SIGMOID",
- "TANH",
- "NONE",
- ]
- },
- "DEQUANTIZE": {},
- "DIV": {},
- "EQUAL": {},
- "ELU": {},
- "EXP": {},
- "EXPAND_DIMS": {},
- "FILL": {},
- "FLOOR": {},
- "FLOOR_DIV": {},
- "FULLY_CONNECTED": {
- "supported_fused_activation": [
- "RELU",
- "RELU6",
- "RELU_N1_TO_1",
- "SIGMOID",
- "TANH",
- "NONE",
- ]
- },
- "GATHER": {},
- "GATHER_ND": {},
- "GREATER": {},
- "GREATER_EQUAL": {},
- "HARD_SWISH": {},
- "L2_NORMALIZATION": {},
- "L2_POOL_2D": {},
- "LESS": {},
- "LESS_EQUAL": {},
- "LOCAL_RESPONSE_NORMALIZATION": {},
- "LOG": {},
- "LOGICAL_AND": {},
- "LOGICAL_NOT": {},
- "LOGICAL_OR": {},
- "LOGISTIC": {},
- "LOG_SOFTMAX": {},
- "LSTM": {},
- "MAXIMUM": {},
- "MAX_POOL_2D": {
- "supported_fused_activation": [
- "RELU",
- "RELU6",
- "RELU_N1_TO_1",
- "SIGMOID",
- "TANH",
- "NONE",
- ]
- },
- "MEAN": {},
- "MINIMUM": {},
- "MIRROR_PAD": {},
- "MUL": {},
- "NEG": {},
- "NOT_EQUAL": {},
- "PACK": {},
- "PAD": {},
- "PADV2": {},
- "PRELU": {},
- "QUANTIZE": {},
- "RANK": {},
- "REDUCE_MAX": {},
- "REDUCE_MIN": {},
- "REDUCE_PROD": {},
- "RELU": {},
- "RELU6": {},
- "RELU_N1_TO_1": {},
- "RESHAPE": {},
- "RESIZE_BILINEAR": {},
- "RESIZE_NEAREST_NEIGHBOR": {},
- "RSQRT": {},
- "SHAPE": {},
- "SIN": {},
- "SOFTMAX": {},
- "SPACE_TO_BATCH_ND": {},
- "SPACE_TO_DEPTH": {},
- "SPLIT": {},
- "SPLIT_V": {},
- "SQRT": {},
- "SQUEEZE": {},
- "STRIDED_SLICE": {},
- "SUB": {},
- "SUM": {},
- "TANH": {},
- "TRANSPOSE": {},
- "TRANSPOSE_CONV": {},
- "UNIDIRECTIONAL_SEQUENCE_LSTM": {},
- "UNPACK": {},
- },
- # CUSTOM OPERATORS
- "custom_ops": {
- "AveragePool3D": {
- "supported_fused_activation": [
- "RELU",
- "RELU6",
- "RELU_N1_TO_1",
- "SIGMOID",
- "SIGN_BIT",
- "TANH",
- "NONE",
- ]
- },
- "MaxPool3D": {
- "supported_fused_activation": [
- "RELU",
- "RELU6",
- "RELU_N1_TO_1",
- "SIGMOID",
- "SIGN_BIT",
- "TANH",
- "NONE",
- ]
- },
- },
-}
diff --git a/src/mlia/target/cortex_a/operators.py b/src/mlia/target/cortex_a/operators.py
index 91f1886..ae611e5 100644
--- a/src/mlia/target/cortex_a/operators.py
+++ b/src/mlia/target/cortex_a/operators.py
@@ -9,12 +9,12 @@ from pathlib import Path
from typing import Any
from typing import ClassVar
+from mlia.backend.armnn_tflite_delegate.compat import (
+ ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT,
+)
from mlia.nn.tensorflow.tflite_graph import Op
from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
-from mlia.target.cortex_a.operator_compatibility import (
- ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT,
-)
@dataclass
diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py
index 503919d..3c92ae5 100644
--- a/src/mlia/target/ethos_u/__init__.py
+++ b/src/mlia/target/ethos_u/__init__.py
@@ -1,3 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U target module."""
+from mlia.target.registry import registry
+from mlia.target.registry import TargetInfo
+
+registry.register("Ethos-U55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
+registry.register("Ethos-U65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py
new file mode 100644
index 0000000..6b33084
--- /dev/null
+++ b/src/mlia/target/registry.py
@@ -0,0 +1,34 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Target module."""
+from __future__ import annotations
+
+from mlia.backend.registry import registry as backend_registry
+from mlia.core.common import AdviceCategory
+from mlia.target.config import TargetInfo
+from mlia.utils.registry import Registry
+
+# All supported targets are required to be registered here.
+registry = Registry[TargetInfo]()
+
+
+def supported_advice(target: str) -> list[AdviceCategory]:
+ """Get a list of supported advice for the given target."""
+ advice: set[AdviceCategory] = set()
+ for supported_backend in registry.items[target].supported_backends:
+ advice.update(backend_registry.items[supported_backend].supported_advice)
+ return list(advice)
+
+
+def supported_backends(target: str) -> list[str]:
+ """Get a list of backends supported by the given target."""
+ return registry.items[target].filter_supported_backends(check_system=False)
+
+
+def supported_targets(advice: AdviceCategory) -> list[str]:
+ """Get a list of all targets supporting the given advice category."""
+ return [
+ name
+ for name, info in registry.items.items()
+ if info.is_supported(advice, check_system=False)
+ ]
diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py
index 762c831..33c9cf2 100644
--- a/src/mlia/target/tosa/__init__.py
+++ b/src/mlia/target/tosa/__init__.py
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA target module."""
+from mlia.target.registry import registry
+from mlia.target.registry import TargetInfo
+
+registry.register("TOSA", TargetInfo(["TOSA-Checker"]))