aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/cut.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/cut.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py130
1 files changed, 130 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
new file mode 100644
index 0000000..a323b7b
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -0,0 +1,130 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+from collections import defaultdict
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from mlia.nn.rewrite.core.utils.utils import load, save
+
+
+def cut_subgraph(subgraph, input_tensor_names, output_tensor_names):
+ """Change the global inputs and outputs of a graph to the provided named tensors"""
+
+ def tensors_by_name(names):
+ seek = frozenset([name.encode("utf-8") for name in names])
+ tensors = [
+ i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek
+ ]
+ return tensors
+
+ if input_tensor_names is not None:
+ subgraph.inputs = tensors_by_name(input_tensor_names)
+ assert len(subgraph.inputs) == len(
+ input_tensor_names
+ ), "Expected %d input tensors: %s\nFound: %s" % (
+ len(subgraph.inputs),
+ ", ".join(input_tensor_names),
+ ", ".join(subgraph.tensors[i].name for i in subgraph.inputs),
+ )
+
+ if output_tensor_names is not None:
+ subgraph.outputs = tensors_by_name(output_tensor_names)
+ assert len(subgraph.outputs) == len(
+ output_tensor_names
+ ), "Expected %d output tensors: %s\nFound: %s" % (
+ len(subgraph.outputs),
+ ", ".join(output_tensor_names),
+ ", ".join(subgraph.tensors[i].name for i in subgraph.outputs),
+ )
+
+
+def simplify(model):
+ """Remove any unused operators, tensors and buffers from a model"""
+ for s in model.subgraphs:
+ simplify_subgraph(s)
+
+ used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs}
+ used_buffers = used_buffers.union({m.buffer for m in model.metadata})
+ used_buffers.add(
+ 0
+ ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime
+ model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers)
+
+ for s in model.subgraphs:
+ for t in s.tensors:
+ t.buffer = buf_relabel[t.buffer]
+
+ for m in model.metadata:
+ m.buffer = buf_relabel[m.buffer]
+
+
+def simplify_subgraph(subgraph):
+ requires = defaultdict(set)
+
+ for o, operator in enumerate(subgraph.operators):
+ for t in operator.outputs:
+ if not t in subgraph.inputs:
+ requires[t].add(o)
+
+ op_set, ten_set = find_required(subgraph, requires, subgraph.outputs)
+
+ subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set)
+ subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set)
+
+ ten_relabel[-1] = -1 # Some files have ops with -1 input tensors; leave unchanged
+
+ for op in subgraph.operators:
+ op.inputs = [ten_relabel[t] for t in op.inputs]
+ op.outputs = [ten_relabel[t] for t in op.outputs]
+
+ subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs]
+ subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs]
+
+
+def find_required(subgraph, requires, tensors):
+ visited_operators = set()
+ visited_tensors = set(tensors)
+ stop_tensors = set(subgraph.inputs)
+ changed = True
+
+ next_tensors = visited_tensors
+ while next_tensors:
+ loop_tensors = next_tensors
+ next_tensors = set()
+ for t in loop_tensors:
+ candidate_operators = set(requires[t])
+ new_operators = candidate_operators - visited_operators
+ visited_operators = visited_operators.union(new_operators)
+ for op in new_operators:
+ candidate_tensors = set(subgraph.operators[op].inputs)
+ new_tensors = candidate_tensors - (visited_tensors.union(next_tensors))
+ next_tensors = next_tensors.union(new_tensors)
+ visited_tensors = visited_tensors.union(candidate_tensors)
+ visited_tensors = visited_tensors.union(
+ subgraph.operators[op].outputs
+ ) # include stub outputs but do not traverse them
+ next_tensors = next_tensors - stop_tensors
+
+ return visited_operators, visited_tensors
+
+
+def filter_relabel(src, filter):
+ relabel = {}
+ output = []
+ for i, x in enumerate(src):
+ if i in filter:
+ relabel[i] = len(output)
+ output.append(x)
+ return output, relabel
+
+
+def cut_model(model_file, input_names, output_names, subgraph_index, output_file):
+ model = load(model_file)
+ subgraph = model.subgraphs[subgraph_index]
+ cut_subgraph(subgraph, input_names, output_names)
+ simplify(model)
+ save(model, output_file)