From e40a7adadd254e29d71af38f69a0a20ff4871eef Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Tue, 25 Oct 2022 18:12:34 +0100 Subject: MLIA-411 Report Cortex-A operator compatibility Check input model for Arm NN TensorFlow Lite Delegate 22.08 support. Change-Id: I1253c4c0b294c5283e08f0a39561b922ef0f62e6 --- src/mlia/devices/cortexa/advice_generation.py | 50 +++++- src/mlia/devices/cortexa/data_analysis.py | 57 ++++++- src/mlia/devices/cortexa/operator_compatibility.py | 184 +++++++++++++++++++++ src/mlia/devices/cortexa/operators.py | 121 +++++++++++++- src/mlia/devices/cortexa/reporters.py | 12 +- 5 files changed, 404 insertions(+), 20 deletions(-) create mode 100644 src/mlia/devices/cortexa/operator_compatibility.py (limited to 'src/mlia/devices/cortexa') diff --git a/src/mlia/devices/cortexa/advice_generation.py b/src/mlia/devices/cortexa/advice_generation.py index 0f3553f..34c51f8 100644 --- a/src/mlia/devices/cortexa/advice_generation.py +++ b/src/mlia/devices/cortexa/advice_generation.py @@ -15,6 +15,13 @@ from mlia.devices.cortexa.data_analysis import ModelIsNotTFLiteCompatible class CortexAAdviceProducer(FactBasedAdviceProducer): """Cortex-A advice producer.""" + cortex_a_disclaimer = ( + "Note that the provided compatibility information is general. " + "At runtime individual operators in the given model might fall back to " + "the TensorFlow Lite reference or might produce errors based on the " + "specific parameters." + ) + @singledispatchmethod def produce_advice(self, _data_item: DataItem) -> None: """Produce advice.""" @@ -22,21 +29,54 @@ class CortexAAdviceProducer(FactBasedAdviceProducer): @produce_advice.register @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) def handle_model_is_cortex_a_compatible( - self, _data_item: ModelIsCortexACompatible + self, data_item: ModelIsCortexACompatible ) -> None: """Advice for Cortex-A compatibility.""" - self.add_advice(["Model is fully compatible with Cortex-A."]) + self.add_advice( + [ + f"Model is fully compatible with {data_item.backend_info} for " + "Cortex-A.", + self.cortex_a_disclaimer, + ] + ) @produce_advice.register @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) def handle_model_is_not_cortex_a_compatible( - self, _data_item: ModelIsNotCortexACompatible + self, data_item: ModelIsNotCortexACompatible ) -> None: """Advice for Cortex-A compatibility.""" + if data_item.unsupported_ops: + self.add_advice( + [ + "The following operators are not supported by " + f"{data_item.backend_info} and will fall back to the " + "TensorFlow Lite runtime:", + "\n".join(f" - {op}" for op in data_item.unsupported_ops), + ] + ) + + if data_item.activation_func_support: + self.add_advice( + [ + "The fused activation functions of the following operators " + f"are not supported by {data_item.backend_info}. Please " + "consider using one of the supported activation functions " + "instead:", + "\n".join( + f" - {op}\n" + f" - Used unsupported: {act.used_unsupported}\n" + f" - Supported: {act.supported}" + for op, act in data_item.activation_func_support.items() + ), + ] + ) + self.add_advice( [ - "Some operators in the model are not compatible with Cortex-A. " - "Please, refer to the operators table for more information." + "Please, refer to the full table of operators above for more " + "information.", + self.cortex_a_disclaimer, ] ) diff --git a/src/mlia/devices/cortexa/data_analysis.py b/src/mlia/devices/cortexa/data_analysis.py index d2b6f35..9f6d82b 100644 --- a/src/mlia/devices/cortexa/data_analysis.py +++ b/src/mlia/devices/cortexa/data_analysis.py @@ -3,13 +3,16 @@ """Cortex-A data analysis module.""" from __future__ import annotations +from collections import defaultdict from dataclasses import dataclass +from dataclasses import field from functools import singledispatchmethod from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.core.data_analysis import FactExtractor from mlia.devices.cortexa.operators import CortexACompatibilityInfo +from mlia.devices.cortexa.operators import Operator from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode @@ -27,9 +30,38 @@ class CortexADataAnalyzer(FactExtractor): ) -> None: """Analyse operator compatibility information.""" if data_item.cortex_a_compatible: - self.add_fact(ModelIsCortexACompatible()) + self.add_fact(ModelIsCortexACompatible(data_item.backend_info)) else: - self.add_fact(ModelIsNotCortexACompatible()) + unsupported_ops = set() + activation_func_support: defaultdict[ + str, ModelIsNotCortexACompatible.ActivationFunctionSupport + ] = defaultdict(ModelIsNotCortexACompatible.ActivationFunctionSupport) + for oper in data_item.operators: + if oper.support_type == Operator.SupportType.OP_NOT_SUPPORTED: + unsupported_ops.add(oper.full_name) + + if oper.support_type == Operator.SupportType.ACTIVATION_NOT_SUPPORTED: + # Add used but unsupported actication functions + activation_func_support[oper.full_name].used_unsupported.add( + oper.activation_func.name + ) + # Add supported activation functions + activation_func_support[oper.full_name].supported.update( + oper.supported_activation_functions + ) + + assert ( + unsupported_ops or activation_func_support or not data_item.operators + ), ( + "The model is marked as not compatible with Cortex-A but there " + "are no unsupported ops activation functions listed." + ) + + self.add_fact( + ModelIsNotCortexACompatible( + data_item.backend_info, unsupported_ops, activation_func_support + ) + ) @analyze_data.register def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None: @@ -52,14 +84,31 @@ class CortexADataAnalyzer(FactExtractor): @dataclass -class ModelIsCortexACompatible(Fact): +class CortexACompatibility(Fact): + """Base class for Cortex-A compatibility providing backend info.""" + + backend_info: str + + +@dataclass +class ModelIsCortexACompatible(CortexACompatibility): """Model is completely compatible with Cortex-A.""" @dataclass -class ModelIsNotCortexACompatible(Fact): +class ModelIsNotCortexACompatible(CortexACompatibility): """Model is not compatible with Cortex-A.""" + @dataclass + class ActivationFunctionSupport: + """Activation function support per operator.""" + + used_unsupported: set[str] = field(default_factory=set) + supported: set[str] = field(default_factory=set) + + unsupported_ops: set[str] + activation_func_support: dict[str, ActivationFunctionSupport] + @dataclass class ModelIsNotTFLiteCompatible(Fact): diff --git a/src/mlia/devices/cortexa/operator_compatibility.py b/src/mlia/devices/cortexa/operator_compatibility.py new file mode 100644 index 0000000..c474e75 --- /dev/null +++ b/src/mlia/devices/cortexa/operator_compatibility.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Collection of Cortex-A operator compatibility information.""" +from __future__ import annotations + +from typing import Any + +ARMNN_TFLITE_DELEGATE: dict[str, dict[str, Any]] = { + "metadata": { + "backend": "Arm NN TensorFlow Lite delegate", + "version": "22.08", + }, + # BUILTIN OPERATORS + "builtin_ops": { + "ABS": {}, + "ADD": {}, + "ARG_MAX": {}, + "ARG_MIN": {}, + "AVERAGE_POOL_2D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "BATCH_TO_SPACE_ND": {}, + "CAST": {}, + "CONCATENATION": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "CONV_2D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "CONV_3D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "DEPTH_TO_SPACE": {}, + "DEPTHWISE_CONV_2D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "DEQUANTIZE": {}, + "DIV": {}, + "EQUAL": {}, + "ELU": {}, + "EXP": {}, + "EXPAND_DIMS": {}, + "FILL": {}, + "FLOOR": {}, + "FLOOR_DIV": {}, + "FULLY_CONNECTED": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "GATHER": {}, + "GATHER_ND": {}, + "GREATER": {}, + "GREATER_EQUAL": {}, + "HARD_SWISH": {}, + "L2_NORMALIZATION": {}, + "L2_POOL_2D": {}, + "LESS": {}, + "LESS_EQUAL": {}, + "LOCAL_RESPONSE_NORMALIZATION": {}, + "LOG": {}, + "LOGICAL_AND": {}, + "LOGICAL_NOT": {}, + "LOGICAL_OR": {}, + "LOGISTIC": {}, + "LOG_SOFTMAX": {}, + "LSTM": {}, + "MAXIMUM": {}, + "MAX_POOL_2D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "TANH", + "NONE", + ] + }, + "MEAN": {}, + "MINIMUM": {}, + "MIRROR_PAD": {}, + "MUL": {}, + "NEG": {}, + "NOT_EQUAL": {}, + "PACK": {}, + "PAD": {}, + "PADV2": {}, + "PRELU": {}, + "QUANTIZE": {}, + "RANK": {}, + "REDUCE_MAX": {}, + "REDUCE_MIN": {}, + "REDUCE_PROD": {}, + "RELU": {}, + "RELU6": {}, + "RELU_N1_TO_1": {}, + "RESHAPE": {}, + "RESIZE_BILINEAR": {}, + "RESIZE_NEAREST_NEIGHBOR": {}, + "RSQRT": {}, + "SHAPE": {}, + "SIN": {}, + "SOFTMAX": {}, + "SPACE_TO_BATCH_ND": {}, + "SPACE_TO_DEPTH": {}, + "SPLIT": {}, + "SPLIT_V": {}, + "SQRT": {}, + "SQUEEZE": {}, + "STRIDED_SLICE": {}, + "SUB": {}, + "SUM": {}, + "TANH": {}, + "TRANSPOSE": {}, + "TRANSPOSE_CONV": {}, + "UNIDIRECTIONAL_SEQUENCE_LSTM": {}, + "UNPACK": {}, + }, + # CUSTOM OPERATORS + "custom_ops": { + "AveragePool3D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "SIGN_BIT", + "TANH", + "NONE", + ] + }, + "MaxPool3D": { + "supported_fused_activation": [ + "RELU", + "RELU6", + "RELU_N1_TO_1", + "SIGMOID", + "SIGN_BIT", + "TANH", + "NONE", + ] + }, + }, +} diff --git a/src/mlia/devices/cortexa/operators.py b/src/mlia/devices/cortexa/operators.py index d46b107..3e84d64 100644 --- a/src/mlia/devices/cortexa/operators.py +++ b/src/mlia/devices/cortexa/operators.py @@ -4,16 +4,113 @@ from __future__ import annotations from dataclasses import dataclass +from enum import Enum from pathlib import Path +from typing import Any +from typing import ClassVar + +from mlia.devices.cortexa.operator_compatibility import ( + ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT, +) +from mlia.nn.tensorflow.tflite_graph import Op +from mlia.nn.tensorflow.tflite_graph import parse_subgraphs +from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION @dataclass class Operator: """Cortex-A compatibility information of the operator.""" + BUILTIN_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["builtin_ops"] + CUSTOM_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["custom_ops"] + + class SupportType(Enum): + """Type of operator support.""" + + COMPATIBLE = "Compatible" + OP_NOT_SUPPORTED = "Operator not supported" + ACTIVATION_NOT_SUPPORTED = "Activation not supported" + name: str location: str - is_cortex_a_compatible: bool + support_type: SupportType + activation_func: TFL_ACTIVATION_FUNCTION + custom_name: str | None = None + + @property + def is_cortex_a_compatible(self) -> bool: + """Check if this operator is compatible.""" + return self.support_type == Operator.SupportType.COMPATIBLE + + @property + def full_name(self) -> str: + """Returun the full name including the custom name if applicable.""" + return self.name + (f" - '{self.custom_name}'" if self.custom_name else "") + + @property + def is_custom(self) -> bool: + """Check if this is a custom operator.""" + return bool(self.custom_name) + + @property + def compatibility_data(self) -> dict[str, dict[str, Any]]: + """Get the compatibility data (builtin or custom ops).""" + return ( + Operator.CUSTOM_COMPATIBILITY + if self.is_custom + else Operator.BUILTIN_COMPATIBILITY + ) + + @property + def supported_activation_functions(self) -> list[str]: + """Return a list of fused activation functions supported by this op.""" + op_name = self.custom_name if self.custom_name else self.name + return self.compatibility_data[op_name].get("supported_fused_activation", []) + + @classmethod + def from_tflite_op(cls, tfl_op: Op, location: str) -> Operator: + """Create a new instance from TensorFlow Lite operator and location.""" + support_type = cls._get_support_type(tfl_op) + activation_func = ( + tfl_op.builtin_options["fused_activation_function"] + if ( + tfl_op.builtin_options + and "fused_activation_function" in tfl_op.builtin_options + ) + else TFL_ACTIVATION_FUNCTION.NONE + ) + return Operator( + tfl_op.type, + location, + support_type, + activation_func=activation_func, + custom_name=(tfl_op.custom_type if tfl_op.is_custom else None), + ) + + @staticmethod + def _get_support_type(tfl_op: Op) -> Operator.SupportType: + """Get the support type from the TensorFlow Lite operator.""" + compat_data = ( + Operator.CUSTOM_COMPATIBILITY + if tfl_op.is_custom + else Operator.BUILTIN_COMPATIBILITY + ) + op_type = tfl_op.custom_type if tfl_op.is_custom else tfl_op.type + + if op_type not in compat_data: + return Operator.SupportType.OP_NOT_SUPPORTED + + compat_op = compat_data[op_type] + if "supported_fused_activation" in compat_op: + assert tfl_op.builtin_options + assert "fused_activation_function" in tfl_op.builtin_options + if ( + tfl_op.builtin_options["fused_activation_function"] + not in compat_op["supported_fused_activation"] + ): + return Operator.SupportType.ACTIVATION_NOT_SUPPORTED + + return Operator.SupportType.COMPATIBLE @dataclass @@ -21,14 +118,26 @@ class CortexACompatibilityInfo: """Model's operators.""" cortex_a_compatible: bool - operators: list[Operator] | None = None + operators: list[Operator] + backend_info: ClassVar[str] = ( + f"{TFLITE_DELEGATE_COMPAT['metadata']['backend']} " + f"{TFLITE_DELEGATE_COMPAT['metadata']['version']}" + ) -def get_cortex_a_compatibility_info( - _model_path: Path, -) -> CortexACompatibilityInfo | None: +def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInfo: """Return list of model's operators.""" - return None + model = parse_subgraphs(model_path) + + op_list = [ + Operator.from_tflite_op(oper, f"subgraph:{g_idx},oper:{op_idx}") + for g_idx, g in enumerate(model) + for op_idx, oper in enumerate(g) + ] + all_compatible = all(oper.is_cortex_a_compatible for oper in op_list) + compat_info = CortexACompatibilityInfo(all_compatible, op_list) + + return compat_info def report() -> None: diff --git a/src/mlia/devices/cortexa/reporters.py b/src/mlia/devices/cortexa/reporters.py index a55caba..84de10b 100644 --- a/src/mlia/devices/cortexa/reporters.py +++ b/src/mlia/devices/cortexa/reporters.py @@ -5,6 +5,7 @@ from __future__ import annotations from typing import Any from typing import Callable +from typing import cast from mlia.core.advice_generation import Advice from mlia.core.reporters import report_advice @@ -96,21 +97,22 @@ def report_cortex_a_operators(ops: list[Operator]) -> Report: ), Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)), Column( - "Cortex-A compatibility", + "Arm NN TFLite Delegate compatibility", alias="cortex_a_compatible", - fmt=Format(wrap_width=25), + fmt=Format(wrap_width=40), ), ], [ ( index + 1, op.location, - op.name, + op.full_name, Cell( - op.is_cortex_a_compatible, + op.support_type, Format( + wrap_width=30, style=style_improvement(op.is_cortex_a_compatible), - str_fmt=lambda v: "Compatible" if v else "Not compatible", + str_fmt=lambda v: cast(str, v.value), ), ), ) -- cgit v1.2.1