diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/cut.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/cut.py | 139 |
1 files changed, 77 insertions, 62 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py index a323b7b..2707eb1 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/cut.py +++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py @@ -1,128 +1,143 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Cut module.""" import os from collections import defaultdict +from typing import Optional -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT +from tensorflow.lite.python.schema_py_generated import SubGraphT -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.rewrite.core.utils.utils import save -from mlia.nn.rewrite.core.utils.utils import load, save +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -def cut_subgraph(subgraph, input_tensor_names, output_tensor_names): - """Change the global inputs and outputs of a graph to the provided named tensors""" +def tensors_by_name(subgraph: SubGraphT, names: list) -> list: + """Seek out tensors from a subgraph and return the result.""" + seek = frozenset([name.encode("utf-8") for name in names]) + tensors = [i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek] + return tensors - def tensors_by_name(names): - seek = frozenset([name.encode("utf-8") for name in names]) - tensors = [ - i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek - ] - return tensors +def cut_subgraph( + subgraph: SubGraphT, + input_tensor_names: Optional[list], + output_tensor_names: Optional[list], +) -> None: + """Change the global inputs and outputs of a graph to the provided named tensors.""" if input_tensor_names is not None: - subgraph.inputs = tensors_by_name(input_tensor_names) + subgraph.inputs = tensors_by_name(subgraph, input_tensor_names) assert len(subgraph.inputs) == len( input_tensor_names - ), "Expected %d input tensors: %s\nFound: %s" % ( - len(subgraph.inputs), - ", ".join(input_tensor_names), - ", ".join(subgraph.tensors[i].name for i in subgraph.inputs), - ) - + ), f"Expected {len(subgraph.inputs)} input tensors: \ + {', '.join(input_tensor_names)}\nFound: \ + {', '.join(subgraph.tensors[i].name for i in subgraph.inputs)}" if output_tensor_names is not None: - subgraph.outputs = tensors_by_name(output_tensor_names) + subgraph.outputs = tensors_by_name(subgraph, output_tensor_names) assert len(subgraph.outputs) == len( output_tensor_names - ), "Expected %d output tensors: %s\nFound: %s" % ( - len(subgraph.outputs), - ", ".join(output_tensor_names), - ", ".join(subgraph.tensors[i].name for i in subgraph.outputs), - ) + ), f"Expected {len(subgraph.outputs)} output tensors: \ + {', '.join(output_tensor_names)}\nFound: \ + {', '.join(subgraph.tensors[i].name for i in subgraph.outputs)}" -def simplify(model): - """Remove any unused operators, tensors and buffers from a model""" - for s in model.subgraphs: - simplify_subgraph(s) +def simplify(model: ModelT) -> None: + """Remove any unused operators, tensors and buffers from a model.""" + for subgraph in model.subgraphs: + simplify_subgraph(subgraph) - used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs} - used_buffers = used_buffers.union({m.buffer for m in model.metadata}) + used_buffers = { + tensor.buffer for tensor in subgraph.tensors for subgraph in model.subgraphs + } + used_buffers = used_buffers.union({metadata.buffer for metadata in model.metadata}) used_buffers.add( 0 - ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime + ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the + # TFLite runtime model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers) - for s in model.subgraphs: - for t in s.tensors: - t.buffer = buf_relabel[t.buffer] + for subgraph in model.subgraphs: + for tensor in subgraph.tensors: + tensor.buffer = buf_relabel[tensor.buffer] - for m in model.metadata: - m.buffer = buf_relabel[m.buffer] + for metadata in model.metadata: + metadata.buffer = buf_relabel[metadata.buffer] -def simplify_subgraph(subgraph): +def simplify_subgraph(subgraph: SubGraphT) -> None: + """Simplify a subgraph given its operators.""" requires = defaultdict(set) - for o, operator in enumerate(subgraph.operators): - for t in operator.outputs: - if not t in subgraph.inputs: - requires[t].add(o) + for output, operator in enumerate(subgraph.operators): + for tensor in operator.outputs: + if tensor not in subgraph.inputs: + requires[tensor].add(output) op_set, ten_set = find_required(subgraph, requires, subgraph.outputs) - subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set) + subgraph.operators, _ = filter_relabel(subgraph.operators, op_set) subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set) ten_relabel[-1] = -1 # Some files have ops with -1 input tensors; leave unchanged - for op in subgraph.operators: - op.inputs = [ten_relabel[t] for t in op.inputs] - op.outputs = [ten_relabel[t] for t in op.outputs] + for operator in subgraph.operators: + operator.inputs = [ten_relabel[tensor] for tensor in operator.inputs] + operator.outputs = [ten_relabel[tensor] for tensor in operator.outputs] - subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs] - subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs] + subgraph.inputs = [ten_relabel[tensor] for tensor in subgraph.inputs] + subgraph.outputs = [ten_relabel[tensors] for tensors in subgraph.outputs] -def find_required(subgraph, requires, tensors): - visited_operators = set() +def find_required(subgraph: SubGraphT, requires: dict, tensors: dict) -> tuple: + """Find required operators in a given subgraph.""" + visited_operators: set = set() visited_tensors = set(tensors) stop_tensors = set(subgraph.inputs) - changed = True next_tensors = visited_tensors while next_tensors: loop_tensors = next_tensors next_tensors = set() - for t in loop_tensors: - candidate_operators = set(requires[t]) + for tensor in loop_tensors: + candidate_operators = set(requires[tensor]) new_operators = candidate_operators - visited_operators visited_operators = visited_operators.union(new_operators) - for op in new_operators: - candidate_tensors = set(subgraph.operators[op].inputs) + for operator in new_operators: + candidate_tensors = set(subgraph.operators[operator].inputs) new_tensors = candidate_tensors - (visited_tensors.union(next_tensors)) next_tensors = next_tensors.union(new_tensors) visited_tensors = visited_tensors.union(candidate_tensors) visited_tensors = visited_tensors.union( - subgraph.operators[op].outputs + subgraph.operators[operator].outputs ) # include stub outputs but do not traverse them next_tensors = next_tensors - stop_tensors return visited_operators, visited_tensors -def filter_relabel(src, filter): - relabel = {} - output = [] - for i, x in enumerate(src): - if i in filter: +def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple: + """Relabel tensors in a subgraph based on a filter.""" + relabel: dict = {} + output: list = [] + for i, out in enumerate(src_subgraph): + if i in relabel_filter: relabel[i] = len(output) - output.append(x) + output.append(out) return output, relabel -def cut_model(model_file, input_names, output_names, subgraph_index, output_file): +def cut_model( + model_file: str, + input_names: Optional[list], + output_names: Optional[list], + subgraph_index: int, + output_file: str, +) -> None: + """Cut subgraphs and simplify a given model.""" model = load(model_file) subgraph = model.subgraphs[subgraph_index] cut_subgraph(subgraph, input_names, output_names) |