aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/join.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/join.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py111
1 files changed, 72 insertions, 39 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
index 758f4cf..14a7347 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/join.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -1,16 +1,31 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Join module."""
+from __future__ import annotations
+
import os
+from pathlib import Path
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+from tensorflow.lite.python.schema_py_generated import OperatorCodeT
+from tensorflow.lite.python.schema_py_generated import SubGraphT
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
-from mlia.nn.rewrite.core.utils.utils import load, save
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=0):
+def join_models(
+ input_src: str | Path,
+ input_dst: str | Path,
+ output_file: str | Path,
+ subgraph_src: SubGraphT = 0,
+ subgraph_dst: SubGraphT = 0,
+) -> None:
+ """Join two models and save the result into a given model file path."""
src_model = load(input_src)
dst_model = load(input_dst)
src_subgraph = src_model.subgraphs[subgraph_src]
@@ -19,8 +34,13 @@ def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=
save(dst_model, output_file)
-def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
- """Copy subgraph src into subgraph dst from model, connecting tensors with the same names"""
+def join_subgraphs(
+ src_model: ModelT,
+ src_subgraph: SubGraphT,
+ dst_model: ModelT,
+ dst_subgraph: SubGraphT,
+) -> None:
+ """Join two subgraphs, connecting tensors with the same names."""
# Find inputs that match outputs in the other graph and vice versa
dst_to_src = {
i: o
@@ -36,10 +56,11 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name
}
- assert not (src_to_dst and dst_to_src), (
- "Source and destination subgraphs appear to connect in a loop: %d tensors from src to dst, %d tensors from dst to src"
- % (len(src_to_dst), len(dst_to_src))
- )
+ assert not (
+ src_to_dst and dst_to_src
+ ), f"Source and destination subgraphs appear to connect in a loop: \
+ {len(src_to_dst)} tensors from src to dst, {len(dst_to_src)} \
+ tensors from dst to src"
# Relabel matched input/output tensors between graphs
tensor_relabel = src_to_dst if src_to_dst else dst_to_src
@@ -47,36 +68,46 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
# Remove matched inputs/outputs as these will now become internal tensors
if src_to_dst:
src_subgraph.outputs = [
- o for o in src_subgraph.outputs if not o in tensor_relabel.keys()
+ output
+ for output in src_subgraph.outputs
+ if output not in tensor_relabel.keys()
]
dst_subgraph.inputs = [
- i for i in dst_subgraph.inputs if not i in tensor_relabel.values()
+ input
+ for input in dst_subgraph.inputs
+ if input not in tensor_relabel.values()
]
else:
src_subgraph.inputs = [
- i for i in src_subgraph.inputs if not i in tensor_relabel.keys()
+ input for input in src_subgraph.inputs if input not in tensor_relabel.keys()
]
dst_subgraph.outputs = [
- o for o in dst_subgraph.outputs if not o in tensor_relabel.values()
+ output
+ for output in dst_subgraph.outputs
+ if output not in tensor_relabel.values()
]
buffer_relabel = {
- src_subgraph.tensors[i].buffer: dst_subgraph.tensors[o].buffer
- for i, o in tensor_relabel.items()
+ src_subgraph.tensors[input].buffer: dst_subgraph.tensors[output].buffer
+ for input, output in tensor_relabel.items()
}
used_tensors = [
- t for i, t in enumerate(src_subgraph.tensors) if not i in tensor_relabel
+ tensor
+ for i, tensor in enumerate(src_subgraph.tensors)
+ if i not in tensor_relabel
]
- used_buffer_ids = [t.buffer for t in used_tensors]
+ used_buffer_ids = [tensor.buffer for tensor in used_tensors]
+
+ def opcode_data(code: OperatorCodeT) -> tuple:
+ return (
+ code.builtinCode,
+ code.deprecatedBuiltinCode,
+ code.customCode,
+ code.version,
+ )
- opcode_data = lambda c: (
- c.builtinCode,
- c.deprecatedBuiltinCode,
- c.customCode,
- c.version,
- )
opcode_relabel = {
s: d
for s in range(len(src_model.operatorCodes))
@@ -85,7 +116,8 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
== opcode_data(dst_model.operatorCodes[d])
}
- # operator order defines execution schedule so must reflect the inputs/outputs dependencies
+ # operator order defines execution schedule so must reflect
+ # the inputs/outputs dependencies
if dst_to_src:
dst_subgraph.operators += src_subgraph.operators
else:
@@ -99,17 +131,17 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
] = -1 # Some files have ops with -1 input tensors; leave unchanged
for i in used_buffer_ids:
- if not i in buffer_relabel:
+ if i not in buffer_relabel:
buffer_relabel[i] = len(dst_model.buffers)
dst_model.buffers.append(src_model.buffers[i])
- for o in src_subgraph.operators:
- o.inputs = [tensor_relabel[t] for t in o.inputs]
- o.outputs = [tensor_relabel[t] for t in o.outputs]
- o.opcodeIndex = opcode_relabel[o.opcodeIndex]
+ for operator in src_subgraph.operators:
+ operator.inputs = [tensor_relabel[tensor] for tensor in operator.inputs]
+ operator.outputs = [tensor_relabel[tensor] for tensor in operator.outputs]
+ operator.opcodeIndex = opcode_relabel[operator.opcodeIndex]
- for t in used_tensors:
- t.buffer = buffer_relabel[t.buffer]
+ for tensor in used_tensors:
+ tensor.buffer = buffer_relabel[tensor.buffer]
src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs]
src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs]
@@ -118,11 +150,12 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs))
-def append_relabel(src, dst, map=None):
- if map is None:
- map = {}
- for i, x in enumerate(src):
- if not i in map:
- map[i] = len(dst)
+def append_relabel(src: list, dst: list, operator_map: dict | None = None) -> dict:
+ """Return a map over relabeled tensors in a subgraph."""
+ if not operator_map:
+ operator_map = {}
+ for i, x in enumerate(src): # pylint: disable=invalid-name
+ if i not in operator_map:
+ operator_map[i] = len(dst)
dst.append(x)
- return map
+ return operator_map