aboutsummaryrefslogtreecommitdiff
path: root/src/mlia
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-12-14 11:20:11 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-01-04 10:11:33 +0000
commitdcd0bd31985c27e1d07333351b26cf8ad12ad1fd (patch)
treea3388ff5f91e7cdc7ec41271a1a76cdbfae38ece /src/mlia
parent4b4cf29cb1e7d917ae001e258ff01f7846c34778 (diff)
downloadmlia-dcd0bd31985c27e1d07333351b26cf8ad12ad1fd.tar.gz
MLIA-589 Create an API to get target information
Change-Id: Ieeaa9188ea1e29e2ccaad7475d457bce71e3140d
Diffstat (limited to 'src/mlia')
-rw-r--r--src/mlia/backend/__init__.py9
-rw-r--r--src/mlia/backend/armnn_tflite_delegate/__init__.py16
-rw-r--r--src/mlia/backend/armnn_tflite_delegate/compat.py (renamed from src/mlia/target/cortex_a/operator_compatibility.py)0
-rw-r--r--src/mlia/backend/config.py82
-rw-r--r--src/mlia/backend/corstone/__init__.py22
-rw-r--r--src/mlia/backend/registry.py8
-rw-r--r--src/mlia/backend/tosa_checker/__init__.py14
-rw-r--r--src/mlia/backend/vela/__init__.py23
-rw-r--r--src/mlia/cli/main.py9
-rw-r--r--src/mlia/cli/options.py2
-rw-r--r--src/mlia/target/__init__.py6
-rw-r--r--src/mlia/target/config.py36
-rw-r--r--src/mlia/target/cortex_a/__init__.py4
-rw-r--r--src/mlia/target/cortex_a/operators.py6
-rw-r--r--src/mlia/target/ethos_u/__init__.py5
-rw-r--r--src/mlia/target/registry.py34
-rw-r--r--src/mlia/target/tosa/__init__.py4
-rw-r--r--src/mlia/utils/registry.py31
18 files changed, 302 insertions, 9 deletions
diff --git a/src/mlia/backend/__init__.py b/src/mlia/backend/__init__.py
index 745aa1b..2d1da70 100644
--- a/src/mlia/backend/__init__.py
+++ b/src/mlia/backend/__init__.py
@@ -1,3 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Backends module."""
+"""Backend module."""
+# Make sure all targets are registered with the registry by importing the
+# sub-modules
+# flake8: noqa
+from mlia.backend import armnn_tflite_delegate
+from mlia.backend import corstone
+from mlia.backend import tosa_checker
+from mlia.backend import vela
diff --git a/src/mlia/backend/armnn_tflite_delegate/__init__.py b/src/mlia/backend/armnn_tflite_delegate/__init__.py
new file mode 100644
index 0000000..6d5af42
--- /dev/null
+++ b/src/mlia/backend/armnn_tflite_delegate/__init__.py
@@ -0,0 +1,16 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Arm NN TensorFlow Lite delegate backend module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.registry import registry
+from mlia.core.common import AdviceCategory
+
+registry.register(
+ "ArmNNTFLiteDelegate",
+ BackendConfiguration(
+ supported_advice=[AdviceCategory.OPERATORS],
+ supported_systems=None,
+ backend_type=BackendType.BUILTIN,
+ ),
+)
diff --git a/src/mlia/target/cortex_a/operator_compatibility.py b/src/mlia/backend/armnn_tflite_delegate/compat.py
index c474e75..c474e75 100644
--- a/src/mlia/target/cortex_a/operator_compatibility.py
+++ b/src/mlia/backend/armnn_tflite_delegate/compat.py
diff --git a/src/mlia/backend/config.py b/src/mlia/backend/config.py
new file mode 100644
index 0000000..8d14b28
--- /dev/null
+++ b/src/mlia/backend/config.py
@@ -0,0 +1,82 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Backend config module."""
+from __future__ import annotations
+
+import platform
+from enum import auto
+from enum import Enum
+from typing import cast
+
+from mlia.core.common import AdviceCategory
+
+
+class System(Enum):
+ """Enum of system configurations (e.g. OS and architecture)."""
+
+ LINUX_AMD64 = ("Linux", "x86_64")
+ LINUX_AARCH64 = ("Linux", "aarch64")
+ WINDOWS_AMD64 = ("Windows", "AMD64")
+ WINDOWS_AARCH64 = ("Windows", "ARM64")
+ DARWIN_AARCH64 = ("Darwin", "arm64")
+ CURRENT = (platform.system(), platform.machine())
+
+ def __init__(
+ self, system: str = platform.system(), machine: str = platform.machine()
+ ) -> None:
+ """Set the system parameters (defaults to the current system)."""
+ self.system = system.lower()
+ self.machine = machine.lower()
+
+ def __eq__(self, other: object) -> bool:
+ """
+ Compare two System instances for equality.
+
+ Raises a TypeError if the input is not a System.
+ """
+ if isinstance(other, self.__class__):
+ return self.system == other.system and self.machine == other.machine
+ return False
+
+ def is_compatible(self) -> bool:
+ """Check if this system is compatible with the current system."""
+ return self == self.CURRENT
+
+
+class BackendType(Enum):
+ """Define the type of the backend (builtin, wheel file etc)."""
+
+ BUILTIN = auto()
+ WHEEL = auto()
+ CUSTOM = auto()
+
+
+class BackendConfiguration:
+ """Base class for backend configurations."""
+
+ def __init__(
+ self,
+ supported_advice: list[AdviceCategory],
+ supported_systems: list[System] | None,
+ backend_type: BackendType,
+ ) -> None:
+ """Set up basic information about the backend."""
+ self.supported_advice = supported_advice
+ self.supported_systems = supported_systems
+ self.type = backend_type
+
+ def __str__(self) -> str:
+ """List supported advice."""
+ return ", ".join(cast(str, adv.name).lower() for adv in self.supported_advice)
+
+ def is_supported(
+ self, advice: AdviceCategory | None = None, check_system: bool = False
+ ) -> bool:
+ """Check backend supports the current system and advice."""
+ is_system_supported = (
+ not self.supported_systems
+ or not check_system
+ or any(sys.is_compatible() for sys in self.supported_systems)
+ )
+ is_advice_supported = advice is None or advice in self.supported_advice
+ return is_system_supported and is_advice_supported
diff --git a/src/mlia/backend/corstone/__init__.py b/src/mlia/backend/corstone/__init__.py
index a1eac14..f89da63 100644
--- a/src/mlia/backend/corstone/__init__.py
+++ b/src/mlia/backend/corstone/__init__.py
@@ -1,3 +1,25 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Corstone backend module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.backend.registry import registry
+from mlia.core.common import AdviceCategory
+
+registry.register(
+ "Corstone-300",
+ BackendConfiguration(
+ supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
+ supported_systems=[System.LINUX_AMD64],
+ backend_type=BackendType.CUSTOM,
+ ),
+)
+registry.register(
+ "Corstone-310",
+ BackendConfiguration(
+ supported_advice=[AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION],
+ supported_systems=[System.LINUX_AMD64],
+ backend_type=BackendType.CUSTOM,
+ ),
+)
diff --git a/src/mlia/backend/registry.py b/src/mlia/backend/registry.py
new file mode 100644
index 0000000..6a0da74
--- /dev/null
+++ b/src/mlia/backend/registry.py
@@ -0,0 +1,8 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Backend module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.utils.registry import Registry
+
+# All supported targets are required to be registered here.
+registry = Registry[BackendConfiguration]()
diff --git a/src/mlia/backend/tosa_checker/__init__.py b/src/mlia/backend/tosa_checker/__init__.py
index cec210d..19fc8be 100644
--- a/src/mlia/backend/tosa_checker/__init__.py
+++ b/src/mlia/backend/tosa_checker/__init__.py
@@ -1,3 +1,17 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA checker backend module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.backend.registry import registry
+from mlia.core.common import AdviceCategory
+
+registry.register(
+ "TOSA-Checker",
+ BackendConfiguration(
+ supported_advice=[AdviceCategory.OPERATORS],
+ supported_systems=[System.LINUX_AMD64],
+ backend_type=BackendType.WHEEL,
+ ),
+)
diff --git a/src/mlia/backend/vela/__init__.py b/src/mlia/backend/vela/__init__.py
index 6ea0c21..38a623e 100644
--- a/src/mlia/backend/vela/__init__.py
+++ b/src/mlia/backend/vela/__init__.py
@@ -1,3 +1,26 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Vela backend module."""
+from mlia.backend.config import BackendConfiguration
+from mlia.backend.config import BackendType
+from mlia.backend.config import System
+from mlia.backend.registry import registry
+from mlia.core.common import AdviceCategory
+
+registry.register(
+ "Vela",
+ BackendConfiguration(
+ supported_advice=[
+ AdviceCategory.OPERATORS,
+ AdviceCategory.PERFORMANCE,
+ AdviceCategory.OPTIMIZATION,
+ ],
+ supported_systems=[
+ System.LINUX_AMD64,
+ System.LINUX_AARCH64,
+ System.WINDOWS_AMD64,
+ System.WINDOWS_AARCH64,
+ ],
+ backend_type=BackendType.BUILTIN,
+ ),
+)
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 98fdb63..ac60308 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -12,6 +12,7 @@ from pathlib import Path
from mlia import __version__
from mlia.backend.errors import BackendUnavailableError
+from mlia.backend.registry import registry as backend_registry
from mlia.cli.commands import all_tests
from mlia.cli.commands import backend_install
from mlia.cli.commands import backend_list
@@ -36,6 +37,7 @@ from mlia.cli.options import add_tflite_model_options
from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
+from mlia.target.registry import registry as target_registry
logger = logging.getLogger(__name__)
@@ -46,11 +48,10 @@ ML Inference Advisor {__version__}
Help the design and optimization of neural network models for efficient inference on a target CPU and NPU
Supported targets:
+{target_registry}
- - Cortex-A <op compatibility>
- - Ethos-U55 <op compatibility, perf estimation, model opt>
- - Ethos-U65 <op compatibility, perf estimation, model opt>
- - TOSA <op compatibility>
+Supported backends:
+{backend_registry}
""".strip()
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 5eab9aa..8ea4250 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -38,7 +38,7 @@ def add_target_options(
help="Target profile that will set the target options "
"such as target, mac value, memory mode, etc. "
f"For the values associated with each target profile "
- f" please refer to the documenation {default_help}.",
+ f" please refer to the documentation {default_help}.",
)
diff --git a/src/mlia/target/__init__.py b/src/mlia/target/__init__.py
index 2370221..a9979c6 100644
--- a/src/mlia/target/__init__.py
+++ b/src/mlia/target/__init__.py
@@ -1,3 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target module."""
+# Make sure all targets are registered with the registry by importing the
+# sub-modules
+# flake8: noqa
+from mlia.target import cortex_a
+from mlia.target import ethos_u
+from mlia.target import tosa
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index 7ab6b43..f257784 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -1,6 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""IP configuration module."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from mlia.backend.registry import registry as backend_registry
+from mlia.core.common import AdviceCategory
class IPConfiguration: # pylint: disable=too-few-public-methods
@@ -9,3 +15,33 @@ class IPConfiguration: # pylint: disable=too-few-public-methods
def __init__(self, target: str) -> None:
"""Init IP configuration instance."""
self.target = target
+
+
+@dataclass
+class TargetInfo:
+ """Collect information about supported targets."""
+
+ supported_backends: list[str]
+
+ def __str__(self) -> str:
+ """List supported backends."""
+ return ", ".join(sorted(self.supported_backends))
+
+ def is_supported(
+ self, advice: AdviceCategory | None = None, check_system: bool = False
+ ) -> bool:
+ """Check if any of the supported backends support this kind of advice."""
+ return any(
+ backend_registry.items[name].is_supported(advice, check_system)
+ for name in self.supported_backends
+ )
+
+ def filter_supported_backends(
+ self, advice: AdviceCategory | None = None, check_system: bool = False
+ ) -> list[str]:
+ """Get the list of supported backends filtered by the given arguments."""
+ return [
+ name
+ for name in self.supported_backends
+ if backend_registry.items[name].is_supported(advice, check_system)
+ ]
diff --git a/src/mlia/target/cortex_a/__init__.py b/src/mlia/target/cortex_a/__init__.py
index fe01835..9b0e611 100644
--- a/src/mlia/target/cortex_a/__init__.py
+++ b/src/mlia/target/cortex_a/__init__.py
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A target module."""
+from mlia.target.registry import registry
+from mlia.target.registry import TargetInfo
+
+registry.register("Cortex-A", TargetInfo(["ArmNNTFLiteDelegate"]))
diff --git a/src/mlia/target/cortex_a/operators.py b/src/mlia/target/cortex_a/operators.py
index 91f1886..ae611e5 100644
--- a/src/mlia/target/cortex_a/operators.py
+++ b/src/mlia/target/cortex_a/operators.py
@@ -9,12 +9,12 @@ from pathlib import Path
from typing import Any
from typing import ClassVar
+from mlia.backend.armnn_tflite_delegate.compat import (
+ ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT,
+)
from mlia.nn.tensorflow.tflite_graph import Op
from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
-from mlia.target.cortex_a.operator_compatibility import (
- ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT,
-)
@dataclass
diff --git a/src/mlia/target/ethos_u/__init__.py b/src/mlia/target/ethos_u/__init__.py
index 503919d..3c92ae5 100644
--- a/src/mlia/target/ethos_u/__init__.py
+++ b/src/mlia/target/ethos_u/__init__.py
@@ -1,3 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U target module."""
+from mlia.target.registry import registry
+from mlia.target.registry import TargetInfo
+
+registry.register("Ethos-U55", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
+registry.register("Ethos-U65", TargetInfo(["Vela", "Corstone-300", "Corstone-310"]))
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py
new file mode 100644
index 0000000..6b33084
--- /dev/null
+++ b/src/mlia/target/registry.py
@@ -0,0 +1,34 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Target module."""
+from __future__ import annotations
+
+from mlia.backend.registry import registry as backend_registry
+from mlia.core.common import AdviceCategory
+from mlia.target.config import TargetInfo
+from mlia.utils.registry import Registry
+
+# All supported targets are required to be registered here.
+registry = Registry[TargetInfo]()
+
+
+def supported_advice(target: str) -> list[AdviceCategory]:
+ """Get a list of supported advice for the given target."""
+ advice: set[AdviceCategory] = set()
+ for supported_backend in registry.items[target].supported_backends:
+ advice.update(backend_registry.items[supported_backend].supported_advice)
+ return list(advice)
+
+
+def supported_backends(target: str) -> list[str]:
+ """Get a list of backends supported by the given target."""
+ return registry.items[target].filter_supported_backends(check_system=False)
+
+
+def supported_targets(advice: AdviceCategory) -> list[str]:
+ """Get a list of all targets supporting the given advice category."""
+ return [
+ name
+ for name, info in registry.items.items()
+ if info.is_supported(advice, check_system=False)
+ ]
diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py
index 762c831..33c9cf2 100644
--- a/src/mlia/target/tosa/__init__.py
+++ b/src/mlia/target/tosa/__init__.py
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA target module."""
+from mlia.target.registry import registry
+from mlia.target.registry import TargetInfo
+
+registry.register("TOSA", TargetInfo(["TOSA-Checker"]))
diff --git a/src/mlia/utils/registry.py b/src/mlia/utils/registry.py
new file mode 100644
index 0000000..9b25a81
--- /dev/null
+++ b/src/mlia/utils/registry.py
@@ -0,0 +1,31 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Generic registry class."""
+from __future__ import annotations
+
+from typing import Generic
+from typing import TypeVar
+
+T = TypeVar("T")
+
+
+class Registry(Generic[T]):
+ """Generic registry for name-config pairs."""
+
+ def __init__(self) -> None:
+ """Create an empty registry."""
+ self.items: dict[str, T] = {}
+
+ def __str__(self) -> str:
+ """List all registered items."""
+ return "\n".join(
+ f"- {name}: {item}"
+ for name, item in sorted(self.items.items(), key=lambda v: v[0])
+ )
+
+ def register(self, name: str, item: T) -> bool:
+ """Register an item: returns `False` if already registered."""
+ if name in self.items:
+ return False # already registered
+ self.items[name] = item
+ return True