aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/cortex_a/operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/cortex_a/operators.py')
-rw-r--r--src/mlia/target/cortex_a/operators.py148
1 files changed, 148 insertions, 0 deletions
diff --git a/src/mlia/target/cortex_a/operators.py b/src/mlia/target/cortex_a/operators.py
new file mode 100644
index 0000000..91f1886
--- /dev/null
+++ b/src/mlia/target/cortex_a/operators.py
@@ -0,0 +1,148 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Cortex-A tools module."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from enum import Enum
+from pathlib import Path
+from typing import Any
+from typing import ClassVar
+
+from mlia.nn.tensorflow.tflite_graph import Op
+from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
+from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
+from mlia.target.cortex_a.operator_compatibility import (
+ ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT,
+)
+
+
+@dataclass
+class Operator:
+ """Cortex-A compatibility information of the operator."""
+
+ BUILTIN_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["builtin_ops"]
+ CUSTOM_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["custom_ops"]
+
+ class SupportType(Enum):
+ """Type of operator support."""
+
+ COMPATIBLE = "Compatible"
+ OP_NOT_SUPPORTED = "Operator not supported"
+ ACTIVATION_NOT_SUPPORTED = "Activation not supported"
+
+ name: str
+ location: str
+ support_type: SupportType
+ activation_func: TFL_ACTIVATION_FUNCTION
+ custom_name: str | None = None
+
+ @property
+ def is_cortex_a_compatible(self) -> bool:
+ """Check if this operator is compatible."""
+ return self.support_type == Operator.SupportType.COMPATIBLE
+
+ @property
+ def full_name(self) -> str:
+ """Returun the full name including the custom name if applicable."""
+ return self.name + (f" - '{self.custom_name}'" if self.custom_name else "")
+
+ @property
+ def is_custom(self) -> bool:
+ """Check if this is a custom operator."""
+ return bool(self.custom_name)
+
+ @property
+ def compatibility_data(self) -> dict[str, dict[str, Any]]:
+ """Get the compatibility data (builtin or custom ops)."""
+ return (
+ Operator.CUSTOM_COMPATIBILITY
+ if self.is_custom
+ else Operator.BUILTIN_COMPATIBILITY
+ )
+
+ @property
+ def supported_activation_functions(self) -> list[str]:
+ """Return a list of fused activation functions supported by this op."""
+ op_name = self.custom_name if self.custom_name else self.name
+ return self.compatibility_data[op_name].get("supported_fused_activation", [])
+
+ @classmethod
+ def from_tflite_op(cls, tfl_op: Op, location: str) -> Operator:
+ """Create a new instance from TensorFlow Lite operator and location."""
+ support_type = cls._get_support_type(tfl_op)
+ activation_func = (
+ tfl_op.builtin_options["fused_activation_function"]
+ if (
+ tfl_op.builtin_options
+ and "fused_activation_function" in tfl_op.builtin_options
+ )
+ else TFL_ACTIVATION_FUNCTION.NONE
+ )
+ return Operator(
+ tfl_op.type,
+ location,
+ support_type,
+ activation_func=activation_func,
+ custom_name=(tfl_op.custom_type if tfl_op.is_custom else None),
+ )
+
+ @staticmethod
+ def _get_support_type(tfl_op: Op) -> Operator.SupportType:
+ """Get the support type from the TensorFlow Lite operator."""
+ compat_data = (
+ Operator.CUSTOM_COMPATIBILITY
+ if tfl_op.is_custom
+ else Operator.BUILTIN_COMPATIBILITY
+ )
+ op_type = tfl_op.custom_type if tfl_op.is_custom else tfl_op.type
+
+ if op_type not in compat_data:
+ return Operator.SupportType.OP_NOT_SUPPORTED
+
+ compat_op = compat_data[op_type]
+ if "supported_fused_activation" in compat_op:
+ assert tfl_op.builtin_options
+ assert "fused_activation_function" in tfl_op.builtin_options
+ if (
+ tfl_op.builtin_options["fused_activation_function"]
+ not in compat_op["supported_fused_activation"]
+ ):
+ return Operator.SupportType.ACTIVATION_NOT_SUPPORTED
+
+ return Operator.SupportType.COMPATIBLE
+
+
+@dataclass
+class CortexACompatibilityInfo:
+ """Model's operators."""
+
+ cortex_a_compatible: bool
+ operators: list[Operator]
+ backend_info: ClassVar[str] = (
+ f"{TFLITE_DELEGATE_COMPAT['metadata']['backend']} "
+ f"{TFLITE_DELEGATE_COMPAT['metadata']['version']}"
+ )
+
+
+def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInfo:
+ """Return list of model's operators."""
+ model = parse_subgraphs(model_path)
+
+ op_list = [
+ Operator.from_tflite_op(oper, f"subgraph:{g_idx},oper:{op_idx}")
+ for g_idx, g in enumerate(model)
+ for op_idx, oper in enumerate(g)
+ ]
+ all_compatible = all(oper.is_cortex_a_compatible for oper in op_list)
+ compat_info = CortexACompatibilityInfo(all_compatible, op_list)
+
+ return compat_info
+
+
+def report() -> None:
+ """Generate supported operators report."""
+ raise Exception(
+ "Generating a supported operators report is not "
+ "currently supported with Cortex-A target profile."
+ )