aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py130
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/diff.py109
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py128
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py62
5 files changed, 431 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
new file mode 100644
index 0000000..48b1622
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
new file mode 100644
index 0000000..a323b7b
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -0,0 +1,130 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+from collections import defaultdict
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from mlia.nn.rewrite.core.utils.utils import load, save
+
+
+def cut_subgraph(subgraph, input_tensor_names, output_tensor_names):
+ """Change the global inputs and outputs of a graph to the provided named tensors"""
+
+ def tensors_by_name(names):
+ seek = frozenset([name.encode("utf-8") for name in names])
+ tensors = [
+ i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek
+ ]
+ return tensors
+
+ if input_tensor_names is not None:
+ subgraph.inputs = tensors_by_name(input_tensor_names)
+ assert len(subgraph.inputs) == len(
+ input_tensor_names
+ ), "Expected %d input tensors: %s\nFound: %s" % (
+ len(subgraph.inputs),
+ ", ".join(input_tensor_names),
+ ", ".join(subgraph.tensors[i].name for i in subgraph.inputs),
+ )
+
+ if output_tensor_names is not None:
+ subgraph.outputs = tensors_by_name(output_tensor_names)
+ assert len(subgraph.outputs) == len(
+ output_tensor_names
+ ), "Expected %d output tensors: %s\nFound: %s" % (
+ len(subgraph.outputs),
+ ", ".join(output_tensor_names),
+ ", ".join(subgraph.tensors[i].name for i in subgraph.outputs),
+ )
+
+
+def simplify(model):
+ """Remove any unused operators, tensors and buffers from a model"""
+ for s in model.subgraphs:
+ simplify_subgraph(s)
+
+ used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs}
+ used_buffers = used_buffers.union({m.buffer for m in model.metadata})
+ used_buffers.add(
+ 0
+ ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime
+ model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers)
+
+ for s in model.subgraphs:
+ for t in s.tensors:
+ t.buffer = buf_relabel[t.buffer]
+
+ for m in model.metadata:
+ m.buffer = buf_relabel[m.buffer]
+
+
+def simplify_subgraph(subgraph):
+ requires = defaultdict(set)
+
+ for o, operator in enumerate(subgraph.operators):
+ for t in operator.outputs:
+ if not t in subgraph.inputs:
+ requires[t].add(o)
+
+ op_set, ten_set = find_required(subgraph, requires, subgraph.outputs)
+
+ subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set)
+ subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set)
+
+ ten_relabel[-1] = -1 # Some files have ops with -1 input tensors; leave unchanged
+
+ for op in subgraph.operators:
+ op.inputs = [ten_relabel[t] for t in op.inputs]
+ op.outputs = [ten_relabel[t] for t in op.outputs]
+
+ subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs]
+ subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs]
+
+
+def find_required(subgraph, requires, tensors):
+ visited_operators = set()
+ visited_tensors = set(tensors)
+ stop_tensors = set(subgraph.inputs)
+ changed = True
+
+ next_tensors = visited_tensors
+ while next_tensors:
+ loop_tensors = next_tensors
+ next_tensors = set()
+ for t in loop_tensors:
+ candidate_operators = set(requires[t])
+ new_operators = candidate_operators - visited_operators
+ visited_operators = visited_operators.union(new_operators)
+ for op in new_operators:
+ candidate_tensors = set(subgraph.operators[op].inputs)
+ new_tensors = candidate_tensors - (visited_tensors.union(next_tensors))
+ next_tensors = next_tensors.union(new_tensors)
+ visited_tensors = visited_tensors.union(candidate_tensors)
+ visited_tensors = visited_tensors.union(
+ subgraph.operators[op].outputs
+ ) # include stub outputs but do not traverse them
+ next_tensors = next_tensors - stop_tensors
+
+ return visited_operators, visited_tensors
+
+
+def filter_relabel(src, filter):
+ relabel = {}
+ output = []
+ for i, x in enumerate(src):
+ if i in filter:
+ relabel[i] = len(output)
+ output.append(x)
+ return output, relabel
+
+
+def cut_model(model_file, input_names, output_names, subgraph_index, output_file):
+ model = load(model_file)
+ subgraph = model.subgraphs[subgraph_index]
+ cut_subgraph(subgraph, input_names, output_names)
+ simplify(model)
+ save(model, output_file)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py
new file mode 100644
index 0000000..b6c9616
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py
@@ -0,0 +1,109 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+from collections import defaultdict
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+import numpy as np
+from tensorflow.lite.python import interpreter as interpreter_wrapper
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter
+
+
+def diff(file1, file2):
+ results = []
+
+ dataset1 = NumpyTFReader(file1)
+ dataset2 = NumpyTFReader(file2)
+
+ for i, (x1, x2) in enumerate(zip(dataset1, dataset2)):
+ assert x1.keys() == x2.keys(), (
+ "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n"
+ % (
+ i,
+ file1,
+ ", ".join(x1.keys()),
+ file2,
+ ", ".join(x2.keys()),
+ )
+ )
+ results.append({})
+ for k in x1.keys():
+ v1 = x1[k].numpy().astype(np.double)
+ v2 = x2[k].numpy().astype(np.double)
+ mae = abs(v1 - v2).mean()
+ results[-1][k] = mae
+
+ total = sum(sum(x.values()) for x in results)
+ count = sum(len(x.values()) for x in results)
+ mean = total / count
+ return results, mean
+
+
+def diff_stats(file1, file2, per_tensor_and_channel=False):
+ dataset1 = NumpyTFReader(file1)
+ dataset2 = NumpyTFReader(file2)
+
+ totals = defaultdict(dict)
+
+ def add_total(name, key, values):
+ if not key in totals[name]:
+ totals[name][key] = values
+ else:
+ totals[name][key] += values
+
+ # First iterate through dataset1 and calculate per-channel total for each tensor
+ count = 0
+ for d in dataset1:
+ count += 1
+ for k, v in d.items():
+ value = v.numpy().astype(np.double)
+ add_total("dataset1_total", k, value)
+
+ # Use this to calculate per-channel mean for each tensor
+ per_tensor_mean = lambda name: {
+ k: total / count for k, total in totals[name].items()
+ }
+ dataset1_mean = per_tensor_mean("dataset1_total")
+
+ # Next iterate through both datasets and calculate per-channel total squared error
+ # between them for each tensor and dataset1 variance for each tensor using the mean from above
+ for i, (x1, x2) in enumerate(zip(dataset1, dataset2)):
+ assert x1.keys() == x2.keys(), (
+ "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n"
+ % (
+ i,
+ file1,
+ ", ".join(x1.keys()),
+ file2,
+ ", ".join(x2.keys()),
+ )
+ )
+ for k in x1.keys():
+ v1 = x1[k].numpy().astype(np.double)
+ v2 = x2[k].numpy().astype(np.double)
+ add_total("ae", k, abs(v1 - v2))
+ add_total("se", k, (v1 - v2) ** 2)
+ add_total("dataset1_variance", k, (v1 - dataset1_mean[k]) ** 2)
+
+ # Finally average over number of inputs to get the rmse and the dataset1 variance
+ mae = per_tensor_mean("ae")
+ mse = per_tensor_mean("se")
+ rmse = {k: np.sqrt(v) for k, v in mse.items()}
+ dataset1_var = per_tensor_mean("dataset1_variance")
+ is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var}
+
+ # Divide by target standard deviation to get the per-channel nrmse for each tensor where possible
+ nrmse = {
+ k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]])
+ for k, v in rmse.items()
+ }
+
+ if per_tensor_and_channel:
+ return mae, nrmse
+ else:
+ dict_mean = lambda d: np.mean(list(d.values()))
+ return dict_mean(mae), dict_mean(nrmse)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
new file mode 100644
index 0000000..758f4cf
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -0,0 +1,128 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from mlia.nn.rewrite.core.utils.utils import load, save
+
+
+def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=0):
+ src_model = load(input_src)
+ dst_model = load(input_dst)
+ src_subgraph = src_model.subgraphs[subgraph_src]
+ dst_subgraph = dst_model.subgraphs[subgraph_dst]
+ join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph)
+ save(dst_model, output_file)
+
+
+def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
+ """Copy subgraph src into subgraph dst from model, connecting tensors with the same names"""
+ # Find inputs that match outputs in the other graph and vice versa
+ dst_to_src = {
+ i: o
+ for i in src_subgraph.inputs
+ for o in dst_subgraph.outputs
+ if src_subgraph.tensors[i].name == dst_subgraph.tensors[o].name
+ }
+
+ src_to_dst = {
+ o: i
+ for i in dst_subgraph.inputs
+ for o in src_subgraph.outputs
+ if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name
+ }
+
+ assert not (src_to_dst and dst_to_src), (
+ "Source and destination subgraphs appear to connect in a loop: %d tensors from src to dst, %d tensors from dst to src"
+ % (len(src_to_dst), len(dst_to_src))
+ )
+
+ # Relabel matched input/output tensors between graphs
+ tensor_relabel = src_to_dst if src_to_dst else dst_to_src
+
+ # Remove matched inputs/outputs as these will now become internal tensors
+ if src_to_dst:
+ src_subgraph.outputs = [
+ o for o in src_subgraph.outputs if not o in tensor_relabel.keys()
+ ]
+ dst_subgraph.inputs = [
+ i for i in dst_subgraph.inputs if not i in tensor_relabel.values()
+ ]
+ else:
+ src_subgraph.inputs = [
+ i for i in src_subgraph.inputs if not i in tensor_relabel.keys()
+ ]
+ dst_subgraph.outputs = [
+ o for o in dst_subgraph.outputs if not o in tensor_relabel.values()
+ ]
+
+ buffer_relabel = {
+ src_subgraph.tensors[i].buffer: dst_subgraph.tensors[o].buffer
+ for i, o in tensor_relabel.items()
+ }
+
+ used_tensors = [
+ t for i, t in enumerate(src_subgraph.tensors) if not i in tensor_relabel
+ ]
+
+ used_buffer_ids = [t.buffer for t in used_tensors]
+
+ opcode_data = lambda c: (
+ c.builtinCode,
+ c.deprecatedBuiltinCode,
+ c.customCode,
+ c.version,
+ )
+ opcode_relabel = {
+ s: d
+ for s in range(len(src_model.operatorCodes))
+ for d in range(len(dst_model.operatorCodes))
+ if opcode_data(src_model.operatorCodes[s])
+ == opcode_data(dst_model.operatorCodes[d])
+ }
+
+ # operator order defines execution schedule so must reflect the inputs/outputs dependencies
+ if dst_to_src:
+ dst_subgraph.operators += src_subgraph.operators
+ else:
+ dst_subgraph.operators = src_subgraph.operators + dst_subgraph.operators
+
+ append_relabel(src_subgraph.tensors, dst_subgraph.tensors, tensor_relabel)
+ append_relabel(src_model.operatorCodes, dst_model.operatorCodes, opcode_relabel)
+
+ tensor_relabel[
+ -1
+ ] = -1 # Some files have ops with -1 input tensors; leave unchanged
+
+ for i in used_buffer_ids:
+ if not i in buffer_relabel:
+ buffer_relabel[i] = len(dst_model.buffers)
+ dst_model.buffers.append(src_model.buffers[i])
+
+ for o in src_subgraph.operators:
+ o.inputs = [tensor_relabel[t] for t in o.inputs]
+ o.outputs = [tensor_relabel[t] for t in o.outputs]
+ o.opcodeIndex = opcode_relabel[o.opcodeIndex]
+
+ for t in used_tensors:
+ t.buffer = buffer_relabel[t.buffer]
+
+ src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs]
+ src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs]
+
+ dst_subgraph.inputs = list(set(src_subgraph.inputs).union(dst_subgraph.inputs))
+ dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs))
+
+
+def append_relabel(src, dst, map=None):
+ if map is None:
+ map = {}
+ for i, x in enumerate(src):
+ if not i in map:
+ map[i] = len(dst)
+ dst.append(x)
+ return map
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
new file mode 100644
index 0000000..03cd3f9
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -0,0 +1,62 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import math
+import os
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from tqdm import tqdm
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import (
+ NumpyTFReader,
+ NumpyTFWriter,
+ TFLiteModel,
+ numpytf_count,
+)
+from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+
+
+def record_model(
+ input_filename,
+ model_filename,
+ output_filename,
+ batch_size=None,
+ show_progress=False,
+ num_procs=1,
+ num_threads=0,
+):
+ """num_procs: 0 => detect real cores on system
+ num_threads: 0 => TFLite impl. specific setting, usually 3"""
+ model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size)
+ if not batch_size:
+ batch_size = (
+ model.num_procs * model.batch_size
+ ) # automatically batch to the minimum effective size if not specified
+
+ total = numpytf_count(input_filename)
+ dataset = NumpyTFReader(input_filename)
+ writer = NumpyTFWriter(output_filename)
+
+ if batch_size > 1:
+ # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now.
+ dataset = dataset.map(
+ lambda d: {k: tf.squeeze(v, axis=0) for k, v in d.items()}
+ )
+ dataset = dataset.batch(batch_size, drop_remainder=False)
+ total = int(math.ceil(total / batch_size))
+
+ for j, named_x in enumerate(
+ tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ ):
+ named_y = model(named_x)
+ if batch_size > 1:
+ for i in range(batch_size):
+ # Expand the batches and recreate each dict as a batch-size 1 item for the tfrec output
+ d = {k: v[i : i + 1] for k, v in named_y.items() if i < v.shape[0]}
+ if d:
+ writer.write(d)
+ else:
+ writer.write(named_y)
+ model.close()