aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/cut.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/cut.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py139
1 files changed, 77 insertions, 62 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
index a323b7b..2707eb1 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/cut.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -1,128 +1,143 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Cut module."""
import os
from collections import defaultdict
+from typing import Optional
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+from tensorflow.lite.python.schema_py_generated import SubGraphT
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
-from mlia.nn.rewrite.core.utils.utils import load, save
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-def cut_subgraph(subgraph, input_tensor_names, output_tensor_names):
- """Change the global inputs and outputs of a graph to the provided named tensors"""
+def tensors_by_name(subgraph: SubGraphT, names: list) -> list:
+ """Seek out tensors from a subgraph and return the result."""
+ seek = frozenset([name.encode("utf-8") for name in names])
+ tensors = [i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek]
+ return tensors
- def tensors_by_name(names):
- seek = frozenset([name.encode("utf-8") for name in names])
- tensors = [
- i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek
- ]
- return tensors
+def cut_subgraph(
+ subgraph: SubGraphT,
+ input_tensor_names: Optional[list],
+ output_tensor_names: Optional[list],
+) -> None:
+ """Change the global inputs and outputs of a graph to the provided named tensors."""
if input_tensor_names is not None:
- subgraph.inputs = tensors_by_name(input_tensor_names)
+ subgraph.inputs = tensors_by_name(subgraph, input_tensor_names)
assert len(subgraph.inputs) == len(
input_tensor_names
- ), "Expected %d input tensors: %s\nFound: %s" % (
- len(subgraph.inputs),
- ", ".join(input_tensor_names),
- ", ".join(subgraph.tensors[i].name for i in subgraph.inputs),
- )
-
+ ), f"Expected {len(subgraph.inputs)} input tensors: \
+ {', '.join(input_tensor_names)}\nFound: \
+ {', '.join(subgraph.tensors[i].name for i in subgraph.inputs)}"
if output_tensor_names is not None:
- subgraph.outputs = tensors_by_name(output_tensor_names)
+ subgraph.outputs = tensors_by_name(subgraph, output_tensor_names)
assert len(subgraph.outputs) == len(
output_tensor_names
- ), "Expected %d output tensors: %s\nFound: %s" % (
- len(subgraph.outputs),
- ", ".join(output_tensor_names),
- ", ".join(subgraph.tensors[i].name for i in subgraph.outputs),
- )
+ ), f"Expected {len(subgraph.outputs)} output tensors: \
+ {', '.join(output_tensor_names)}\nFound: \
+ {', '.join(subgraph.tensors[i].name for i in subgraph.outputs)}"
-def simplify(model):
- """Remove any unused operators, tensors and buffers from a model"""
- for s in model.subgraphs:
- simplify_subgraph(s)
+def simplify(model: ModelT) -> None:
+ """Remove any unused operators, tensors and buffers from a model."""
+ for subgraph in model.subgraphs:
+ simplify_subgraph(subgraph)
- used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs}
- used_buffers = used_buffers.union({m.buffer for m in model.metadata})
+ used_buffers = {
+ tensor.buffer for tensor in subgraph.tensors for subgraph in model.subgraphs
+ }
+ used_buffers = used_buffers.union({metadata.buffer for metadata in model.metadata})
used_buffers.add(
0
- ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime
+ ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the
+ # TFLite runtime
model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers)
- for s in model.subgraphs:
- for t in s.tensors:
- t.buffer = buf_relabel[t.buffer]
+ for subgraph in model.subgraphs:
+ for tensor in subgraph.tensors:
+ tensor.buffer = buf_relabel[tensor.buffer]
- for m in model.metadata:
- m.buffer = buf_relabel[m.buffer]
+ for metadata in model.metadata:
+ metadata.buffer = buf_relabel[metadata.buffer]
-def simplify_subgraph(subgraph):
+def simplify_subgraph(subgraph: SubGraphT) -> None:
+ """Simplify a subgraph given its operators."""
requires = defaultdict(set)
- for o, operator in enumerate(subgraph.operators):
- for t in operator.outputs:
- if not t in subgraph.inputs:
- requires[t].add(o)
+ for output, operator in enumerate(subgraph.operators):
+ for tensor in operator.outputs:
+ if tensor not in subgraph.inputs:
+ requires[tensor].add(output)
op_set, ten_set = find_required(subgraph, requires, subgraph.outputs)
- subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set)
+ subgraph.operators, _ = filter_relabel(subgraph.operators, op_set)
subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set)
ten_relabel[-1] = -1 # Some files have ops with -1 input tensors; leave unchanged
- for op in subgraph.operators:
- op.inputs = [ten_relabel[t] for t in op.inputs]
- op.outputs = [ten_relabel[t] for t in op.outputs]
+ for operator in subgraph.operators:
+ operator.inputs = [ten_relabel[tensor] for tensor in operator.inputs]
+ operator.outputs = [ten_relabel[tensor] for tensor in operator.outputs]
- subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs]
- subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs]
+ subgraph.inputs = [ten_relabel[tensor] for tensor in subgraph.inputs]
+ subgraph.outputs = [ten_relabel[tensors] for tensors in subgraph.outputs]
-def find_required(subgraph, requires, tensors):
- visited_operators = set()
+def find_required(subgraph: SubGraphT, requires: dict, tensors: dict) -> tuple:
+ """Find required operators in a given subgraph."""
+ visited_operators: set = set()
visited_tensors = set(tensors)
stop_tensors = set(subgraph.inputs)
- changed = True
next_tensors = visited_tensors
while next_tensors:
loop_tensors = next_tensors
next_tensors = set()
- for t in loop_tensors:
- candidate_operators = set(requires[t])
+ for tensor in loop_tensors:
+ candidate_operators = set(requires[tensor])
new_operators = candidate_operators - visited_operators
visited_operators = visited_operators.union(new_operators)
- for op in new_operators:
- candidate_tensors = set(subgraph.operators[op].inputs)
+ for operator in new_operators:
+ candidate_tensors = set(subgraph.operators[operator].inputs)
new_tensors = candidate_tensors - (visited_tensors.union(next_tensors))
next_tensors = next_tensors.union(new_tensors)
visited_tensors = visited_tensors.union(candidate_tensors)
visited_tensors = visited_tensors.union(
- subgraph.operators[op].outputs
+ subgraph.operators[operator].outputs
) # include stub outputs but do not traverse them
next_tensors = next_tensors - stop_tensors
return visited_operators, visited_tensors
-def filter_relabel(src, filter):
- relabel = {}
- output = []
- for i, x in enumerate(src):
- if i in filter:
+def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple:
+ """Relabel tensors in a subgraph based on a filter."""
+ relabel: dict = {}
+ output: list = []
+ for i, out in enumerate(src_subgraph):
+ if i in relabel_filter:
relabel[i] = len(output)
- output.append(x)
+ output.append(out)
return output, relabel
-def cut_model(model_file, input_names, output_names, subgraph_index, output_file):
+def cut_model(
+ model_file: str,
+ input_names: Optional[list],
+ output_names: Optional[list],
+ subgraph_index: int,
+ output_file: str,
+) -> None:
+ """Cut subgraphs and simplify a given model."""
model = load(model_file)
subgraph = model.subgraphs[subgraph_index]
cut_subgraph(subgraph, input_names, output_names)