aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/cortex_a/operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/cortex_a/operators.py')
-rw-r--r--src/mlia/target/cortex_a/operators.py135
1 files changed, 68 insertions, 67 deletions
diff --git a/src/mlia/target/cortex_a/operators.py b/src/mlia/target/cortex_a/operators.py
index ae611e5..cd92f31 100644
--- a/src/mlia/target/cortex_a/operators.py
+++ b/src/mlia/target/cortex_a/operators.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A tools module."""
from __future__ import annotations
@@ -7,7 +7,7 @@ from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
-from typing import ClassVar
+from typing import cast
from mlia.backend.armnn_tflite_delegate.compat import (
ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT,
@@ -15,34 +15,19 @@ from mlia.backend.armnn_tflite_delegate.compat import (
from mlia.nn.tensorflow.tflite_graph import Op
from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
+from mlia.target.cortex_a.config import CortexAConfiguration
@dataclass
class Operator:
"""Cortex-A compatibility information of the operator."""
- BUILTIN_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["builtin_ops"]
- CUSTOM_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["custom_ops"]
-
- class SupportType(Enum):
- """Type of operator support."""
-
- COMPATIBLE = "Compatible"
- OP_NOT_SUPPORTED = "Operator not supported"
- ACTIVATION_NOT_SUPPORTED = "Activation not supported"
-
name: str
location: str
- support_type: SupportType
activation_func: TFL_ACTIVATION_FUNCTION
custom_name: str | None = None
@property
- def is_cortex_a_compatible(self) -> bool:
- """Check if this operator is compatible."""
- return self.support_type == Operator.SupportType.COMPATIBLE
-
- @property
def full_name(self) -> str:
"""Returun the full name including the custom name if applicable."""
return self.name + (f" - '{self.custom_name}'" if self.custom_name else "")
@@ -52,27 +37,11 @@ class Operator:
"""Check if this is a custom operator."""
return bool(self.custom_name)
- @property
- def compatibility_data(self) -> dict[str, dict[str, Any]]:
- """Get the compatibility data (builtin or custom ops)."""
- return (
- Operator.CUSTOM_COMPATIBILITY
- if self.is_custom
- else Operator.BUILTIN_COMPATIBILITY
- )
-
- @property
- def supported_activation_functions(self) -> list[str]:
- """Return a list of fused activation functions supported by this op."""
- op_name = self.custom_name if self.custom_name else self.name
- return self.compatibility_data[op_name].get("supported_fused_activation", [])
-
@classmethod
def from_tflite_op(cls, tfl_op: Op, location: str) -> Operator:
"""Create a new instance from TensorFlow Lite operator and location."""
- support_type = cls._get_support_type(tfl_op)
activation_func = (
- tfl_op.builtin_options["fused_activation_function"]
+ TFL_ACTIVATION_FUNCTION[tfl_op.builtin_options["fused_activation_function"]]
if (
tfl_op.builtin_options
and "fused_activation_function" in tfl_op.builtin_options
@@ -82,50 +51,81 @@ class Operator:
return Operator(
tfl_op.type,
location,
- support_type,
activation_func=activation_func,
custom_name=(tfl_op.custom_type if tfl_op.is_custom else None),
)
- @staticmethod
- def _get_support_type(tfl_op: Op) -> Operator.SupportType:
- """Get the support type from the TensorFlow Lite operator."""
- compat_data = (
- Operator.CUSTOM_COMPATIBILITY
- if tfl_op.is_custom
- else Operator.BUILTIN_COMPATIBILITY
+
+class CortexACompatibilityInfo:
+ """Model's operators."""
+
+ class SupportType(Enum):
+ """Type of operator support."""
+
+ COMPATIBLE = "Compatible"
+ OP_NOT_SUPPORTED = "Operator not supported"
+ ACTIVATION_NOT_SUPPORTED = "Activation not supported"
+
+ def __init__(self, ops: list[Operator], armnn_tfl_delegate_version: str) -> None:
+ """Create a new collection of op compatibility information."""
+ compat_data = TFLITE_DELEGATE_COMPAT["ops"][armnn_tfl_delegate_version]
+ self._builtin_compatibility = compat_data["builtin_ops"]
+ self._custom_compatibility = compat_data["custom_ops"]
+
+ self.backend_info = (
+ f"{TFLITE_DELEGATE_COMPAT['backend']} {armnn_tfl_delegate_version}"
)
- op_type = tfl_op.custom_type if tfl_op.is_custom else tfl_op.type
- if op_type not in compat_data:
- return Operator.SupportType.OP_NOT_SUPPORTED
+ self.operators = ops
+
+ @property
+ def is_cortex_a_compatible(self) -> bool:
+ """Check if all operators are compatible."""
+ return all(self.is_op_compatible(oper) for oper in self.operators)
+
+ def is_op_compatible(self, operator: Operator) -> bool:
+ """Check if the given operator is compatible."""
+ return self.get_support_type(operator) == self.SupportType.COMPATIBLE
- compat_op = compat_data[op_type]
+ def compatibility_data(self, operator: Operator) -> dict[str, dict[str, Any]]:
+ """Get the compatibility data (builtin or custom ops)."""
+ return (
+ cast(dict, self._custom_compatibility)
+ if operator.is_custom
+ else cast(dict, self._builtin_compatibility)
+ )
+
+ def supported_activation_functions(self, operator: Operator) -> list[str]:
+ """Return a list of fused activation functions supported by this op."""
+ op_name = operator.custom_name if operator.custom_name else operator.name
+ return self.compatibility_data(operator)[op_name].get(
+ "supported_fused_activation", []
+ )
+
+ def get_support_type(
+ self, operator: Operator
+ ) -> CortexACompatibilityInfo.SupportType:
+ """Get the support type from the TensorFlow Lite operator."""
+ compat_data = self.compatibility_data(operator)
+ op_name = operator.custom_name if operator.is_custom else operator.name
+
+ if op_name not in compat_data:
+ return CortexACompatibilityInfo.SupportType.OP_NOT_SUPPORTED
+
+ compat_op = compat_data[op_name]
if "supported_fused_activation" in compat_op:
- assert tfl_op.builtin_options
- assert "fused_activation_function" in tfl_op.builtin_options
if (
- tfl_op.builtin_options["fused_activation_function"]
+ operator.activation_func.name
not in compat_op["supported_fused_activation"]
):
- return Operator.SupportType.ACTIVATION_NOT_SUPPORTED
+ return CortexACompatibilityInfo.SupportType.ACTIVATION_NOT_SUPPORTED
- return Operator.SupportType.COMPATIBLE
+ return CortexACompatibilityInfo.SupportType.COMPATIBLE
-@dataclass
-class CortexACompatibilityInfo:
- """Model's operators."""
-
- cortex_a_compatible: bool
- operators: list[Operator]
- backend_info: ClassVar[str] = (
- f"{TFLITE_DELEGATE_COMPAT['metadata']['backend']} "
- f"{TFLITE_DELEGATE_COMPAT['metadata']['version']}"
- )
-
-
-def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInfo:
+def get_cortex_a_compatibility_info(
+ model_path: Path, target_config: CortexAConfiguration
+) -> CortexACompatibilityInfo:
"""Return list of model's operators."""
model = parse_subgraphs(model_path)
@@ -134,8 +134,9 @@ def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInf
for g_idx, g in enumerate(model)
for op_idx, oper in enumerate(g)
]
- all_compatible = all(oper.is_cortex_a_compatible for oper in op_list)
- compat_info = CortexACompatibilityInfo(all_compatible, op_list)
+ compat_info = CortexACompatibilityInfo(
+ op_list, target_config.armnn_tflite_delegate_version
+ )
return compat_info