diff options
Diffstat (limited to 'src/mlia/target/cortex_a/operators.py')
-rw-r--r-- | src/mlia/target/cortex_a/operators.py | 148 |
1 files changed, 148 insertions, 0 deletions
diff --git a/src/mlia/target/cortex_a/operators.py b/src/mlia/target/cortex_a/operators.py new file mode 100644 index 0000000..91f1886 --- /dev/null +++ b/src/mlia/target/cortex_a/operators.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A tools module.""" +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any +from typing import ClassVar + +from mlia.nn.tensorflow.tflite_graph import Op +from mlia.nn.tensorflow.tflite_graph import parse_subgraphs +from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION +from mlia.target.cortex_a.operator_compatibility import ( + ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT, +) + + +@dataclass +class Operator: + """Cortex-A compatibility information of the operator.""" + + BUILTIN_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["builtin_ops"] + CUSTOM_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["custom_ops"] + + class SupportType(Enum): + """Type of operator support.""" + + COMPATIBLE = "Compatible" + OP_NOT_SUPPORTED = "Operator not supported" + ACTIVATION_NOT_SUPPORTED = "Activation not supported" + + name: str + location: str + support_type: SupportType + activation_func: TFL_ACTIVATION_FUNCTION + custom_name: str | None = None + + @property + def is_cortex_a_compatible(self) -> bool: + """Check if this operator is compatible.""" + return self.support_type == Operator.SupportType.COMPATIBLE + + @property + def full_name(self) -> str: + """Returun the full name including the custom name if applicable.""" + return self.name + (f" - '{self.custom_name}'" if self.custom_name else "") + + @property + def is_custom(self) -> bool: + """Check if this is a custom operator.""" + return bool(self.custom_name) + + @property + def compatibility_data(self) -> dict[str, dict[str, Any]]: + """Get the compatibility data (builtin or custom ops).""" + return ( + Operator.CUSTOM_COMPATIBILITY + if self.is_custom + else Operator.BUILTIN_COMPATIBILITY + ) + + @property + def supported_activation_functions(self) -> list[str]: + """Return a list of fused activation functions supported by this op.""" + op_name = self.custom_name if self.custom_name else self.name + return self.compatibility_data[op_name].get("supported_fused_activation", []) + + @classmethod + def from_tflite_op(cls, tfl_op: Op, location: str) -> Operator: + """Create a new instance from TensorFlow Lite operator and location.""" + support_type = cls._get_support_type(tfl_op) + activation_func = ( + tfl_op.builtin_options["fused_activation_function"] + if ( + tfl_op.builtin_options + and "fused_activation_function" in tfl_op.builtin_options + ) + else TFL_ACTIVATION_FUNCTION.NONE + ) + return Operator( + tfl_op.type, + location, + support_type, + activation_func=activation_func, + custom_name=(tfl_op.custom_type if tfl_op.is_custom else None), + ) + + @staticmethod + def _get_support_type(tfl_op: Op) -> Operator.SupportType: + """Get the support type from the TensorFlow Lite operator.""" + compat_data = ( + Operator.CUSTOM_COMPATIBILITY + if tfl_op.is_custom + else Operator.BUILTIN_COMPATIBILITY + ) + op_type = tfl_op.custom_type if tfl_op.is_custom else tfl_op.type + + if op_type not in compat_data: + return Operator.SupportType.OP_NOT_SUPPORTED + + compat_op = compat_data[op_type] + if "supported_fused_activation" in compat_op: + assert tfl_op.builtin_options + assert "fused_activation_function" in tfl_op.builtin_options + if ( + tfl_op.builtin_options["fused_activation_function"] + not in compat_op["supported_fused_activation"] + ): + return Operator.SupportType.ACTIVATION_NOT_SUPPORTED + + return Operator.SupportType.COMPATIBLE + + +@dataclass +class CortexACompatibilityInfo: + """Model's operators.""" + + cortex_a_compatible: bool + operators: list[Operator] + backend_info: ClassVar[str] = ( + f"{TFLITE_DELEGATE_COMPAT['metadata']['backend']} " + f"{TFLITE_DELEGATE_COMPAT['metadata']['version']}" + ) + + +def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInfo: + """Return list of model's operators.""" + model = parse_subgraphs(model_path) + + op_list = [ + Operator.from_tflite_op(oper, f"subgraph:{g_idx},oper:{op_idx}") + for g_idx, g in enumerate(model) + for op_idx, oper in enumerate(g) + ] + all_compatible = all(oper.is_cortex_a_compatible for oper in op_list) + compat_info = CortexACompatibilityInfo(all_compatible, op_list) + + return compat_info + + +def report() -> None: + """Generate supported operators report.""" + raise Exception( + "Generating a supported operators report is not " + "currently supported with Cortex-A target profile." + ) |