aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-10-25 18:12:34 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-11-10 16:47:22 +0000
commite40a7adadd254e29d71af38f69a0a20ff4871eef (patch)
tree9a57ddf406846785683673565359d9bd6ba3cf0b
parent720839a2dc6d4d75cd7aa77f83fcd49bcf114ba6 (diff)
downloadmlia-e40a7adadd254e29d71af38f69a0a20ff4871eef.tar.gz
MLIA-411 Report Cortex-A operator compatibility
Check input model for Arm NN TensorFlow Lite Delegate 22.08 support. Change-Id: I1253c4c0b294c5283e08f0a39561b922ef0f62e6
-rw-r--r--README.md39
-rw-r--r--src/mlia/cli/main.py3
-rw-r--r--src/mlia/devices/cortexa/advice_generation.py50
-rw-r--r--src/mlia/devices/cortexa/data_analysis.py57
-rw-r--r--src/mlia/devices/cortexa/operator_compatibility.py184
-rw-r--r--src/mlia/devices/cortexa/operators.py121
-rw-r--r--src/mlia/devices/cortexa/reporters.py12
-rw-r--r--src/mlia/nn/tensorflow/tflite_graph.py139
-rw-r--r--tests/test_cli_main.py15
-rw-r--r--tests/test_devices_cortexa_advice_generation.py76
-rw-r--r--tests/test_devices_cortexa_data_analysis.py81
-rw-r--r--tests/test_devices_cortexa_data_collection.py30
-rw-r--r--tests/test_devices_cortexa_operators.py73
-rw-r--r--tests/test_devices_cortexa_reporters.py52
-rw-r--r--tests/test_nn_tensorflow_tflite_graph.py81
15 files changed, 963 insertions, 50 deletions
diff --git a/README.md b/README.md
index ec4080d..c1c9ce6 100644
--- a/README.md
+++ b/README.md
@@ -14,12 +14,11 @@ If you find something that concerns you, email terms@arm.com.
## Introduction
This tool is used to help AI developers design and optimize neural network
-models for efficient inference on Arm® targets (e.g. Cortex®-M55 and
-Ethos™-U55/Ethos™-U65, Cortex®-M85 and Ethos™-U55) by enabling performance analysis
-and providing actionable advice early in the model development cycle. The final
-advice can cover the operator list, performance analysis and suggestions for
-model inference run on certain hardware before/after applying model optimization
-(e.g. pruning, clustering, etc.).
+models for efficient inference on Arm® targets (e.g. Cortex®-A or
+Ethos™-U55/Ethos™-U65 with Cortex®-M55/Cortex®-M85) by enabling performance
+analysis and providing actionable advice early in the model development cycle.
+The final advice can cover supported operators, performance analysis and
+suggestions for model optimization (e.g. pruning, clustering, etc.).
## Prerequisites and dependencies
@@ -84,17 +83,21 @@ Not all backends work on any platform. Please refer to the compatibility table
below:
```
-+---------------------------------------------------------------------------+
-| Backend | Linux | Windows | Python |
-+============================================================================
-| Corstone-300 | x86_64 | Not compatible | Python>=3.8 |
-+----------------------------------------------------------------------------
-| Corstone-310 | x86_64 | Not compatible | Python>=3.8 |
-+----------------------------------------------------------------------------
-| TOSA checker | x86_64 (manylinux2014) | Not compatible | 3.7<=Python<=3.9 |
-+----------------------------------------------------------------------------
-| Vela | x86_64 | Windows 10 | Python~=3.7 |
-+---------------------------------------------------------------------------+
++----------------------------------------------------------------------------+
+| Backend | Linux | Windows | Python |
++=============================================================================
+| Arm NN | | | |
+| TensorFlow | x86_64 | Windows 10 | Python>=3.8 |
+| Lite delegate | | | |
++-----------------------------------------------------------------------------
+| Corstone-300 | x86_64 | Not compatible | Python>=3.8 |
++-----------------------------------------------------------------------------
+| Corstone-310 | x86_64 | Not compatible | Python>=3.8 |
++-----------------------------------------------------------------------------
+| TOSA checker | x86_64 (manylinux2014) | Not compatible | 3.7<=Python<=3.9 |
++-----------------------------------------------------------------------------
+| Vela | x86_64 | Windows 10 | Python~=3.7 |
++----------------------------------------------------------------------------+
```
### Using Corstone™-300
@@ -207,6 +210,7 @@ mlia operators --target-profile ethos-u55-256 ~/models/mobilenet_v1_1.0_224_quan
target, MAC value, memory mode, etc ...
* default: ethos-u55-256
* options:
+ * cortex-a
* ethos-u55-256
* ethos-u55-128
* ethos-u65-512
@@ -378,6 +382,7 @@ mlia all_tests --output ./report.json ~/models/ds_cnn_l.h5
target, MAC value, memory mode, etc ...
* default: ethos-u55-256
* options:
+ * cortex-a
* ethos-u55-256
* ethos-u55-128
* ethos-u65-512
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index bafe434..d36d2d9 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -37,10 +37,11 @@ logger = logging.getLogger(__name__)
INFO_MESSAGE = f"""
ML Inference Advisor {__version__}
-Help the design and optimization of neural network models for efficient inference on a target CPU, GPU and NPU
+Help the design and optimization of neural network models for efficient inference on a target CPU and NPU
Supported targets:
+ - Cortex-A <op compatibility>
- Ethos-U55 <op compatibility, perf estimation, model opt>
- Ethos-U65 <op compatibility, perf estimation, model opt>
- TOSA <op compatibility>
diff --git a/src/mlia/devices/cortexa/advice_generation.py b/src/mlia/devices/cortexa/advice_generation.py
index 0f3553f..34c51f8 100644
--- a/src/mlia/devices/cortexa/advice_generation.py
+++ b/src/mlia/devices/cortexa/advice_generation.py
@@ -15,6 +15,13 @@ from mlia.devices.cortexa.data_analysis import ModelIsNotTFLiteCompatible
class CortexAAdviceProducer(FactBasedAdviceProducer):
"""Cortex-A advice producer."""
+ cortex_a_disclaimer = (
+ "Note that the provided compatibility information is general. "
+ "At runtime individual operators in the given model might fall back to "
+ "the TensorFlow Lite reference or might produce errors based on the "
+ "specific parameters."
+ )
+
@singledispatchmethod
def produce_advice(self, _data_item: DataItem) -> None:
"""Produce advice."""
@@ -22,21 +29,54 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
@produce_advice.register
@advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
def handle_model_is_cortex_a_compatible(
- self, _data_item: ModelIsCortexACompatible
+ self, data_item: ModelIsCortexACompatible
) -> None:
"""Advice for Cortex-A compatibility."""
- self.add_advice(["Model is fully compatible with Cortex-A."])
+ self.add_advice(
+ [
+ f"Model is fully compatible with {data_item.backend_info} for "
+ "Cortex-A.",
+ self.cortex_a_disclaimer,
+ ]
+ )
@produce_advice.register
@advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
def handle_model_is_not_cortex_a_compatible(
- self, _data_item: ModelIsNotCortexACompatible
+ self, data_item: ModelIsNotCortexACompatible
) -> None:
"""Advice for Cortex-A compatibility."""
+ if data_item.unsupported_ops:
+ self.add_advice(
+ [
+ "The following operators are not supported by "
+ f"{data_item.backend_info} and will fall back to the "
+ "TensorFlow Lite runtime:",
+ "\n".join(f" - {op}" for op in data_item.unsupported_ops),
+ ]
+ )
+
+ if data_item.activation_func_support:
+ self.add_advice(
+ [
+ "The fused activation functions of the following operators "
+ f"are not supported by {data_item.backend_info}. Please "
+ "consider using one of the supported activation functions "
+ "instead:",
+ "\n".join(
+ f" - {op}\n"
+ f" - Used unsupported: {act.used_unsupported}\n"
+ f" - Supported: {act.supported}"
+ for op, act in data_item.activation_func_support.items()
+ ),
+ ]
+ )
+
self.add_advice(
[
- "Some operators in the model are not compatible with Cortex-A. "
- "Please, refer to the operators table for more information."
+ "Please, refer to the full table of operators above for more "
+ "information.",
+ self.cortex_a_disclaimer,
]
)
diff --git a/src/mlia/devices/cortexa/data_analysis.py b/src/mlia/devices/cortexa/data_analysis.py
index d2b6f35..9f6d82b 100644
--- a/src/mlia/devices/cortexa/data_analysis.py
+++ b/src/mlia/devices/cortexa/data_analysis.py
@@ -3,13 +3,16 @@
"""Cortex-A data analysis module."""
from __future__ import annotations
+from collections import defaultdict
from dataclasses import dataclass
+from dataclasses import field
from functools import singledispatchmethod
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.core.data_analysis import FactExtractor
from mlia.devices.cortexa.operators import CortexACompatibilityInfo
+from mlia.devices.cortexa.operators import Operator
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
@@ -27,9 +30,38 @@ class CortexADataAnalyzer(FactExtractor):
) -> None:
"""Analyse operator compatibility information."""
if data_item.cortex_a_compatible:
- self.add_fact(ModelIsCortexACompatible())
+ self.add_fact(ModelIsCortexACompatible(data_item.backend_info))
else:
- self.add_fact(ModelIsNotCortexACompatible())
+ unsupported_ops = set()
+ activation_func_support: defaultdict[
+ str, ModelIsNotCortexACompatible.ActivationFunctionSupport
+ ] = defaultdict(ModelIsNotCortexACompatible.ActivationFunctionSupport)
+ for oper in data_item.operators:
+ if oper.support_type == Operator.SupportType.OP_NOT_SUPPORTED:
+ unsupported_ops.add(oper.full_name)
+
+ if oper.support_type == Operator.SupportType.ACTIVATION_NOT_SUPPORTED:
+ # Add used but unsupported actication functions
+ activation_func_support[oper.full_name].used_unsupported.add(
+ oper.activation_func.name
+ )
+ # Add supported activation functions
+ activation_func_support[oper.full_name].supported.update(
+ oper.supported_activation_functions
+ )
+
+ assert (
+ unsupported_ops or activation_func_support or not data_item.operators
+ ), (
+ "The model is marked as not compatible with Cortex-A but there "
+ "are no unsupported ops activation functions listed."
+ )
+
+ self.add_fact(
+ ModelIsNotCortexACompatible(
+ data_item.backend_info, unsupported_ops, activation_func_support
+ )
+ )
@analyze_data.register
def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None:
@@ -52,14 +84,31 @@ class CortexADataAnalyzer(FactExtractor):
@dataclass
-class ModelIsCortexACompatible(Fact):
+class CortexACompatibility(Fact):
+ """Base class for Cortex-A compatibility providing backend info."""
+
+ backend_info: str
+
+
+@dataclass
+class ModelIsCortexACompatible(CortexACompatibility):
"""Model is completely compatible with Cortex-A."""
@dataclass
-class ModelIsNotCortexACompatible(Fact):
+class ModelIsNotCortexACompatible(CortexACompatibility):
"""Model is not compatible with Cortex-A."""
+ @dataclass
+ class ActivationFunctionSupport:
+ """Activation function support per operator."""
+
+ used_unsupported: set[str] = field(default_factory=set)
+ supported: set[str] = field(default_factory=set)
+
+ unsupported_ops: set[str]
+ activation_func_support: dict[str, ActivationFunctionSupport]
+
@dataclass
class ModelIsNotTFLiteCompatible(Fact):
diff --git a/src/mlia/devices/cortexa/operator_compatibility.py b/src/mlia/devices/cortexa/operator_compatibility.py
new file mode 100644
index 0000000..c474e75
--- /dev/null
+++ b/src/mlia/devices/cortexa/operator_compatibility.py
@@ -0,0 +1,184 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Collection of Cortex-A operator compatibility information."""
+from __future__ import annotations
+
+from typing import Any
+
+ARMNN_TFLITE_DELEGATE: dict[str, dict[str, Any]] = {
+ "metadata": {
+ "backend": "Arm NN TensorFlow Lite delegate",
+ "version": "22.08",
+ },
+ # BUILTIN OPERATORS
+ "builtin_ops": {
+ "ABS": {},
+ "ADD": {},
+ "ARG_MAX": {},
+ "ARG_MIN": {},
+ "AVERAGE_POOL_2D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "BATCH_TO_SPACE_ND": {},
+ "CAST": {},
+ "CONCATENATION": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "CONV_2D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "CONV_3D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "DEPTH_TO_SPACE": {},
+ "DEPTHWISE_CONV_2D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "DEQUANTIZE": {},
+ "DIV": {},
+ "EQUAL": {},
+ "ELU": {},
+ "EXP": {},
+ "EXPAND_DIMS": {},
+ "FILL": {},
+ "FLOOR": {},
+ "FLOOR_DIV": {},
+ "FULLY_CONNECTED": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "GATHER": {},
+ "GATHER_ND": {},
+ "GREATER": {},
+ "GREATER_EQUAL": {},
+ "HARD_SWISH": {},
+ "L2_NORMALIZATION": {},
+ "L2_POOL_2D": {},
+ "LESS": {},
+ "LESS_EQUAL": {},
+ "LOCAL_RESPONSE_NORMALIZATION": {},
+ "LOG": {},
+ "LOGICAL_AND": {},
+ "LOGICAL_NOT": {},
+ "LOGICAL_OR": {},
+ "LOGISTIC": {},
+ "LOG_SOFTMAX": {},
+ "LSTM": {},
+ "MAXIMUM": {},
+ "MAX_POOL_2D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "MEAN": {},
+ "MINIMUM": {},
+ "MIRROR_PAD": {},
+ "MUL": {},
+ "NEG": {},
+ "NOT_EQUAL": {},
+ "PACK": {},
+ "PAD": {},
+ "PADV2": {},
+ "PRELU": {},
+ "QUANTIZE": {},
+ "RANK": {},
+ "REDUCE_MAX": {},
+ "REDUCE_MIN": {},
+ "REDUCE_PROD": {},
+ "RELU": {},
+ "RELU6": {},
+ "RELU_N1_TO_1": {},
+ "RESHAPE": {},
+ "RESIZE_BILINEAR": {},
+ "RESIZE_NEAREST_NEIGHBOR": {},
+ "RSQRT": {},
+ "SHAPE": {},
+ "SIN": {},
+ "SOFTMAX": {},
+ "SPACE_TO_BATCH_ND": {},
+ "SPACE_TO_DEPTH": {},
+ "SPLIT": {},
+ "SPLIT_V": {},
+ "SQRT": {},
+ "SQUEEZE": {},
+ "STRIDED_SLICE": {},
+ "SUB": {},
+ "SUM": {},
+ "TANH": {},
+ "TRANSPOSE": {},
+ "TRANSPOSE_CONV": {},
+ "UNIDIRECTIONAL_SEQUENCE_LSTM": {},
+ "UNPACK": {},
+ },
+ # CUSTOM OPERATORS
+ "custom_ops": {
+ "AveragePool3D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "SIGN_BIT",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "MaxPool3D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "SIGN_BIT",
+ "TANH",
+ "NONE",
+ ]
+ },
+ },
+}
diff --git a/src/mlia/devices/cortexa/operators.py b/src/mlia/devices/cortexa/operators.py
index d46b107..3e84d64 100644
--- a/src/mlia/devices/cortexa/operators.py
+++ b/src/mlia/devices/cortexa/operators.py
@@ -4,16 +4,113 @@
from __future__ import annotations
from dataclasses import dataclass
+from enum import Enum
from pathlib import Path
+from typing import Any
+from typing import ClassVar
+
+from mlia.devices.cortexa.operator_compatibility import (
+ ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT,
+)
+from mlia.nn.tensorflow.tflite_graph import Op
+from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
+from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
@dataclass
class Operator:
"""Cortex-A compatibility information of the operator."""
+ BUILTIN_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["builtin_ops"]
+ CUSTOM_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["custom_ops"]
+
+ class SupportType(Enum):
+ """Type of operator support."""
+
+ COMPATIBLE = "Compatible"
+ OP_NOT_SUPPORTED = "Operator not supported"
+ ACTIVATION_NOT_SUPPORTED = "Activation not supported"
+
name: str
location: str
- is_cortex_a_compatible: bool
+ support_type: SupportType
+ activation_func: TFL_ACTIVATION_FUNCTION
+ custom_name: str | None = None
+
+ @property
+ def is_cortex_a_compatible(self) -> bool:
+ """Check if this operator is compatible."""
+ return self.support_type == Operator.SupportType.COMPATIBLE
+
+ @property
+ def full_name(self) -> str:
+ """Returun the full name including the custom name if applicable."""
+ return self.name + (f" - '{self.custom_name}'" if self.custom_name else "")
+
+ @property
+ def is_custom(self) -> bool:
+ """Check if this is a custom operator."""
+ return bool(self.custom_name)
+
+ @property
+ def compatibility_data(self) -> dict[str, dict[str, Any]]:
+ """Get the compatibility data (builtin or custom ops)."""
+ return (
+ Operator.CUSTOM_COMPATIBILITY
+ if self.is_custom
+ else Operator.BUILTIN_COMPATIBILITY
+ )
+
+ @property
+ def supported_activation_functions(self) -> list[str]:
+ """Return a list of fused activation functions supported by this op."""
+ op_name = self.custom_name if self.custom_name else self.name
+ return self.compatibility_data[op_name].get("supported_fused_activation", [])
+
+ @classmethod
+ def from_tflite_op(cls, tfl_op: Op, location: str) -> Operator:
+ """Create a new instance from TensorFlow Lite operator and location."""
+ support_type = cls._get_support_type(tfl_op)
+ activation_func = (
+ tfl_op.builtin_options["fused_activation_function"]
+ if (
+ tfl_op.builtin_options
+ and "fused_activation_function" in tfl_op.builtin_options
+ )
+ else TFL_ACTIVATION_FUNCTION.NONE
+ )
+ return Operator(
+ tfl_op.type,
+ location,
+ support_type,
+ activation_func=activation_func,
+ custom_name=(tfl_op.custom_type if tfl_op.is_custom else None),
+ )
+
+ @staticmethod
+ def _get_support_type(tfl_op: Op) -> Operator.SupportType:
+ """Get the support type from the TensorFlow Lite operator."""
+ compat_data = (
+ Operator.CUSTOM_COMPATIBILITY
+ if tfl_op.is_custom
+ else Operator.BUILTIN_COMPATIBILITY
+ )
+ op_type = tfl_op.custom_type if tfl_op.is_custom else tfl_op.type
+
+ if op_type not in compat_data:
+ return Operator.SupportType.OP_NOT_SUPPORTED
+
+ compat_op = compat_data[op_type]
+ if "supported_fused_activation" in compat_op:
+ assert tfl_op.builtin_options
+ assert "fused_activation_function" in tfl_op.builtin_options
+ if (
+ tfl_op.builtin_options["fused_activation_function"]
+ not in compat_op["supported_fused_activation"]
+ ):
+ return Operator.SupportType.ACTIVATION_NOT_SUPPORTED
+
+ return Operator.SupportType.COMPATIBLE
@dataclass
@@ -21,14 +118,26 @@ class CortexACompatibilityInfo:
"""Model's operators."""
cortex_a_compatible: bool
- operators: list[Operator] | None = None
+ operators: list[Operator]
+ backend_info: ClassVar[str] = (
+ f"{TFLITE_DELEGATE_COMPAT['metadata']['backend']} "
+ f"{TFLITE_DELEGATE_COMPAT['metadata']['version']}"
+ )
-def get_cortex_a_compatibility_info(
- _model_path: Path,
-) -> CortexACompatibilityInfo | None:
+def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInfo:
"""Return list of model's operators."""
- return None
+ model = parse_subgraphs(model_path)
+
+ op_list = [
+ Operator.from_tflite_op(oper, f"subgraph:{g_idx},oper:{op_idx}")
+ for g_idx, g in enumerate(model)
+ for op_idx, oper in enumerate(g)
+ ]
+ all_compatible = all(oper.is_cortex_a_compatible for oper in op_list)
+ compat_info = CortexACompatibilityInfo(all_compatible, op_list)
+
+ return compat_info
def report() -> None:
diff --git a/src/mlia/devices/cortexa/reporters.py b/src/mlia/devices/cortexa/reporters.py
index a55caba..84de10b 100644
--- a/src/mlia/devices/cortexa/reporters.py
+++ b/src/mlia/devices/cortexa/reporters.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from typing import Any
from typing import Callable
+from typing import cast
from mlia.core.advice_generation import Advice
from mlia.core.reporters import report_advice
@@ -96,21 +97,22 @@ def report_cortex_a_operators(ops: list[Operator]) -> Report:
),
Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)),
Column(
- "Cortex-A compatibility",
+ "Arm NN TFLite Delegate compatibility",
alias="cortex_a_compatible",
- fmt=Format(wrap_width=25),
+ fmt=Format(wrap_width=40),
),
],
[
(
index + 1,
op.location,
- op.name,
+ op.full_name,
Cell(
- op.is_cortex_a_compatible,
+ op.support_type,
Format(
+ wrap_width=30,
style=style_improvement(op.is_cortex_a_compatible),
- str_fmt=lambda v: "Compatible" if v else "Not compatible",
+ str_fmt=lambda v: cast(str, v.value),
),
),
)
diff --git a/src/mlia/nn/tensorflow/tflite_graph.py b/src/mlia/nn/tensorflow/tflite_graph.py
new file mode 100644
index 0000000..4f5e85f
--- /dev/null
+++ b/src/mlia/nn/tensorflow/tflite_graph.py
@@ -0,0 +1,139 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utilities for TensorFlow Lite graphs."""
+from __future__ import annotations
+
+import enum
+import json
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+from typing import cast
+
+from tensorflow.lite.python import schema_py_generated as schema_fb
+from tensorflow.lite.tools import visualize
+
+
+def _enum_from_class(cls: Any) -> Any:
+ """Create an enum from the public class variables."""
+ return enum.Enum(
+ cls.__name__,
+ {key: value for key, value in vars(cls).items() if not key.startswith("_")},
+ )
+
+
+TFL_TYPE = _enum_from_class(schema_fb.TensorType)
+TFL_OP = _enum_from_class(schema_fb.BuiltinOperator)
+TFL_ACTIVATION_FUNCTION = _enum_from_class(schema_fb.ActivationFunctionType)
+
+
+def _ascii_list_to_string(ascii_list: list[int]) -> str:
+ return "".join(chr(i) for i in ascii_list)
+
+
+@dataclass
+class TensorInfo:
+ """Collection of tensor information parsed from a TensorFlow Lite file."""
+
+ name: str
+ type: str
+ shape: tuple | list
+ is_variable: bool
+
+ def __str__(self) -> str:
+ """Create a text represenation of this TensorInfo instance."""
+ return f"{self.name}: {self.type}, {self.shape}, is_variable={self.is_variable}"
+
+ def __repr__(self) -> str:
+ """Convert this instance to JSON."""
+ return json.dumps(vars(self))
+
+ @classmethod
+ def from_dict(cls, tensor: dict[str, Any]) -> TensorInfo:
+ """
+ Create a new instance from a dictionary.
+
+ The expected dict is the one contained in the dict returned by
+ visualize.CreateDictFromFlatbuffer().
+ """
+ return TensorInfo(
+ _ascii_list_to_string(tensor["name"]),
+ TFL_TYPE(tensor["type"]).name,
+ tensor["shape"],
+ tensor["is_variable"],
+ )
+
+
+@dataclass
+class Op:
+ """
+ Representation of an operator from a TensorFlow Lite file.
+
+ E.g. collects the operator type, input/output tensors etc.
+ """
+
+ type: str
+ builtin_options: dict
+ inputs: list[TensorInfo]
+ outputs: list[TensorInfo]
+ custom_type: str | None = None
+
+ def __post_init__(self) -> None:
+ """Convert the builtin option 'fused_activation_function' to string."""
+ if "fused_activation_function" in self.builtin_options:
+ # Convert the fused activation function ID to a string
+ self.builtin_options["fused_activation_function"] = TFL_ACTIVATION_FUNCTION(
+ self.builtin_options["fused_activation_function"]
+ ).name
+
+ def __str__(self) -> str:
+ """Create a text represenation of this Op instance."""
+ return f"""{self.type}
+ builtin_options: {self.builtin_options}
+ inputs: {self.inputs}
+ outputs: {self.outputs}"""
+
+ @property
+ def is_custom(self) -> bool:
+ """Check if this Op is a custom operator."""
+ return self.type == cast(str, TFL_OP.CUSTOM.name)
+
+ @classmethod
+ def from_model_info(cls, oper: dict, graph: dict, model: dict) -> Op:
+ """Create a new Op from the model information."""
+ op_code_idx = oper["opcode_index"]
+ op_code_obj = model["operator_codes"][op_code_idx]
+ op_code = max(
+ op_code_obj["builtin_code"], op_code_obj["deprecated_builtin_code"]
+ )
+ custom_code = op_code_obj.get("custom_code")
+ return cls(
+ type=cast(str, TFL_OP(op_code).name),
+ builtin_options=oper["builtin_options"] if oper["builtin_options"] else {},
+ inputs=[
+ TensorInfo.from_dict(graph["tensors"][idx]) for idx in oper["inputs"]
+ ],
+ outputs=[
+ TensorInfo.from_dict(graph["tensors"][idx]) for idx in oper["outputs"]
+ ],
+ custom_type=_ascii_list_to_string(custom_code) if custom_code else None,
+ )
+
+
+def load_tflite(file: Path) -> bytes:
+ """Load a TensorFlow Lite file from disk."""
+ return file.read_bytes()
+
+
+def parse_subgraphs(tflite_file: Path) -> list[list[Op]]:
+ """Load the TensorFlow Lite file and parse the subgraphs."""
+ tflite_model = load_tflite(tflite_file)
+ model = cast(dict, visualize.CreateDictFromFlatbuffer(tflite_model))
+ assert isinstance(model, dict)
+
+ graphs = [
+ [Op.from_model_info(oper, g, model) for oper in g["operators"]]
+ for g in model["subgraphs"]
+ ]
+
+ return graphs
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index 78adc53..4b16ac5 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -250,6 +250,21 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non
evaluate_on=["some_backend"],
),
],
+ [
+ [
+ "operators",
+ "sample_model.h5",
+ "--target-profile",
+ "cortex-a",
+ ],
+ call(
+ ctx=ANY,
+ target_profile="cortex-a",
+ model="sample_model.h5",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
],
)
def test_commands_execution(
diff --git a/tests/test_devices_cortexa_advice_generation.py b/tests/test_devices_cortexa_advice_generation.py
index ead8ae6..0446f38 100644
--- a/tests/test_devices_cortexa_advice_generation.py
+++ b/tests/test_devices_cortexa_advice_generation.py
@@ -13,27 +13,91 @@ from mlia.devices.cortexa.advice_generation import CortexAAdviceProducer
from mlia.devices.cortexa.data_analysis import ModelIsCortexACompatible
from mlia.devices.cortexa.data_analysis import ModelIsNotCortexACompatible
from mlia.devices.cortexa.data_analysis import ModelIsNotTFLiteCompatible
+from mlia.devices.cortexa.operator_compatibility import ARMNN_TFLITE_DELEGATE
+from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
+
+BACKEND_INFO = (
+ f"{ARMNN_TFLITE_DELEGATE['metadata']['backend']} "
+ f"{ARMNN_TFLITE_DELEGATE['metadata']['version']}"
+)
@pytest.mark.parametrize(
"input_data, advice_category, expected_advice",
[
[
- ModelIsNotCortexACompatible(),
+ ModelIsNotCortexACompatible(BACKEND_INFO, {"UNSUPPORTED_OP"}, {}),
AdviceCategory.OPERATORS,
[
Advice(
[
- "Some operators in the model are not compatible with Cortex-A. "
- "Please, refer to the operators table for more information."
+ "The following operators are not supported by "
+ f"{BACKEND_INFO} and will fall back to the TensorFlow "
+ "Lite runtime:",
+ " - UNSUPPORTED_OP",
]
- )
+ ),
+ Advice(
+ [
+ "Please, refer to the full table of operators above "
+ "for more information.",
+ CortexAAdviceProducer.cortex_a_disclaimer,
+ ]
+ ),
+ ],
+ ],
+ [
+ ModelIsNotCortexACompatible(
+ BACKEND_INFO,
+ {"UNSUPPORTED_OP"},
+ {
+ "CONV_2D": ModelIsNotCortexACompatible.ActivationFunctionSupport(
+ used_unsupported={TFL_ACTIVATION_FUNCTION.SIGN_BIT.name},
+ supported={"RELU"},
+ )
+ },
+ ),
+ AdviceCategory.OPERATORS,
+ [
+ Advice(
+ [
+ "The following operators are not supported by "
+ f"{BACKEND_INFO} and will fall back to the TensorFlow "
+ "Lite runtime:",
+ " - UNSUPPORTED_OP",
+ ]
+ ),
+ Advice(
+ [
+ "The fused activation functions of the following "
+ f"operators are not supported by {BACKEND_INFO}. "
+ "Please consider using one of the supported activation "
+ "functions instead:",
+ " - CONV_2D\n"
+ " - Used unsupported: {'SIGN_BIT'}\n"
+ " - Supported: {'RELU'}",
+ ]
+ ),
+ Advice(
+ [
+ "Please, refer to the full table of operators above "
+ "for more information.",
+ CortexAAdviceProducer.cortex_a_disclaimer,
+ ]
+ ),
],
],
[
- ModelIsCortexACompatible(),
+ ModelIsCortexACompatible(BACKEND_INFO),
AdviceCategory.OPERATORS,
- [Advice(["Model is fully compatible with Cortex-A."])],
+ [
+ Advice(
+ [
+ f"Model is fully compatible with {BACKEND_INFO} for Cortex-A.",
+ CortexAAdviceProducer.cortex_a_disclaimer,
+ ]
+ )
+ ],
],
[
ModelIsNotTFLiteCompatible(
diff --git a/tests/test_devices_cortexa_data_analysis.py b/tests/test_devices_cortexa_data_analysis.py
index b491e52..4d98212 100644
--- a/tests/test_devices_cortexa_data_analysis.py
+++ b/tests/test_devices_cortexa_data_analysis.py
@@ -11,10 +11,18 @@ from mlia.devices.cortexa.data_analysis import CortexADataAnalyzer
from mlia.devices.cortexa.data_analysis import ModelIsCortexACompatible
from mlia.devices.cortexa.data_analysis import ModelIsNotCortexACompatible
from mlia.devices.cortexa.data_analysis import ModelIsNotTFLiteCompatible
+from mlia.devices.cortexa.operator_compatibility import ARMNN_TFLITE_DELEGATE
from mlia.devices.cortexa.operators import CortexACompatibilityInfo
+from mlia.devices.cortexa.operators import Operator
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError
from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
+from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
+
+BACKEND_INFO = (
+ f"{ARMNN_TFLITE_DELEGATE['metadata']['backend']} "
+ f"{ARMNN_TFLITE_DELEGATE['metadata']['version']}"
+)
@pytest.mark.parametrize(
@@ -22,11 +30,78 @@ from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
[
[
CortexACompatibilityInfo(True, []),
- [ModelIsCortexACompatible()],
+ [ModelIsCortexACompatible(BACKEND_INFO)],
+ ],
+ [
+ CortexACompatibilityInfo(
+ True,
+ [
+ Operator(
+ "CONV_2D",
+ "somewhere",
+ support_type=Operator.SupportType.COMPATIBLE,
+ activation_func=TFL_ACTIVATION_FUNCTION.NONE,
+ ),
+ Operator(
+ "CUSTOM",
+ "somewhere else",
+ support_type=Operator.SupportType.COMPATIBLE,
+ activation_func=TFL_ACTIVATION_FUNCTION.SIGN_BIT,
+ custom_name="MaxPool3D",
+ ),
+ ],
+ ),
+ [ModelIsCortexACompatible(BACKEND_INFO)],
],
[
- CortexACompatibilityInfo(False, []),
- [ModelIsNotCortexACompatible()],
+ # pylint: disable=line-too-long
+ CortexACompatibilityInfo(
+ False,
+ [
+ Operator(
+ "UNSUPPORTED_OP",
+ "somewhere",
+ support_type=Operator.SupportType.OP_NOT_SUPPORTED,
+ activation_func=TFL_ACTIVATION_FUNCTION.NONE,
+ ),
+ Operator(
+ "CUSTOM",
+ "somewhere",
+ support_type=Operator.SupportType.OP_NOT_SUPPORTED,
+ activation_func=TFL_ACTIVATION_FUNCTION.NONE,
+ custom_name="UNSUPPORTED_OP",
+ ),
+ Operator(
+ "CONV_2D",
+ "somewhere else",
+ support_type=Operator.SupportType.ACTIVATION_NOT_SUPPORTED,
+ activation_func=TFL_ACTIVATION_FUNCTION.SIGN_BIT,
+ ),
+ ],
+ ),
+ [
+ ModelIsNotCortexACompatible(
+ BACKEND_INFO,
+ {
+ "UNSUPPORTED_OP",
+ "CUSTOM - 'UNSUPPORTED_OP'",
+ },
+ {
+ "CONV_2D": ModelIsNotCortexACompatible.ActivationFunctionSupport(
+ used_unsupported={TFL_ACTIVATION_FUNCTION.SIGN_BIT.name},
+ supported={
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ },
+ )
+ },
+ )
+ ],
+ # pylint: enable=line-too-long
],
[
TFLiteCompatibilityInfo(compatible=True),
diff --git a/tests/test_devices_cortexa_data_collection.py b/tests/test_devices_cortexa_data_collection.py
index 7ea3e52..6d3b2ac 100644
--- a/tests/test_devices_cortexa_data_collection.py
+++ b/tests/test_devices_cortexa_data_collection.py
@@ -11,18 +11,42 @@ from mlia.devices.cortexa.data_collection import CortexAOperatorCompatibility
from mlia.devices.cortexa.operators import CortexACompatibilityInfo
-def test_cortex_a_data_collection(
- monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str
+def check_cortex_a_data_collection(
+ monkeypatch: pytest.MonkeyPatch, model: Path, tmpdir: str
) -> None:
"""Test Cortex-A data collection."""
+ assert CortexAOperatorCompatibility.name()
+
monkeypatch.setattr(
"mlia.devices.cortexa.data_collection.get_cortex_a_compatibility_info",
MagicMock(return_value=CortexACompatibilityInfo(True, [])),
)
+
context = ExecutionContext(working_dir=tmpdir)
- collector = CortexAOperatorCompatibility(test_tflite_model)
+ collector = CortexAOperatorCompatibility(model)
collector.set_context(context)
data_item = collector.collect_data()
assert isinstance(data_item, CortexACompatibilityInfo)
+
+
+def test_cortex_a_data_collection_tflite(
+ monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str
+) -> None:
+ """Test Cortex-A data collection with a TensorFlow Lite model."""
+ check_cortex_a_data_collection(monkeypatch, test_tflite_model, tmpdir)
+
+
+def test_cortex_a_data_collection_keras(
+ monkeypatch: pytest.MonkeyPatch, test_keras_model: Path, tmpdir: str
+) -> None:
+ """Test Cortex-A data collection with a Keras model."""
+ check_cortex_a_data_collection(monkeypatch, test_keras_model, tmpdir)
+
+
+def test_cortex_a_data_collection_tf(
+ monkeypatch: pytest.MonkeyPatch, test_tf_model: Path, tmpdir: str
+) -> None:
+ """Test Cortex-A data collection with a SavedModel."""
+ check_cortex_a_data_collection(monkeypatch, test_tf_model, tmpdir)
diff --git a/tests/test_devices_cortexa_operators.py b/tests/test_devices_cortexa_operators.py
new file mode 100644
index 0000000..23c4b0a
--- /dev/null
+++ b/tests/test_devices_cortexa_operators.py
@@ -0,0 +1,73 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Cortex-A operator compatibility."""
+from pathlib import Path
+
+import pytest
+import tensorflow as tf
+
+from mlia.devices.cortexa import operator_compatibility as op_compat
+from mlia.devices.cortexa.operators import CortexACompatibilityInfo
+from mlia.devices.cortexa.operators import get_cortex_a_compatibility_info
+from mlia.devices.cortexa.operators import Operator
+from mlia.nn.tensorflow.tflite_graph import TFL_OP
+from mlia.nn.tensorflow.utils import convert_to_tflite
+
+
+def test_op_compat_data() -> None:
+ """Make sure all data contains the necessary items."""
+ builtin_tfl_ops = {op.name for op in TFL_OP}
+ for data in [op_compat.ARMNN_TFLITE_DELEGATE]:
+ assert "metadata" in data
+ assert "backend" in data["metadata"]
+ assert "version" in data["metadata"]
+ assert "builtin_ops" in data
+ for comp in data["builtin_ops"]:
+ assert comp in builtin_tfl_ops
+ assert "custom_ops" in data
+
+
+def check_get_cortex_a_compatibility_info(
+ model_path: Path,
+ expected_success: bool,
+) -> None:
+ """Check the function 'get_cortex_a_compatibility_info'."""
+ compat_info = get_cortex_a_compatibility_info(model_path)
+ assert isinstance(compat_info, CortexACompatibilityInfo)
+ assert expected_success == compat_info.cortex_a_compatible
+ assert compat_info.operators
+ for oper in compat_info.operators:
+ assert oper.name
+ assert oper.location
+ assert oper.support_type in Operator.SupportType
+
+
+def test_get_cortex_a_compatibility_info_compatible(
+ test_tflite_model: Path,
+) -> None:
+ """Test a fully compatible TensorFlow Lite model."""
+ check_get_cortex_a_compatibility_info(test_tflite_model, expected_success=True)
+
+
+def test_get_cortex_a_compatibility_info_not_compatible(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Construct and test a NOT fully compatible TensorFlow Lite model."""
+ keras_model = tf.keras.Sequential(
+ [
+ tf.keras.Input(shape=(28, 28, 1), batch_size=1, name="input"),
+ tf.keras.layers.Conv2D(
+ filters=12, kernel_size=(3, 3), activation="softmax", name="conv1"
+ ),
+ tf.keras.layers.LeakyReLU(),
+ ]
+ )
+ keras_model.compile(optimizer="sgd", loss="mean_squared_error")
+ tflite_model = convert_to_tflite(keras_model, quantized=False)
+
+ monkeypatch.setattr(
+ "mlia.nn.tensorflow.tflite_graph.load_tflite", lambda _p: tflite_model
+ )
+ check_get_cortex_a_compatibility_info(
+ Path("NOT_USED_BECAUSE_OF_MOCKING"), expected_success=False
+ )
diff --git a/tests/test_devices_cortexa_reporters.py b/tests/test_devices_cortexa_reporters.py
new file mode 100644
index 0000000..4177b55
--- /dev/null
+++ b/tests/test_devices_cortexa_reporters.py
@@ -0,0 +1,52 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Cortex-A reporters."""
+from typing import Any
+
+import pytest
+
+from mlia.core.advice_generation import Advice
+from mlia.core.reporting import Report
+from mlia.devices.cortexa.config import CortexAConfiguration
+from mlia.devices.cortexa.operators import Operator
+from mlia.devices.cortexa.reporters import cortex_a_formatters
+from mlia.devices.cortexa.reporters import report_device
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
+
+
+def test_report_device() -> None:
+ """Test function report_device()."""
+ report = report_device(CortexAConfiguration("cortex-a"))
+ assert report.to_plain_text()
+
+
+@pytest.mark.parametrize(
+ "data",
+ (
+ [Advice(["Sample", "Advice"])],
+ TFLiteCompatibilityInfo(compatible=True),
+ [
+ Operator(
+ name="Test",
+ location="loc",
+ support_type=Operator.SupportType.OP_NOT_SUPPORTED,
+ activation_func=TFL_ACTIVATION_FUNCTION.NONE,
+ )
+ ],
+ ),
+)
+def test_cortex_a_formatters(data: Any) -> None:
+ """Test function cortex_a_formatters() with valid input."""
+ formatter = cortex_a_formatters(data)
+ report = formatter(data)
+ assert isinstance(report, Report)
+
+
+def test_cortex_a_formatters_invalid_data() -> None:
+ """Test cortex_a_formatters() with invalid input."""
+ with pytest.raises(
+ Exception,
+ match=r"^Unable to find appropriate formatter for .*",
+ ):
+ cortex_a_formatters(12)
diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py
new file mode 100644
index 0000000..cd1fad6
--- /dev/null
+++ b/tests/test_nn_tensorflow_tflite_graph.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the tflite_graph module."""
+import json
+from pathlib import Path
+
+from mlia.nn.tensorflow.tflite_graph import Op
+from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
+from mlia.nn.tensorflow.tflite_graph import TensorInfo
+from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
+from mlia.nn.tensorflow.tflite_graph import TFL_OP
+from mlia.nn.tensorflow.tflite_graph import TFL_TYPE
+
+
+def test_tensor_info() -> None:
+ """Test class 'TensorInfo'."""
+ expected = {
+ "name": "Test",
+ "type": TFL_TYPE.INT8.name,
+ "shape": (1,),
+ "is_variable": False,
+ }
+ info = TensorInfo(**expected)
+ assert vars(info) == expected
+
+ expected = {
+ "name": "Test2",
+ "type": TFL_TYPE.FLOAT32.name,
+ "shape": [2, 3],
+ "is_variable": True,
+ }
+ tensor_dict = {
+ "name": [ord(c) for c in expected["name"]],
+ "type": TFL_TYPE[expected["type"]],
+ "shape": expected["shape"],
+ "is_variable": expected["is_variable"],
+ }
+ info = TensorInfo.from_dict(tensor_dict)
+ assert vars(info) == expected
+
+ json_repr = json.loads(repr(info))
+ assert vars(info) == json_repr
+
+ assert str(info)
+
+
+def test_op() -> None:
+ """Test class 'Op'."""
+ expected = {
+ "type": TFL_OP.CONV_2D.name,
+ "builtin_options": {},
+ "inputs": [],
+ "outputs": [],
+ "custom_type": None,
+ }
+ oper = Op(**expected)
+ assert vars(oper) == expected
+
+ expected["builtin_options"] = {"some_random_option": 3.14}
+ oper = Op(**expected)
+ assert vars(oper) == expected
+
+ activation_func = TFL_ACTIVATION_FUNCTION.RELU
+ expected["builtin_options"] = {"fused_activation_function": activation_func.value}
+ oper = Op(**expected)
+ assert oper.builtin_options
+ assert oper.builtin_options["fused_activation_function"] == activation_func.name
+
+ assert str(oper)
+ assert repr(oper)
+
+
+def test_parse_subgraphs(test_tflite_model: Path) -> None:
+ """Test function 'parse_subgraphs'."""
+ model = parse_subgraphs(test_tflite_model)
+ assert len(model) == 1
+ assert len(model[0]) == 5
+ for oper in model[0]:
+ assert TFL_OP[oper.type] in TFL_OP
+ assert len(oper.inputs) > 0
+ assert len(oper.outputs) > 0