aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-10-25 18:12:34 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-11-10 16:47:22 +0000
commite40a7adadd254e29d71af38f69a0a20ff4871eef (patch)
tree9a57ddf406846785683673565359d9bd6ba3cf0b /src
parent720839a2dc6d4d75cd7aa77f83fcd49bcf114ba6 (diff)
downloadmlia-e40a7adadd254e29d71af38f69a0a20ff4871eef.tar.gz
MLIA-411 Report Cortex-A operator compatibility
Check input model for Arm NN TensorFlow Lite Delegate 22.08 support. Change-Id: I1253c4c0b294c5283e08f0a39561b922ef0f62e6
Diffstat (limited to 'src')
-rw-r--r--src/mlia/cli/main.py3
-rw-r--r--src/mlia/devices/cortexa/advice_generation.py50
-rw-r--r--src/mlia/devices/cortexa/data_analysis.py57
-rw-r--r--src/mlia/devices/cortexa/operator_compatibility.py184
-rw-r--r--src/mlia/devices/cortexa/operators.py121
-rw-r--r--src/mlia/devices/cortexa/reporters.py12
-rw-r--r--src/mlia/nn/tensorflow/tflite_graph.py139
7 files changed, 545 insertions, 21 deletions
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index bafe434..d36d2d9 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -37,10 +37,11 @@ logger = logging.getLogger(__name__)
INFO_MESSAGE = f"""
ML Inference Advisor {__version__}
-Help the design and optimization of neural network models for efficient inference on a target CPU, GPU and NPU
+Help the design and optimization of neural network models for efficient inference on a target CPU and NPU
Supported targets:
+ - Cortex-A <op compatibility>
- Ethos-U55 <op compatibility, perf estimation, model opt>
- Ethos-U65 <op compatibility, perf estimation, model opt>
- TOSA <op compatibility>
diff --git a/src/mlia/devices/cortexa/advice_generation.py b/src/mlia/devices/cortexa/advice_generation.py
index 0f3553f..34c51f8 100644
--- a/src/mlia/devices/cortexa/advice_generation.py
+++ b/src/mlia/devices/cortexa/advice_generation.py
@@ -15,6 +15,13 @@ from mlia.devices.cortexa.data_analysis import ModelIsNotTFLiteCompatible
class CortexAAdviceProducer(FactBasedAdviceProducer):
"""Cortex-A advice producer."""
+ cortex_a_disclaimer = (
+ "Note that the provided compatibility information is general. "
+ "At runtime individual operators in the given model might fall back to "
+ "the TensorFlow Lite reference or might produce errors based on the "
+ "specific parameters."
+ )
+
@singledispatchmethod
def produce_advice(self, _data_item: DataItem) -> None:
"""Produce advice."""
@@ -22,21 +29,54 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
@produce_advice.register
@advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
def handle_model_is_cortex_a_compatible(
- self, _data_item: ModelIsCortexACompatible
+ self, data_item: ModelIsCortexACompatible
) -> None:
"""Advice for Cortex-A compatibility."""
- self.add_advice(["Model is fully compatible with Cortex-A."])
+ self.add_advice(
+ [
+ f"Model is fully compatible with {data_item.backend_info} for "
+ "Cortex-A.",
+ self.cortex_a_disclaimer,
+ ]
+ )
@produce_advice.register
@advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
def handle_model_is_not_cortex_a_compatible(
- self, _data_item: ModelIsNotCortexACompatible
+ self, data_item: ModelIsNotCortexACompatible
) -> None:
"""Advice for Cortex-A compatibility."""
+ if data_item.unsupported_ops:
+ self.add_advice(
+ [
+ "The following operators are not supported by "
+ f"{data_item.backend_info} and will fall back to the "
+ "TensorFlow Lite runtime:",
+ "\n".join(f" - {op}" for op in data_item.unsupported_ops),
+ ]
+ )
+
+ if data_item.activation_func_support:
+ self.add_advice(
+ [
+ "The fused activation functions of the following operators "
+ f"are not supported by {data_item.backend_info}. Please "
+ "consider using one of the supported activation functions "
+ "instead:",
+ "\n".join(
+ f" - {op}\n"
+ f" - Used unsupported: {act.used_unsupported}\n"
+ f" - Supported: {act.supported}"
+ for op, act in data_item.activation_func_support.items()
+ ),
+ ]
+ )
+
self.add_advice(
[
- "Some operators in the model are not compatible with Cortex-A. "
- "Please, refer to the operators table for more information."
+ "Please, refer to the full table of operators above for more "
+ "information.",
+ self.cortex_a_disclaimer,
]
)
diff --git a/src/mlia/devices/cortexa/data_analysis.py b/src/mlia/devices/cortexa/data_analysis.py
index d2b6f35..9f6d82b 100644
--- a/src/mlia/devices/cortexa/data_analysis.py
+++ b/src/mlia/devices/cortexa/data_analysis.py
@@ -3,13 +3,16 @@
"""Cortex-A data analysis module."""
from __future__ import annotations
+from collections import defaultdict
from dataclasses import dataclass
+from dataclasses import field
from functools import singledispatchmethod
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.core.data_analysis import FactExtractor
from mlia.devices.cortexa.operators import CortexACompatibilityInfo
+from mlia.devices.cortexa.operators import Operator
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
@@ -27,9 +30,38 @@ class CortexADataAnalyzer(FactExtractor):
) -> None:
"""Analyse operator compatibility information."""
if data_item.cortex_a_compatible:
- self.add_fact(ModelIsCortexACompatible())
+ self.add_fact(ModelIsCortexACompatible(data_item.backend_info))
else:
- self.add_fact(ModelIsNotCortexACompatible())
+ unsupported_ops = set()
+ activation_func_support: defaultdict[
+ str, ModelIsNotCortexACompatible.ActivationFunctionSupport
+ ] = defaultdict(ModelIsNotCortexACompatible.ActivationFunctionSupport)
+ for oper in data_item.operators:
+ if oper.support_type == Operator.SupportType.OP_NOT_SUPPORTED:
+ unsupported_ops.add(oper.full_name)
+
+ if oper.support_type == Operator.SupportType.ACTIVATION_NOT_SUPPORTED:
+ # Add used but unsupported actication functions
+ activation_func_support[oper.full_name].used_unsupported.add(
+ oper.activation_func.name
+ )
+ # Add supported activation functions
+ activation_func_support[oper.full_name].supported.update(
+ oper.supported_activation_functions
+ )
+
+ assert (
+ unsupported_ops or activation_func_support or not data_item.operators
+ ), (
+ "The model is marked as not compatible with Cortex-A but there "
+ "are no unsupported ops activation functions listed."
+ )
+
+ self.add_fact(
+ ModelIsNotCortexACompatible(
+ data_item.backend_info, unsupported_ops, activation_func_support
+ )
+ )
@analyze_data.register
def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None:
@@ -52,14 +84,31 @@ class CortexADataAnalyzer(FactExtractor):
@dataclass
-class ModelIsCortexACompatible(Fact):
+class CortexACompatibility(Fact):
+ """Base class for Cortex-A compatibility providing backend info."""
+
+ backend_info: str
+
+
+@dataclass
+class ModelIsCortexACompatible(CortexACompatibility):
"""Model is completely compatible with Cortex-A."""
@dataclass
-class ModelIsNotCortexACompatible(Fact):
+class ModelIsNotCortexACompatible(CortexACompatibility):
"""Model is not compatible with Cortex-A."""
+ @dataclass
+ class ActivationFunctionSupport:
+ """Activation function support per operator."""
+
+ used_unsupported: set[str] = field(default_factory=set)
+ supported: set[str] = field(default_factory=set)
+
+ unsupported_ops: set[str]
+ activation_func_support: dict[str, ActivationFunctionSupport]
+
@dataclass
class ModelIsNotTFLiteCompatible(Fact):
diff --git a/src/mlia/devices/cortexa/operator_compatibility.py b/src/mlia/devices/cortexa/operator_compatibility.py
new file mode 100644
index 0000000..c474e75
--- /dev/null
+++ b/src/mlia/devices/cortexa/operator_compatibility.py
@@ -0,0 +1,184 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Collection of Cortex-A operator compatibility information."""
+from __future__ import annotations
+
+from typing import Any
+
+ARMNN_TFLITE_DELEGATE: dict[str, dict[str, Any]] = {
+ "metadata": {
+ "backend": "Arm NN TensorFlow Lite delegate",
+ "version": "22.08",
+ },
+ # BUILTIN OPERATORS
+ "builtin_ops": {
+ "ABS": {},
+ "ADD": {},
+ "ARG_MAX": {},
+ "ARG_MIN": {},
+ "AVERAGE_POOL_2D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "BATCH_TO_SPACE_ND": {},
+ "CAST": {},
+ "CONCATENATION": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "CONV_2D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "CONV_3D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "DEPTH_TO_SPACE": {},
+ "DEPTHWISE_CONV_2D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "DEQUANTIZE": {},
+ "DIV": {},
+ "EQUAL": {},
+ "ELU": {},
+ "EXP": {},
+ "EXPAND_DIMS": {},
+ "FILL": {},
+ "FLOOR": {},
+ "FLOOR_DIV": {},
+ "FULLY_CONNECTED": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "GATHER": {},
+ "GATHER_ND": {},
+ "GREATER": {},
+ "GREATER_EQUAL": {},
+ "HARD_SWISH": {},
+ "L2_NORMALIZATION": {},
+ "L2_POOL_2D": {},
+ "LESS": {},
+ "LESS_EQUAL": {},
+ "LOCAL_RESPONSE_NORMALIZATION": {},
+ "LOG": {},
+ "LOGICAL_AND": {},
+ "LOGICAL_NOT": {},
+ "LOGICAL_OR": {},
+ "LOGISTIC": {},
+ "LOG_SOFTMAX": {},
+ "LSTM": {},
+ "MAXIMUM": {},
+ "MAX_POOL_2D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "MEAN": {},
+ "MINIMUM": {},
+ "MIRROR_PAD": {},
+ "MUL": {},
+ "NEG": {},
+ "NOT_EQUAL": {},
+ "PACK": {},
+ "PAD": {},
+ "PADV2": {},
+ "PRELU": {},
+ "QUANTIZE": {},
+ "RANK": {},
+ "REDUCE_MAX": {},
+ "REDUCE_MIN": {},
+ "REDUCE_PROD": {},
+ "RELU": {},
+ "RELU6": {},
+ "RELU_N1_TO_1": {},
+ "RESHAPE": {},
+ "RESIZE_BILINEAR": {},
+ "RESIZE_NEAREST_NEIGHBOR": {},
+ "RSQRT": {},
+ "SHAPE": {},
+ "SIN": {},
+ "SOFTMAX": {},
+ "SPACE_TO_BATCH_ND": {},
+ "SPACE_TO_DEPTH": {},
+ "SPLIT": {},
+ "SPLIT_V": {},
+ "SQRT": {},
+ "SQUEEZE": {},
+ "STRIDED_SLICE": {},
+ "SUB": {},
+ "SUM": {},
+ "TANH": {},
+ "TRANSPOSE": {},
+ "TRANSPOSE_CONV": {},
+ "UNIDIRECTIONAL_SEQUENCE_LSTM": {},
+ "UNPACK": {},
+ },
+ # CUSTOM OPERATORS
+ "custom_ops": {
+ "AveragePool3D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "SIGN_BIT",
+ "TANH",
+ "NONE",
+ ]
+ },
+ "MaxPool3D": {
+ "supported_fused_activation": [
+ "RELU",
+ "RELU6",
+ "RELU_N1_TO_1",
+ "SIGMOID",
+ "SIGN_BIT",
+ "TANH",
+ "NONE",
+ ]
+ },
+ },
+}
diff --git a/src/mlia/devices/cortexa/operators.py b/src/mlia/devices/cortexa/operators.py
index d46b107..3e84d64 100644
--- a/src/mlia/devices/cortexa/operators.py
+++ b/src/mlia/devices/cortexa/operators.py
@@ -4,16 +4,113 @@
from __future__ import annotations
from dataclasses import dataclass
+from enum import Enum
from pathlib import Path
+from typing import Any
+from typing import ClassVar
+
+from mlia.devices.cortexa.operator_compatibility import (
+ ARMNN_TFLITE_DELEGATE as TFLITE_DELEGATE_COMPAT,
+)
+from mlia.nn.tensorflow.tflite_graph import Op
+from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
+from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
@dataclass
class Operator:
"""Cortex-A compatibility information of the operator."""
+ BUILTIN_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["builtin_ops"]
+ CUSTOM_COMPATIBILITY = TFLITE_DELEGATE_COMPAT["custom_ops"]
+
+ class SupportType(Enum):
+ """Type of operator support."""
+
+ COMPATIBLE = "Compatible"
+ OP_NOT_SUPPORTED = "Operator not supported"
+ ACTIVATION_NOT_SUPPORTED = "Activation not supported"
+
name: str
location: str
- is_cortex_a_compatible: bool
+ support_type: SupportType
+ activation_func: TFL_ACTIVATION_FUNCTION
+ custom_name: str | None = None
+
+ @property
+ def is_cortex_a_compatible(self) -> bool:
+ """Check if this operator is compatible."""
+ return self.support_type == Operator.SupportType.COMPATIBLE
+
+ @property
+ def full_name(self) -> str:
+ """Returun the full name including the custom name if applicable."""
+ return self.name + (f" - '{self.custom_name}'" if self.custom_name else "")
+
+ @property
+ def is_custom(self) -> bool:
+ """Check if this is a custom operator."""
+ return bool(self.custom_name)
+
+ @property
+ def compatibility_data(self) -> dict[str, dict[str, Any]]:
+ """Get the compatibility data (builtin or custom ops)."""
+ return (
+ Operator.CUSTOM_COMPATIBILITY
+ if self.is_custom
+ else Operator.BUILTIN_COMPATIBILITY
+ )
+
+ @property
+ def supported_activation_functions(self) -> list[str]:
+ """Return a list of fused activation functions supported by this op."""
+ op_name = self.custom_name if self.custom_name else self.name
+ return self.compatibility_data[op_name].get("supported_fused_activation", [])
+
+ @classmethod
+ def from_tflite_op(cls, tfl_op: Op, location: str) -> Operator:
+ """Create a new instance from TensorFlow Lite operator and location."""
+ support_type = cls._get_support_type(tfl_op)
+ activation_func = (
+ tfl_op.builtin_options["fused_activation_function"]
+ if (
+ tfl_op.builtin_options
+ and "fused_activation_function" in tfl_op.builtin_options
+ )
+ else TFL_ACTIVATION_FUNCTION.NONE
+ )
+ return Operator(
+ tfl_op.type,
+ location,
+ support_type,
+ activation_func=activation_func,
+ custom_name=(tfl_op.custom_type if tfl_op.is_custom else None),
+ )
+
+ @staticmethod
+ def _get_support_type(tfl_op: Op) -> Operator.SupportType:
+ """Get the support type from the TensorFlow Lite operator."""
+ compat_data = (
+ Operator.CUSTOM_COMPATIBILITY
+ if tfl_op.is_custom
+ else Operator.BUILTIN_COMPATIBILITY
+ )
+ op_type = tfl_op.custom_type if tfl_op.is_custom else tfl_op.type
+
+ if op_type not in compat_data:
+ return Operator.SupportType.OP_NOT_SUPPORTED
+
+ compat_op = compat_data[op_type]
+ if "supported_fused_activation" in compat_op:
+ assert tfl_op.builtin_options
+ assert "fused_activation_function" in tfl_op.builtin_options
+ if (
+ tfl_op.builtin_options["fused_activation_function"]
+ not in compat_op["supported_fused_activation"]
+ ):
+ return Operator.SupportType.ACTIVATION_NOT_SUPPORTED
+
+ return Operator.SupportType.COMPATIBLE
@dataclass
@@ -21,14 +118,26 @@ class CortexACompatibilityInfo:
"""Model's operators."""
cortex_a_compatible: bool
- operators: list[Operator] | None = None
+ operators: list[Operator]
+ backend_info: ClassVar[str] = (
+ f"{TFLITE_DELEGATE_COMPAT['metadata']['backend']} "
+ f"{TFLITE_DELEGATE_COMPAT['metadata']['version']}"
+ )
-def get_cortex_a_compatibility_info(
- _model_path: Path,
-) -> CortexACompatibilityInfo | None:
+def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInfo:
"""Return list of model's operators."""
- return None
+ model = parse_subgraphs(model_path)
+
+ op_list = [
+ Operator.from_tflite_op(oper, f"subgraph:{g_idx},oper:{op_idx}")
+ for g_idx, g in enumerate(model)
+ for op_idx, oper in enumerate(g)
+ ]
+ all_compatible = all(oper.is_cortex_a_compatible for oper in op_list)
+ compat_info = CortexACompatibilityInfo(all_compatible, op_list)
+
+ return compat_info
def report() -> None:
diff --git a/src/mlia/devices/cortexa/reporters.py b/src/mlia/devices/cortexa/reporters.py
index a55caba..84de10b 100644
--- a/src/mlia/devices/cortexa/reporters.py
+++ b/src/mlia/devices/cortexa/reporters.py
@@ -5,6 +5,7 @@ from __future__ import annotations
from typing import Any
from typing import Callable
+from typing import cast
from mlia.core.advice_generation import Advice
from mlia.core.reporters import report_advice
@@ -96,21 +97,22 @@ def report_cortex_a_operators(ops: list[Operator]) -> Report:
),
Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)),
Column(
- "Cortex-A compatibility",
+ "Arm NN TFLite Delegate compatibility",
alias="cortex_a_compatible",
- fmt=Format(wrap_width=25),
+ fmt=Format(wrap_width=40),
),
],
[
(
index + 1,
op.location,
- op.name,
+ op.full_name,
Cell(
- op.is_cortex_a_compatible,
+ op.support_type,
Format(
+ wrap_width=30,
style=style_improvement(op.is_cortex_a_compatible),
- str_fmt=lambda v: "Compatible" if v else "Not compatible",
+ str_fmt=lambda v: cast(str, v.value),
),
),
)
diff --git a/src/mlia/nn/tensorflow/tflite_graph.py b/src/mlia/nn/tensorflow/tflite_graph.py
new file mode 100644
index 0000000..4f5e85f
--- /dev/null
+++ b/src/mlia/nn/tensorflow/tflite_graph.py
@@ -0,0 +1,139 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utilities for TensorFlow Lite graphs."""
+from __future__ import annotations
+
+import enum
+import json
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+from typing import cast
+
+from tensorflow.lite.python import schema_py_generated as schema_fb
+from tensorflow.lite.tools import visualize
+
+
+def _enum_from_class(cls: Any) -> Any:
+ """Create an enum from the public class variables."""
+ return enum.Enum(
+ cls.__name__,
+ {key: value for key, value in vars(cls).items() if not key.startswith("_")},
+ )
+
+
+TFL_TYPE = _enum_from_class(schema_fb.TensorType)
+TFL_OP = _enum_from_class(schema_fb.BuiltinOperator)
+TFL_ACTIVATION_FUNCTION = _enum_from_class(schema_fb.ActivationFunctionType)
+
+
+def _ascii_list_to_string(ascii_list: list[int]) -> str:
+ return "".join(chr(i) for i in ascii_list)
+
+
+@dataclass
+class TensorInfo:
+ """Collection of tensor information parsed from a TensorFlow Lite file."""
+
+ name: str
+ type: str
+ shape: tuple | list
+ is_variable: bool
+
+ def __str__(self) -> str:
+ """Create a text represenation of this TensorInfo instance."""
+ return f"{self.name}: {self.type}, {self.shape}, is_variable={self.is_variable}"
+
+ def __repr__(self) -> str:
+ """Convert this instance to JSON."""
+ return json.dumps(vars(self))
+
+ @classmethod
+ def from_dict(cls, tensor: dict[str, Any]) -> TensorInfo:
+ """
+ Create a new instance from a dictionary.
+
+ The expected dict is the one contained in the dict returned by
+ visualize.CreateDictFromFlatbuffer().
+ """
+ return TensorInfo(
+ _ascii_list_to_string(tensor["name"]),
+ TFL_TYPE(tensor["type"]).name,
+ tensor["shape"],
+ tensor["is_variable"],
+ )
+
+
+@dataclass
+class Op:
+ """
+ Representation of an operator from a TensorFlow Lite file.
+
+ E.g. collects the operator type, input/output tensors etc.
+ """
+
+ type: str
+ builtin_options: dict
+ inputs: list[TensorInfo]
+ outputs: list[TensorInfo]
+ custom_type: str | None = None
+
+ def __post_init__(self) -> None:
+ """Convert the builtin option 'fused_activation_function' to string."""
+ if "fused_activation_function" in self.builtin_options:
+ # Convert the fused activation function ID to a string
+ self.builtin_options["fused_activation_function"] = TFL_ACTIVATION_FUNCTION(
+ self.builtin_options["fused_activation_function"]
+ ).name
+
+ def __str__(self) -> str:
+ """Create a text represenation of this Op instance."""
+ return f"""{self.type}
+ builtin_options: {self.builtin_options}
+ inputs: {self.inputs}
+ outputs: {self.outputs}"""
+
+ @property
+ def is_custom(self) -> bool:
+ """Check if this Op is a custom operator."""
+ return self.type == cast(str, TFL_OP.CUSTOM.name)
+
+ @classmethod
+ def from_model_info(cls, oper: dict, graph: dict, model: dict) -> Op:
+ """Create a new Op from the model information."""
+ op_code_idx = oper["opcode_index"]
+ op_code_obj = model["operator_codes"][op_code_idx]
+ op_code = max(
+ op_code_obj["builtin_code"], op_code_obj["deprecated_builtin_code"]
+ )
+ custom_code = op_code_obj.get("custom_code")
+ return cls(
+ type=cast(str, TFL_OP(op_code).name),
+ builtin_options=oper["builtin_options"] if oper["builtin_options"] else {},
+ inputs=[
+ TensorInfo.from_dict(graph["tensors"][idx]) for idx in oper["inputs"]
+ ],
+ outputs=[
+ TensorInfo.from_dict(graph["tensors"][idx]) for idx in oper["outputs"]
+ ],
+ custom_type=_ascii_list_to_string(custom_code) if custom_code else None,
+ )
+
+
+def load_tflite(file: Path) -> bytes:
+ """Load a TensorFlow Lite file from disk."""
+ return file.read_bytes()
+
+
+def parse_subgraphs(tflite_file: Path) -> list[list[Op]]:
+ """Load the TensorFlow Lite file and parse the subgraphs."""
+ tflite_model = load_tflite(tflite_file)
+ model = cast(dict, visualize.CreateDictFromFlatbuffer(tflite_model))
+ assert isinstance(model, dict)
+
+ graphs = [
+ [Op.from_model_info(oper, g, model) for oper in g["operators"]]
+ for g in model["subgraphs"]
+ ]
+
+ return graphs