diff options
Diffstat (limited to 'src/mlia/target/cortex_a/operators.py')
-rw-r--r-- | src/mlia/target/cortex_a/operators.py | 135 |
1 files changed, 68 insertions, 67 deletions
diff --git a/src/mlia/target/cortex_a/operators.py b/src/mlia/target/cortex_a/operators.py index ae611e5..cd92f31 100644 --- a/src/mlia/target/cortex_a/operators.py +++ b/src/mlia/target/cortex_a/operators.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cortex-A tools module.""" from __future__ import annotations @@ -7,7 +7,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Any -from typing import ClassVar +from typing import cast from mlia.backend.armnn_tflite_delegate.compat import ( ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT, @@ -15,34 +15,19 @@ from mlia.backend.armnn_tflite_delegate.compat import ( from mlia.nn.tensorflow.tflite_graph import Op from mlia.nn.tensorflow.tflite_graph import parse_subgraphs from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION +from mlia.target.cortex_a.config import CortexAConfiguration @dataclass class Operator: """Cortex-A compatibility information of the operator.""" - BUILTIN_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["builtin_ops"] - CUSTOM_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["custom_ops"] - - class SupportType(Enum): - """Type of operator support.""" - - COMPATIBLE = "Compatible" - OP_NOT_SUPPORTED = "Operator not supported" - ACTIVATION_NOT_SUPPORTED = "Activation not supported" - name: str location: str - support_type: SupportType activation_func: TFL_ACTIVATION_FUNCTION custom_name: str | None = None @property - def is_cortex_a_compatible(self) -> bool: - """Check if this operator is compatible.""" - return self.support_type == Operator.SupportType.COMPATIBLE - - @property def full_name(self) -> str: """Returun the full name including the custom name if applicable.""" return self.name + (f" - '{self.custom_name}'" if self.custom_name else "") @@ -52,27 +37,11 @@ class Operator: """Check if this is a custom operator.""" return bool(self.custom_name) - @property - def compatibility_data(self) -> dict[str, dict[str, Any]]: - """Get the compatibility data (builtin or custom ops).""" - return ( - Operator.CUSTOM_COMPATIBILITY - if self.is_custom - else Operator.BUILTIN_COMPATIBILITY - ) - - @property - def supported_activation_functions(self) -> list[str]: - """Return a list of fused activation functions supported by this op.""" - op_name = self.custom_name if self.custom_name else self.name - return self.compatibility_data[op_name].get("supported_fused_activation", []) - @classmethod def from_tflite_op(cls, tfl_op: Op, location: str) -> Operator: """Create a new instance from TensorFlow Lite operator and location.""" - support_type = cls._get_support_type(tfl_op) activation_func = ( - tfl_op.builtin_options["fused_activation_function"] + TFL_ACTIVATION_FUNCTION[tfl_op.builtin_options["fused_activation_function"]] if ( tfl_op.builtin_options and "fused_activation_function" in tfl_op.builtin_options @@ -82,50 +51,81 @@ class Operator: return Operator( tfl_op.type, location, - support_type, activation_func=activation_func, custom_name=(tfl_op.custom_type if tfl_op.is_custom else None), ) - @staticmethod - def _get_support_type(tfl_op: Op) -> Operator.SupportType: - """Get the support type from the TensorFlow Lite operator.""" - compat_data = ( - Operator.CUSTOM_COMPATIBILITY - if tfl_op.is_custom - else Operator.BUILTIN_COMPATIBILITY + +class CortexACompatibilityInfo: + """Model's operators.""" + + class SupportType(Enum): + """Type of operator support.""" + + COMPATIBLE = "Compatible" + OP_NOT_SUPPORTED = "Operator not supported" + ACTIVATION_NOT_SUPPORTED = "Activation not supported" + + def __init__(self, ops: list[Operator], armnn_tfl_delegate_version: str) -> None: + """Create a new collection of op compatibility information.""" + compat_data = TFLITE_DELEGATE_COMPAT["ops"][armnn_tfl_delegate_version] + self._builtin_compatibility = compat_data["builtin_ops"] + self._custom_compatibility = compat_data["custom_ops"] + + self.backend_info = ( + f"{TFLITE_DELEGATE_COMPAT['backend']} {armnn_tfl_delegate_version}" ) - op_type = tfl_op.custom_type if tfl_op.is_custom else tfl_op.type - if op_type not in compat_data: - return Operator.SupportType.OP_NOT_SUPPORTED + self.operators = ops + + @property + def is_cortex_a_compatible(self) -> bool: + """Check if all operators are compatible.""" + return all(self.is_op_compatible(oper) for oper in self.operators) + + def is_op_compatible(self, operator: Operator) -> bool: + """Check if the given operator is compatible.""" + return self.get_support_type(operator) == self.SupportType.COMPATIBLE - compat_op = compat_data[op_type] + def compatibility_data(self, operator: Operator) -> dict[str, dict[str, Any]]: + """Get the compatibility data (builtin or custom ops).""" + return ( + cast(dict, self._custom_compatibility) + if operator.is_custom + else cast(dict, self._builtin_compatibility) + ) + + def supported_activation_functions(self, operator: Operator) -> list[str]: + """Return a list of fused activation functions supported by this op.""" + op_name = operator.custom_name if operator.custom_name else operator.name + return self.compatibility_data(operator)[op_name].get( + "supported_fused_activation", [] + ) + + def get_support_type( + self, operator: Operator + ) -> CortexACompatibilityInfo.SupportType: + """Get the support type from the TensorFlow Lite operator.""" + compat_data = self.compatibility_data(operator) + op_name = operator.custom_name if operator.is_custom else operator.name + + if op_name not in compat_data: + return CortexACompatibilityInfo.SupportType.OP_NOT_SUPPORTED + + compat_op = compat_data[op_name] if "supported_fused_activation" in compat_op: - assert tfl_op.builtin_options - assert "fused_activation_function" in tfl_op.builtin_options if ( - tfl_op.builtin_options["fused_activation_function"] + operator.activation_func.name not in compat_op["supported_fused_activation"] ): - return Operator.SupportType.ACTIVATION_NOT_SUPPORTED + return CortexACompatibilityInfo.SupportType.ACTIVATION_NOT_SUPPORTED - return Operator.SupportType.COMPATIBLE + return CortexACompatibilityInfo.SupportType.COMPATIBLE -@dataclass -class CortexACompatibilityInfo: - """Model's operators.""" - - cortex_a_compatible: bool - operators: list[Operator] - backend_info: ClassVar[str] = ( - f"{TFLITE_DELEGATE_COMPAT['metadata']['backend']} " - f"{TFLITE_DELEGATE_COMPAT['metadata']['version']}" - ) - - -def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInfo: +def get_cortex_a_compatibility_info( + model_path: Path, target_config: CortexAConfiguration +) -> CortexACompatibilityInfo: """Return list of model's operators.""" model = parse_subgraphs(model_path) @@ -134,8 +134,9 @@ def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInf for g_idx, g in enumerate(model) for op_idx, oper in enumerate(g) ] - all_compatible = all(oper.is_cortex_a_compatible for oper in op_list) - compat_info = CortexACompatibilityInfo(all_compatible, op_list) + compat_info = CortexACompatibilityInfo( + op_list, target_config.armnn_tflite_delegate_version + ) return compat_info |