diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/cut.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/cut.py | 130 |
1 files changed, 130 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py new file mode 100644 index 0000000..a323b7b --- /dev/null +++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import os +from collections import defaultdict + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +from mlia.nn.rewrite.core.utils.utils import load, save + + +def cut_subgraph(subgraph, input_tensor_names, output_tensor_names): + """Change the global inputs and outputs of a graph to the provided named tensors""" + + def tensors_by_name(names): + seek = frozenset([name.encode("utf-8") for name in names]) + tensors = [ + i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek + ] + return tensors + + if input_tensor_names is not None: + subgraph.inputs = tensors_by_name(input_tensor_names) + assert len(subgraph.inputs) == len( + input_tensor_names + ), "Expected %d input tensors: %s\nFound: %s" % ( + len(subgraph.inputs), + ", ".join(input_tensor_names), + ", ".join(subgraph.tensors[i].name for i in subgraph.inputs), + ) + + if output_tensor_names is not None: + subgraph.outputs = tensors_by_name(output_tensor_names) + assert len(subgraph.outputs) == len( + output_tensor_names + ), "Expected %d output tensors: %s\nFound: %s" % ( + len(subgraph.outputs), + ", ".join(output_tensor_names), + ", ".join(subgraph.tensors[i].name for i in subgraph.outputs), + ) + + +def simplify(model): + """Remove any unused operators, tensors and buffers from a model""" + for s in model.subgraphs: + simplify_subgraph(s) + + used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs} + used_buffers = used_buffers.union({m.buffer for m in model.metadata}) + used_buffers.add( + 0 + ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime + model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers) + + for s in model.subgraphs: + for t in s.tensors: + t.buffer = buf_relabel[t.buffer] + + for m in model.metadata: + m.buffer = buf_relabel[m.buffer] + + +def simplify_subgraph(subgraph): + requires = defaultdict(set) + + for o, operator in enumerate(subgraph.operators): + for t in operator.outputs: + if not t in subgraph.inputs: + requires[t].add(o) + + op_set, ten_set = find_required(subgraph, requires, subgraph.outputs) + + subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set) + subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set) + + ten_relabel[-1] = -1 # Some files have ops with -1 input tensors; leave unchanged + + for op in subgraph.operators: + op.inputs = [ten_relabel[t] for t in op.inputs] + op.outputs = [ten_relabel[t] for t in op.outputs] + + subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs] + subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs] + + +def find_required(subgraph, requires, tensors): + visited_operators = set() + visited_tensors = set(tensors) + stop_tensors = set(subgraph.inputs) + changed = True + + next_tensors = visited_tensors + while next_tensors: + loop_tensors = next_tensors + next_tensors = set() + for t in loop_tensors: + candidate_operators = set(requires[t]) + new_operators = candidate_operators - visited_operators + visited_operators = visited_operators.union(new_operators) + for op in new_operators: + candidate_tensors = set(subgraph.operators[op].inputs) + new_tensors = candidate_tensors - (visited_tensors.union(next_tensors)) + next_tensors = next_tensors.union(new_tensors) + visited_tensors = visited_tensors.union(candidate_tensors) + visited_tensors = visited_tensors.union( + subgraph.operators[op].outputs + ) # include stub outputs but do not traverse them + next_tensors = next_tensors - stop_tensors + + return visited_operators, visited_tensors + + +def filter_relabel(src, filter): + relabel = {} + output = [] + for i, x in enumerate(src): + if i in filter: + relabel[i] = len(output) + output.append(x) + return output, relabel + + +def cut_model(model_file, input_names, output_names, subgraph_index, output_file): + model = load(model_file) + subgraph = model.subgraphs[subgraph_index] + cut_subgraph(subgraph, input_names, output_names) + simplify(model) + save(model, output_file) |