aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/cortex_a/operators.py
blob: 4a8d992667ceac7c7c01e19ba3faeba01f7860f7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A tools module."""
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
from typing import cast

from mlia.backend.armnn_tflite_delegate.compat import (
    ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT,
)
from mlia.nn.tensorflow.tflite_graph import Op
from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
from mlia.target.cortex_a.config import CortexAConfiguration


@dataclass
class Operator:
    """Cortex-A compatibility information of the operator."""

    name: str
    location: str
    activation_func: TFL_ACTIVATION_FUNCTION
    custom_name: str | None = None

    @property
    def full_name(self) -> str:
        """Returun the full name including the custom name if applicable."""
        return self.name + (f" - '{self.custom_name}'" if self.custom_name else "")

    @property
    def is_custom(self) -> bool:
        """Check if this is a custom operator."""
        return bool(self.custom_name)

    @classmethod
    def from_tflite_op(cls, tfl_op: Op, location: str) -> Operator:
        """Create a new instance from TensorFlow Lite operator and location."""
        activation_func = (
            TFL_ACTIVATION_FUNCTION[tfl_op.builtin_options["fused_activation_function"]]
            if (
                tfl_op.builtin_options
                and "fused_activation_function" in tfl_op.builtin_options
            )
            else TFL_ACTIVATION_FUNCTION.NONE
        )
        return Operator(
            tfl_op.type,
            location,
            activation_func=activation_func,
            custom_name=(tfl_op.custom_type if tfl_op.is_custom else None),
        )


class CortexACompatibilityInfo:
    """Model's operators."""

    class SupportType(Enum):
        """Type of operator support."""

        COMPATIBLE = "Compatible"
        OP_NOT_SUPPORTED = "Operator not supported"
        ACTIVATION_NOT_SUPPORTED = "Activation not supported"

    def __init__(self, ops: list[Operator], armnn_tfl_delegate_version: str) -> None:
        """Create a new collection of op compatibility information."""
        compat_data = TFLITE_DELEGATE_COMPAT["ops"][armnn_tfl_delegate_version]
        self._builtin_compatibility = compat_data["builtin_ops"]
        self._custom_compatibility = compat_data["custom_ops"]

        self.backend_info = (
            f"{TFLITE_DELEGATE_COMPAT['backend']} {armnn_tfl_delegate_version}"
        )

        self.operators = ops

    @property
    def is_cortex_a_compatible(self) -> bool:
        """Check if all operators are compatible."""
        return all(self.is_op_compatible(oper) for oper in self.operators)

    def is_op_compatible(self, operator: Operator) -> bool:
        """Check if the given operator is compatible."""
        return self.get_support_type(operator) == self.SupportType.COMPATIBLE

    def compatibility_data(self, operator: Operator) -> dict[str, dict[str, Any]]:
        """Get the compatibility data (builtin or custom ops)."""
        return (
            cast(dict, self._custom_compatibility)
            if operator.is_custom
            else cast(dict, self._builtin_compatibility)
        )

    def supported_activation_functions(self, operator: Operator) -> list[str]:
        """Return a list of fused activation functions supported by this op."""
        op_name = operator.custom_name if operator.custom_name else operator.name
        return self.compatibility_data(operator)[op_name].get(
            "supported_fused_activation", []
        )

    def get_support_type(
        self, operator: Operator
    ) -> CortexACompatibilityInfo.SupportType:
        """Get the support type from the TensorFlow Lite operator."""
        compat_data = self.compatibility_data(operator)
        op_name = operator.custom_name if operator.is_custom else operator.name

        if op_name not in compat_data:
            return CortexACompatibilityInfo.SupportType.OP_NOT_SUPPORTED

        compat_op = compat_data[op_name]
        if "supported_fused_activation" in compat_op:
            if (
                operator.activation_func.name
                not in compat_op["supported_fused_activation"]
            ):
                return CortexACompatibilityInfo.SupportType.ACTIVATION_NOT_SUPPORTED

        return CortexACompatibilityInfo.SupportType.COMPATIBLE


def get_cortex_a_compatibility_info(
    model_path: Path, target_config: CortexAConfiguration
) -> CortexACompatibilityInfo:
    """Return list of model's operators."""
    model = parse_subgraphs(model_path)

    op_list = [
        Operator.from_tflite_op(oper, f"subgraph:{g_idx},oper:{op_idx}")
        for g_idx, g in enumerate(model)
        for op_idx, oper in enumerate(g)
    ]
    compat_info = CortexACompatibilityInfo(
        op_list, target_config.armnn_tflite_delegate_version
    )

    return compat_info


def report() -> None:
    """Generate supported operators report."""
    raise NotImplementedError(
        "Generating a supported operators report is not "
        "currently supported with Cortex-A target profile."
    )