diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/join.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/join.py | 111 |
1 files changed, 72 insertions, 39 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py index 758f4cf..14a7347 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/join.py +++ b/src/mlia/nn/rewrite/core/graph_edit/join.py @@ -1,16 +1,31 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Join module.""" +from __future__ import annotations + import os +from pathlib import Path -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT +from tensorflow.lite.python.schema_py_generated import OperatorCodeT +from tensorflow.lite.python.schema_py_generated import SubGraphT -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.rewrite.core.utils.utils import save -from mlia.nn.rewrite.core.utils.utils import load, save +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=0): +def join_models( + input_src: str | Path, + input_dst: str | Path, + output_file: str | Path, + subgraph_src: SubGraphT = 0, + subgraph_dst: SubGraphT = 0, +) -> None: + """Join two models and save the result into a given model file path.""" src_model = load(input_src) dst_model = load(input_dst) src_subgraph = src_model.subgraphs[subgraph_src] @@ -19,8 +34,13 @@ def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst= save(dst_model, output_file) -def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): - """Copy subgraph src into subgraph dst from model, connecting tensors with the same names""" +def join_subgraphs( + src_model: ModelT, + src_subgraph: SubGraphT, + dst_model: ModelT, + dst_subgraph: SubGraphT, +) -> None: + """Join two subgraphs, connecting tensors with the same names.""" # Find inputs that match outputs in the other graph and vice versa dst_to_src = { i: o @@ -36,10 +56,11 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name } - assert not (src_to_dst and dst_to_src), ( - "Source and destination subgraphs appear to connect in a loop: %d tensors from src to dst, %d tensors from dst to src" - % (len(src_to_dst), len(dst_to_src)) - ) + assert not ( + src_to_dst and dst_to_src + ), f"Source and destination subgraphs appear to connect in a loop: \ + {len(src_to_dst)} tensors from src to dst, {len(dst_to_src)} \ + tensors from dst to src" # Relabel matched input/output tensors between graphs tensor_relabel = src_to_dst if src_to_dst else dst_to_src @@ -47,36 +68,46 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): # Remove matched inputs/outputs as these will now become internal tensors if src_to_dst: src_subgraph.outputs = [ - o for o in src_subgraph.outputs if not o in tensor_relabel.keys() + output + for output in src_subgraph.outputs + if output not in tensor_relabel.keys() ] dst_subgraph.inputs = [ - i for i in dst_subgraph.inputs if not i in tensor_relabel.values() + input + for input in dst_subgraph.inputs + if input not in tensor_relabel.values() ] else: src_subgraph.inputs = [ - i for i in src_subgraph.inputs if not i in tensor_relabel.keys() + input for input in src_subgraph.inputs if input not in tensor_relabel.keys() ] dst_subgraph.outputs = [ - o for o in dst_subgraph.outputs if not o in tensor_relabel.values() + output + for output in dst_subgraph.outputs + if output not in tensor_relabel.values() ] buffer_relabel = { - src_subgraph.tensors[i].buffer: dst_subgraph.tensors[o].buffer - for i, o in tensor_relabel.items() + src_subgraph.tensors[input].buffer: dst_subgraph.tensors[output].buffer + for input, output in tensor_relabel.items() } used_tensors = [ - t for i, t in enumerate(src_subgraph.tensors) if not i in tensor_relabel + tensor + for i, tensor in enumerate(src_subgraph.tensors) + if i not in tensor_relabel ] - used_buffer_ids = [t.buffer for t in used_tensors] + used_buffer_ids = [tensor.buffer for tensor in used_tensors] + + def opcode_data(code: OperatorCodeT) -> tuple: + return ( + code.builtinCode, + code.deprecatedBuiltinCode, + code.customCode, + code.version, + ) - opcode_data = lambda c: ( - c.builtinCode, - c.deprecatedBuiltinCode, - c.customCode, - c.version, - ) opcode_relabel = { s: d for s in range(len(src_model.operatorCodes)) @@ -85,7 +116,8 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): == opcode_data(dst_model.operatorCodes[d]) } - # operator order defines execution schedule so must reflect the inputs/outputs dependencies + # operator order defines execution schedule so must reflect + # the inputs/outputs dependencies if dst_to_src: dst_subgraph.operators += src_subgraph.operators else: @@ -99,17 +131,17 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): ] = -1 # Some files have ops with -1 input tensors; leave unchanged for i in used_buffer_ids: - if not i in buffer_relabel: + if i not in buffer_relabel: buffer_relabel[i] = len(dst_model.buffers) dst_model.buffers.append(src_model.buffers[i]) - for o in src_subgraph.operators: - o.inputs = [tensor_relabel[t] for t in o.inputs] - o.outputs = [tensor_relabel[t] for t in o.outputs] - o.opcodeIndex = opcode_relabel[o.opcodeIndex] + for operator in src_subgraph.operators: + operator.inputs = [tensor_relabel[tensor] for tensor in operator.inputs] + operator.outputs = [tensor_relabel[tensor] for tensor in operator.outputs] + operator.opcodeIndex = opcode_relabel[operator.opcodeIndex] - for t in used_tensors: - t.buffer = buffer_relabel[t.buffer] + for tensor in used_tensors: + tensor.buffer = buffer_relabel[tensor.buffer] src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs] src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs] @@ -118,11 +150,12 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs)) -def append_relabel(src, dst, map=None): - if map is None: - map = {} - for i, x in enumerate(src): - if not i in map: - map[i] = len(dst) +def append_relabel(src: list, dst: list, operator_map: dict | None = None) -> dict: + """Return a map over relabeled tensors in a subgraph.""" + if not operator_map: + operator_map = {} + for i, x in enumerate(src): # pylint: disable=invalid-name + if i not in operator_map: + operator_map[i] = len(dst) dst.append(x) - return map + return operator_map |