aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2023-03-13 17:00:31 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:41:48 +0100
commitf0b8ed75fed9dc69ab1f6313339f9f7e38bfc725 (patch)
treebc353fad664040b44915b5cf7ae807894b0b87e8
parentb236127b9a18ec2668271c6b5baafa6a7c1dde51 (diff)
downloadmlia-f0b8ed75fed9dc69ab1f6313339f9f7e38bfc725.tar.gz
MLIA-845 Migrate rewrite code
- Add required files for rewriting of TensorFlow Lite graphs - Adapt rewrite dependency paths and project name - Add license headers Change-Id: I19c5f63215fe2af2fa7d7d44af08144c6c5f911c Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
-rw-r--r--setup.cfg1
-rw-r--r--src/mlia/nn/rewrite/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/extract.py87
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py130
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/diff.py109
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py128
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py62
-rw-r--r--src/mlia/nn/rewrite/core/train.py487
-rw-r--r--src/mlia/nn/rewrite/core/utils/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py172
-rw-r--r--src/mlia/nn/rewrite/core/utils/parallel.py113
-rw-r--r--src/mlia/nn/rewrite/core/utils/utils.py26
14 files changed, 1323 insertions, 0 deletions
diff --git a/setup.cfg b/setup.cfg
index 5a68b6b..7cdd3c5 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -35,6 +35,7 @@ install_requires =
requests~=2.31.0
rich~=13.5.2
tomli~=2.0.1 ; python_version<"3.11"
+ tqdm~=4.65.0
[options.packages.find]
where = src
diff --git a/src/mlia/nn/rewrite/__init__.py b/src/mlia/nn/rewrite/__init__.py
new file mode 100644
index 0000000..48b1622
--- /dev/null
+++ b/src/mlia/nn/rewrite/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
diff --git a/src/mlia/nn/rewrite/core/__init__.py b/src/mlia/nn/rewrite/core/__init__.py
new file mode 100644
index 0000000..48b1622
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py
new file mode 100644
index 0000000..5fcd348
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/extract.py
@@ -0,0 +1,87 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from mlia.nn.rewrite.core.graph_edit.cut import cut_model
+from mlia.nn.rewrite.core.graph_edit.record import record_model
+
+
+def extract(
+ output_path,
+ model_file,
+ input_data,
+ input_names,
+ output_names,
+ subgraph=0,
+ skip_outputs=False,
+ show_progress=False,
+ num_procs=1,
+ num_threads=0,
+):
+ try:
+ os.mkdir(output_path)
+ except FileExistsError:
+ pass
+
+ start_file = os.path.join(output_path, "start.tflite")
+ cut_model(
+ model_file,
+ input_names=None,
+ output_names=input_names,
+ subgraph_index=subgraph,
+ output_file=start_file,
+ )
+
+ input_tfrec = os.path.join(output_path, "input.tfrec")
+ record_model(
+ input_data,
+ start_file,
+ input_tfrec,
+ show_progress=show_progress,
+ num_procs=num_procs,
+ num_threads=num_threads,
+ )
+
+ replace_file = os.path.join(output_path, "replace.tflite")
+ cut_model(
+ model_file,
+ input_names=input_names,
+ output_names=output_names,
+ subgraph_index=subgraph,
+ output_file=replace_file,
+ )
+
+ end_file = os.path.join(output_path, "end.tflite")
+ cut_model(
+ model_file,
+ input_names=output_names,
+ output_names=None,
+ subgraph_index=subgraph,
+ output_file=end_file,
+ )
+
+ if not skip_outputs:
+ output_tfrec = os.path.join(output_path, "output.tfrec")
+ record_model(
+ input_tfrec,
+ replace_file,
+ output_tfrec,
+ show_progress=show_progress,
+ num_procs=num_procs,
+ num_threads=num_threads,
+ )
+
+ end_tfrec = os.path.join(output_path, "end.tfrec")
+ record_model(
+ output_tfrec,
+ end_file,
+ end_tfrec,
+ show_progress=show_progress,
+ num_procs=num_procs,
+ num_threads=num_threads,
+ )
diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
new file mode 100644
index 0000000..48b1622
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
new file mode 100644
index 0000000..a323b7b
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -0,0 +1,130 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+from collections import defaultdict
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from mlia.nn.rewrite.core.utils.utils import load, save
+
+
+def cut_subgraph(subgraph, input_tensor_names, output_tensor_names):
+ """Change the global inputs and outputs of a graph to the provided named tensors"""
+
+ def tensors_by_name(names):
+ seek = frozenset([name.encode("utf-8") for name in names])
+ tensors = [
+ i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek
+ ]
+ return tensors
+
+ if input_tensor_names is not None:
+ subgraph.inputs = tensors_by_name(input_tensor_names)
+ assert len(subgraph.inputs) == len(
+ input_tensor_names
+ ), "Expected %d input tensors: %s\nFound: %s" % (
+ len(subgraph.inputs),
+ ", ".join(input_tensor_names),
+ ", ".join(subgraph.tensors[i].name for i in subgraph.inputs),
+ )
+
+ if output_tensor_names is not None:
+ subgraph.outputs = tensors_by_name(output_tensor_names)
+ assert len(subgraph.outputs) == len(
+ output_tensor_names
+ ), "Expected %d output tensors: %s\nFound: %s" % (
+ len(subgraph.outputs),
+ ", ".join(output_tensor_names),
+ ", ".join(subgraph.tensors[i].name for i in subgraph.outputs),
+ )
+
+
+def simplify(model):
+ """Remove any unused operators, tensors and buffers from a model"""
+ for s in model.subgraphs:
+ simplify_subgraph(s)
+
+ used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs}
+ used_buffers = used_buffers.union({m.buffer for m in model.metadata})
+ used_buffers.add(
+ 0
+ ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime
+ model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers)
+
+ for s in model.subgraphs:
+ for t in s.tensors:
+ t.buffer = buf_relabel[t.buffer]
+
+ for m in model.metadata:
+ m.buffer = buf_relabel[m.buffer]
+
+
+def simplify_subgraph(subgraph):
+ requires = defaultdict(set)
+
+ for o, operator in enumerate(subgraph.operators):
+ for t in operator.outputs:
+ if not t in subgraph.inputs:
+ requires[t].add(o)
+
+ op_set, ten_set = find_required(subgraph, requires, subgraph.outputs)
+
+ subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set)
+ subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set)
+
+ ten_relabel[-1] = -1 # Some files have ops with -1 input tensors; leave unchanged
+
+ for op in subgraph.operators:
+ op.inputs = [ten_relabel[t] for t in op.inputs]
+ op.outputs = [ten_relabel[t] for t in op.outputs]
+
+ subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs]
+ subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs]
+
+
+def find_required(subgraph, requires, tensors):
+ visited_operators = set()
+ visited_tensors = set(tensors)
+ stop_tensors = set(subgraph.inputs)
+ changed = True
+
+ next_tensors = visited_tensors
+ while next_tensors:
+ loop_tensors = next_tensors
+ next_tensors = set()
+ for t in loop_tensors:
+ candidate_operators = set(requires[t])
+ new_operators = candidate_operators - visited_operators
+ visited_operators = visited_operators.union(new_operators)
+ for op in new_operators:
+ candidate_tensors = set(subgraph.operators[op].inputs)
+ new_tensors = candidate_tensors - (visited_tensors.union(next_tensors))
+ next_tensors = next_tensors.union(new_tensors)
+ visited_tensors = visited_tensors.union(candidate_tensors)
+ visited_tensors = visited_tensors.union(
+ subgraph.operators[op].outputs
+ ) # include stub outputs but do not traverse them
+ next_tensors = next_tensors - stop_tensors
+
+ return visited_operators, visited_tensors
+
+
+def filter_relabel(src, filter):
+ relabel = {}
+ output = []
+ for i, x in enumerate(src):
+ if i in filter:
+ relabel[i] = len(output)
+ output.append(x)
+ return output, relabel
+
+
+def cut_model(model_file, input_names, output_names, subgraph_index, output_file):
+ model = load(model_file)
+ subgraph = model.subgraphs[subgraph_index]
+ cut_subgraph(subgraph, input_names, output_names)
+ simplify(model)
+ save(model, output_file)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py
new file mode 100644
index 0000000..b6c9616
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py
@@ -0,0 +1,109 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+from collections import defaultdict
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+import numpy as np
+from tensorflow.lite.python import interpreter as interpreter_wrapper
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter
+
+
+def diff(file1, file2):
+ results = []
+
+ dataset1 = NumpyTFReader(file1)
+ dataset2 = NumpyTFReader(file2)
+
+ for i, (x1, x2) in enumerate(zip(dataset1, dataset2)):
+ assert x1.keys() == x2.keys(), (
+ "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n"
+ % (
+ i,
+ file1,
+ ", ".join(x1.keys()),
+ file2,
+ ", ".join(x2.keys()),
+ )
+ )
+ results.append({})
+ for k in x1.keys():
+ v1 = x1[k].numpy().astype(np.double)
+ v2 = x2[k].numpy().astype(np.double)
+ mae = abs(v1 - v2).mean()
+ results[-1][k] = mae
+
+ total = sum(sum(x.values()) for x in results)
+ count = sum(len(x.values()) for x in results)
+ mean = total / count
+ return results, mean
+
+
+def diff_stats(file1, file2, per_tensor_and_channel=False):
+ dataset1 = NumpyTFReader(file1)
+ dataset2 = NumpyTFReader(file2)
+
+ totals = defaultdict(dict)
+
+ def add_total(name, key, values):
+ if not key in totals[name]:
+ totals[name][key] = values
+ else:
+ totals[name][key] += values
+
+ # First iterate through dataset1 and calculate per-channel total for each tensor
+ count = 0
+ for d in dataset1:
+ count += 1
+ for k, v in d.items():
+ value = v.numpy().astype(np.double)
+ add_total("dataset1_total", k, value)
+
+ # Use this to calculate per-channel mean for each tensor
+ per_tensor_mean = lambda name: {
+ k: total / count for k, total in totals[name].items()
+ }
+ dataset1_mean = per_tensor_mean("dataset1_total")
+
+ # Next iterate through both datasets and calculate per-channel total squared error
+ # between them for each tensor and dataset1 variance for each tensor using the mean from above
+ for i, (x1, x2) in enumerate(zip(dataset1, dataset2)):
+ assert x1.keys() == x2.keys(), (
+ "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n"
+ % (
+ i,
+ file1,
+ ", ".join(x1.keys()),
+ file2,
+ ", ".join(x2.keys()),
+ )
+ )
+ for k in x1.keys():
+ v1 = x1[k].numpy().astype(np.double)
+ v2 = x2[k].numpy().astype(np.double)
+ add_total("ae", k, abs(v1 - v2))
+ add_total("se", k, (v1 - v2) ** 2)
+ add_total("dataset1_variance", k, (v1 - dataset1_mean[k]) ** 2)
+
+ # Finally average over number of inputs to get the rmse and the dataset1 variance
+ mae = per_tensor_mean("ae")
+ mse = per_tensor_mean("se")
+ rmse = {k: np.sqrt(v) for k, v in mse.items()}
+ dataset1_var = per_tensor_mean("dataset1_variance")
+ is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var}
+
+ # Divide by target standard deviation to get the per-channel nrmse for each tensor where possible
+ nrmse = {
+ k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]])
+ for k, v in rmse.items()
+ }
+
+ if per_tensor_and_channel:
+ return mae, nrmse
+ else:
+ dict_mean = lambda d: np.mean(list(d.values()))
+ return dict_mean(mae), dict_mean(nrmse)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
new file mode 100644
index 0000000..758f4cf
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -0,0 +1,128 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from mlia.nn.rewrite.core.utils.utils import load, save
+
+
+def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=0):
+ src_model = load(input_src)
+ dst_model = load(input_dst)
+ src_subgraph = src_model.subgraphs[subgraph_src]
+ dst_subgraph = dst_model.subgraphs[subgraph_dst]
+ join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph)
+ save(dst_model, output_file)
+
+
+def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
+ """Copy subgraph src into subgraph dst from model, connecting tensors with the same names"""
+ # Find inputs that match outputs in the other graph and vice versa
+ dst_to_src = {
+ i: o
+ for i in src_subgraph.inputs
+ for o in dst_subgraph.outputs
+ if src_subgraph.tensors[i].name == dst_subgraph.tensors[o].name
+ }
+
+ src_to_dst = {
+ o: i
+ for i in dst_subgraph.inputs
+ for o in src_subgraph.outputs
+ if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name
+ }
+
+ assert not (src_to_dst and dst_to_src), (
+ "Source and destination subgraphs appear to connect in a loop: %d tensors from src to dst, %d tensors from dst to src"
+ % (len(src_to_dst), len(dst_to_src))
+ )
+
+ # Relabel matched input/output tensors between graphs
+ tensor_relabel = src_to_dst if src_to_dst else dst_to_src
+
+ # Remove matched inputs/outputs as these will now become internal tensors
+ if src_to_dst:
+ src_subgraph.outputs = [
+ o for o in src_subgraph.outputs if not o in tensor_relabel.keys()
+ ]
+ dst_subgraph.inputs = [
+ i for i in dst_subgraph.inputs if not i in tensor_relabel.values()
+ ]
+ else:
+ src_subgraph.inputs = [
+ i for i in src_subgraph.inputs if not i in tensor_relabel.keys()
+ ]
+ dst_subgraph.outputs = [
+ o for o in dst_subgraph.outputs if not o in tensor_relabel.values()
+ ]
+
+ buffer_relabel = {
+ src_subgraph.tensors[i].buffer: dst_subgraph.tensors[o].buffer
+ for i, o in tensor_relabel.items()
+ }
+
+ used_tensors = [
+ t for i, t in enumerate(src_subgraph.tensors) if not i in tensor_relabel
+ ]
+
+ used_buffer_ids = [t.buffer for t in used_tensors]
+
+ opcode_data = lambda c: (
+ c.builtinCode,
+ c.deprecatedBuiltinCode,
+ c.customCode,
+ c.version,
+ )
+ opcode_relabel = {
+ s: d
+ for s in range(len(src_model.operatorCodes))
+ for d in range(len(dst_model.operatorCodes))
+ if opcode_data(src_model.operatorCodes[s])
+ == opcode_data(dst_model.operatorCodes[d])
+ }
+
+ # operator order defines execution schedule so must reflect the inputs/outputs dependencies
+ if dst_to_src:
+ dst_subgraph.operators += src_subgraph.operators
+ else:
+ dst_subgraph.operators = src_subgraph.operators + dst_subgraph.operators
+
+ append_relabel(src_subgraph.tensors, dst_subgraph.tensors, tensor_relabel)
+ append_relabel(src_model.operatorCodes, dst_model.operatorCodes, opcode_relabel)
+
+ tensor_relabel[
+ -1
+ ] = -1 # Some files have ops with -1 input tensors; leave unchanged
+
+ for i in used_buffer_ids:
+ if not i in buffer_relabel:
+ buffer_relabel[i] = len(dst_model.buffers)
+ dst_model.buffers.append(src_model.buffers[i])
+
+ for o in src_subgraph.operators:
+ o.inputs = [tensor_relabel[t] for t in o.inputs]
+ o.outputs = [tensor_relabel[t] for t in o.outputs]
+ o.opcodeIndex = opcode_relabel[o.opcodeIndex]
+
+ for t in used_tensors:
+ t.buffer = buffer_relabel[t.buffer]
+
+ src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs]
+ src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs]
+
+ dst_subgraph.inputs = list(set(src_subgraph.inputs).union(dst_subgraph.inputs))
+ dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs))
+
+
+def append_relabel(src, dst, map=None):
+ if map is None:
+ map = {}
+ for i, x in enumerate(src):
+ if not i in map:
+ map[i] = len(dst)
+ dst.append(x)
+ return map
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
new file mode 100644
index 0000000..03cd3f9
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -0,0 +1,62 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import math
+import os
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from tqdm import tqdm
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import (
+ NumpyTFReader,
+ NumpyTFWriter,
+ TFLiteModel,
+ numpytf_count,
+)
+from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+
+
+def record_model(
+ input_filename,
+ model_filename,
+ output_filename,
+ batch_size=None,
+ show_progress=False,
+ num_procs=1,
+ num_threads=0,
+):
+ """num_procs: 0 => detect real cores on system
+ num_threads: 0 => TFLite impl. specific setting, usually 3"""
+ model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size)
+ if not batch_size:
+ batch_size = (
+ model.num_procs * model.batch_size
+ ) # automatically batch to the minimum effective size if not specified
+
+ total = numpytf_count(input_filename)
+ dataset = NumpyTFReader(input_filename)
+ writer = NumpyTFWriter(output_filename)
+
+ if batch_size > 1:
+ # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now.
+ dataset = dataset.map(
+ lambda d: {k: tf.squeeze(v, axis=0) for k, v in d.items()}
+ )
+ dataset = dataset.batch(batch_size, drop_remainder=False)
+ total = int(math.ceil(total / batch_size))
+
+ for j, named_x in enumerate(
+ tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ ):
+ named_y = model(named_x)
+ if batch_size > 1:
+ for i in range(batch_size):
+ # Expand the batches and recreate each dict as a batch-size 1 item for the tfrec output
+ d = {k: v[i : i + 1] for k, v in named_y.items() if i < v.shape[0]}
+ if d:
+ writer.write(d)
+ else:
+ writer.write(named_y)
+ model.close()
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
new file mode 100644
index 0000000..a929b14
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -0,0 +1,487 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import math
+import os
+import tempfile
+from collections import defaultdict
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+try:
+ from tensorflow.keras.optimizers.schedules import CosineDecay
+except ImportError:
+ # In TF 2.4 CosineDecay was still experimental
+ from tensorflow.keras.experimental import CosineDecay
+
+import numpy as np
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import (
+ NumpyTFReader,
+ NumpyTFWriter,
+ TFLiteModel,
+ numpytf_count,
+)
+from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+from mlia.nn.rewrite.core.graph_edit.record import record_model
+from mlia.nn.rewrite.core.utils.utils import load, save
+from mlia.nn.rewrite.core.extract import extract
+from mlia.nn.rewrite.core.graph_edit.join import join_models
+from mlia.nn.rewrite.core.graph_edit.diff import diff_stats
+
+
+augmentation_presets = {
+ "none": (None, None),
+ "gaussian": (None, 1.0),
+ "mixup": (1.0, None),
+ "mixout": (1.6, None),
+ "mix_gaussian_large": (2.0, 1.0),
+ "mix_gaussian_small": (1.6, 0.3),
+}
+
+
+class SequentialTrainer:
+ def __init__(
+ self,
+ source_model,
+ output_model,
+ input_tfrec,
+ augment="gaussian",
+ steps=6000,
+ lr=1e-3,
+ batch_size=32,
+ show_progress=True,
+ eval_fn=None,
+ num_procs=1,
+ num_threads=0,
+ ):
+ self.source_model = source_model
+ self.output_model = output_model
+ self.input_tfrec = input_tfrec
+ self.default_augment = augment
+ self.default_steps = steps
+ self.default_lr = lr
+ self.default_batch_size = batch_size
+ self.show_progress = show_progress
+ self.num_procs = num_procs
+ self.num_threads = num_threads
+ self.first_replace = True
+ self.eval_fn = eval_fn
+
+ def replace(
+ self,
+ model_fn,
+ input_tensors,
+ output_tensors,
+ augment=None,
+ steps=None,
+ lr=None,
+ batch_size=None,
+ ):
+ augment = self.default_augment if augment is None else augment
+ steps = self.default_steps if steps is None else steps
+ lr = self.default_lr if lr is None else lr
+ batch_size = self.default_batch_size if batch_size is None else batch_size
+
+ if isinstance(augment, str):
+ augment = augmentation_presets[augment]
+
+ if self.first_replace:
+ source_model = self.source_model
+ unmodified_model = None
+ else:
+ source_model = self.output_model
+ unmodified_model = self.source_model
+
+ mae, nrmse = train(
+ source_model,
+ unmodified_model,
+ self.output_model,
+ self.input_tfrec,
+ model_fn,
+ input_tensors,
+ output_tensors,
+ augment,
+ steps,
+ lr,
+ batch_size,
+ False,
+ self.show_progress,
+ None,
+ 0,
+ self.num_procs,
+ self.num_threads,
+ )
+
+ self.first_replace = False
+ if self.eval_fn:
+ return self.eval_fn(mae, nrmse, self.output_model)
+ else:
+ return mae, nrmse
+
+
+def train(
+ source_model,
+ unmodified_model,
+ output_model,
+ input_tfrec,
+ replace_fn,
+ input_tensors,
+ output_tensors,
+ augment,
+ steps,
+ lr,
+ batch_size,
+ verbose,
+ show_progress,
+ checkpoint_at=None,
+ checkpoint_decay_steps=0,
+ num_procs=1,
+ num_threads=0,
+):
+ if unmodified_model:
+ unmodified_model_dir = tempfile.TemporaryDirectory()
+ unmodified_model_dir_path = unmodified_model_dir.name
+ extract(
+ unmodified_model_dir_path,
+ source_model,
+ input_tfrec,
+ input_tensors,
+ output_tensors,
+ )
+ else:
+ unmodified_model_dir = None
+ unmodified_model_dir_path = None
+
+ results = []
+ with tempfile.TemporaryDirectory() as train_dir:
+ p = lambda file: os.path.join(train_dir, file)
+
+ extract(
+ train_dir,
+ source_model,
+ input_tfrec,
+ input_tensors,
+ output_tensors,
+ num_procs=num_procs,
+ num_threads=num_threads,
+ )
+
+ tflite_filenames = train_in_dir(
+ train_dir,
+ unmodified_model_dir_path,
+ p("new.tflite"),
+ replace_fn,
+ augment,
+ steps,
+ lr,
+ batch_size,
+ checkpoint_at=checkpoint_at,
+ checkpoint_decay_steps=checkpoint_decay_steps,
+ verbose=verbose,
+ show_progress=show_progress,
+ num_procs=num_procs,
+ num_threads=num_threads,
+ )
+
+ for i, filename in enumerate(tflite_filenames):
+ results.append(eval_in_dir(train_dir, filename, num_procs, num_threads))
+
+ if output_model:
+ if i + 1 < len(tflite_filenames):
+ # Append the same _@STEPS.tflite postfix used by intermediate checkpoints for all but the last output
+ postfix = filename.split("_@")[-1]
+ output_filename = output_model.split(".tflite")[0] + postfix
+ else:
+ output_filename = output_model
+ join_in_dir(train_dir, filename, output_filename)
+
+ if unmodified_model_dir:
+ unmodified_model_dir.cleanup()
+
+ return (
+ results if checkpoint_at else results[0]
+ ) # only return a list if multiple checkpoints are asked for
+
+
+def eval_in_dir(dir, new_part, num_procs=1, num_threads=0):
+ p = lambda file: os.path.join(dir, file)
+ input = (
+ p("input_orig.tfrec")
+ if os.path.exists(p("input_orig.tfrec"))
+ else p("input.tfrec")
+ )
+ output = (
+ p("output_orig.tfrec")
+ if os.path.exists(p("output_orig.tfrec"))
+ else p("output.tfrec")
+ )
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ predict = os.path.join(tmp_dir, "predict.tfrec")
+ record_model(
+ input, new_part, predict, num_procs=num_procs, num_threads=num_threads
+ )
+ mae, nrmse = diff_stats(output, predict)
+
+ return mae, nrmse
+
+
+def join_in_dir(dir, new_part, output_model):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ d = lambda file: os.path.join(dir, file)
+ new_end = os.path.join(tmp_dir, "new_end.tflite")
+ join_models(new_part, d("end.tflite"), new_end)
+ join_models(d("start.tflite"), new_end, output_model)
+
+
+def train_in_dir(
+ train_dir,
+ baseline_dir,
+ output_filename,
+ replace_fn,
+ augmentations,
+ steps,
+ lr=1e-3,
+ batch_size=32,
+ checkpoint_at=None,
+ checkpoint_decay_steps=0,
+ schedule="cosine",
+ verbose=False,
+ show_progress=False,
+ num_procs=None,
+ num_threads=1,
+):
+ """Train a replacement for replace.tflite using the input.tfrec and output.tfrec in train_dir.
+ If baseline_dir is provided, train the replacement to match baseline outputs for train_dir inputs.
+ Result saved as new.tflite in train_dir.
+ """
+ teacher_dir = baseline_dir if baseline_dir else train_dir
+ teacher = ParallelTFLiteModel(
+ "%s/replace.tflite" % teacher_dir, num_procs, num_threads, batch_size=batch_size
+ )
+ replace = TFLiteModel("%s/replace.tflite" % train_dir)
+ assert len(teacher.input_tensors()) == 1, (
+ "Can only train replacements with a single input tensor right now, found %s"
+ % teacher.input_tensors()
+ )
+ assert len(teacher.output_tensors()) == 1, (
+ "Can only train replacements with a single output tensor right now, found %s"
+ % teacher.output_tensors()
+ )
+ input_name = teacher.input_tensors()[0]
+ output_name = teacher.output_tensors()[0]
+
+ assert len(teacher.shape_from_name) == len(
+ replace.shape_from_name
+ ), "Baseline and train models must have the same number of inputs and outputs. Teacher: {}\nTrain dir: {}".format(
+ teacher.shape_from_name, replace.shape_from_name
+ )
+ assert all(
+ tn == rn and (ts[1:] == rs[1:]).all()
+ for (tn, ts), (rn, rs) in zip(
+ teacher.shape_from_name.items(), replace.shape_from_name.items()
+ )
+ ), "Baseline and train models must have the same input and output shapes for the subgraph being replaced. Teacher: {}\nTrain dir: {}".format(
+ teacher.shape_from_name, replace.shape_from_name
+ )
+
+ input_filename = os.path.join(train_dir, "input.tfrec")
+ total = numpytf_count(input_filename)
+ dict_inputs = NumpyTFReader(input_filename)
+ inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0))
+ if any(augmentations):
+ # Map the teacher inputs here because the augmentation stage passes these through a TFLite model to get the outputs
+ teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "input.tfrec")).map(
+ lambda d: tf.squeeze(d[input_name], axis=0)
+ )
+ else:
+ teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "output.tfrec")).map(
+ lambda d: tf.squeeze(d[output_name], axis=0)
+ )
+
+ steps_per_epoch = math.ceil(total / batch_size)
+ epochs = int(math.ceil(steps / steps_per_epoch))
+ if verbose:
+ print(
+ "Training on %d items for %d steps (%d epochs with batch size %d)"
+ % (total, epochs * steps_per_epoch, epochs, batch_size)
+ )
+
+ dataset = tf.data.Dataset.zip((inputs, teacher_outputs))
+ if epochs > 1:
+ dataset = dataset.cache()
+ dataset = dataset.shuffle(total).repeat().batch(batch_size)
+
+ if any(augmentations):
+ augment_train, augment_teacher = augment_fn_twins(dict_inputs, augmentations)
+ augment_fn = lambda train, teach: (
+ augment_train({input_name: train})[input_name],
+ teacher(augment_teacher({input_name: teach}))[output_name],
+ )
+ dataset = dataset.map(
+ lambda train, teach: tf.py_function(
+ augment_fn, inp=[train, teach], Tout=[tf.float32, tf.float32]
+ )
+ )
+
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
+
+ input_shape = teacher.shape_from_name[input_name][1:]
+ output_shape = teacher.shape_from_name[output_name][1:]
+ model = replace_fn(input_shape, output_shape)
+
+ optimizer = tf.keras.optimizers.Nadam(learning_rate=lr)
+ loss_fn = tf.keras.losses.MeanSquaredError()
+ model.compile(optimizer=optimizer, loss=loss_fn)
+
+ if verbose:
+ model.summary()
+
+ steps_so_far = 0
+
+ def cosine_decay(epoch_step, logs):
+ """Cosine decay from lr at start of the run to zero at the end"""
+ current_step = epoch_step + steps_so_far
+ learning_rate = lr * (math.cos(math.pi * current_step / steps) + 1) / 2.0
+ tf.keras.backend.set_value(optimizer.learning_rate, learning_rate)
+
+ def late_decay(epoch_step, logs):
+ """Constant until the last 20% of the run, then linear decay to zero"""
+ current_step = epoch_step + steps_so_far
+ steps_remaining = steps - current_step
+ decay_length = steps // 5
+ decay_fraction = min(steps_remaining, decay_length) / decay_length
+ learning_rate = lr * decay_fraction
+ tf.keras.backend.set_value(optimizer.learning_rate, learning_rate)
+
+ if schedule == "cosine":
+ callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
+ elif schedule == "late":
+ callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)]
+ elif schedule == "constant":
+ callbacks = []
+ else:
+ assert False, (
+ 'LR schedule "%s" not implemented - expected "cosine", "constant" or "late"'
+ % schedule
+ )
+
+ output_filenames = []
+ checkpoints = (checkpoint_at if checkpoint_at else []) + [steps]
+ while steps_so_far < steps:
+ steps_to_train = checkpoints.pop(0) - steps_so_far
+ lr_start = optimizer.learning_rate.numpy()
+ model.fit(
+ dataset,
+ epochs=1,
+ steps_per_epoch=steps_to_train,
+ callbacks=callbacks,
+ verbose=show_progress,
+ )
+ steps_so_far += steps_to_train
+ print(
+ "lr decayed from %f to %f over %d steps"
+ % (lr_start, optimizer.learning_rate.numpy(), steps_to_train)
+ )
+
+ if steps_so_far < steps:
+ filename, ext = os.path.splitext(output_filename)
+ checkpoint_filename = filename + ("_@%d" % steps_so_far) + ext
+ else:
+ checkpoint_filename = output_filename
+ print("%d/%d: Saved as %s" % (steps_so_far, steps, checkpoint_filename))
+ save_as_tflite(
+ model,
+ checkpoint_filename,
+ input_name,
+ replace.shape_from_name[input_name],
+ output_name,
+ replace.shape_from_name[output_name],
+ )
+ output_filenames.append(checkpoint_filename)
+
+ teacher.close()
+ return output_filenames
+
+
+def save_as_tflite(
+ keras_model, filename, input_name, input_shape, output_name, output_shape
+):
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+ tflite_model = converter.convert()
+
+ with open(filename, "wb") as f:
+ f.write(tflite_model)
+
+ # Now fix the shapes and names to match those we expect
+ fb = load(filename)
+ i = fb.subgraphs[0].inputs[0]
+ fb.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32)
+ fb.subgraphs[0].tensors[i].name = input_name.encode("utf-8")
+ o = fb.subgraphs[0].outputs[0]
+ fb.subgraphs[0].tensors[o].shape = np.array(output_shape, dtype=np.int32)
+ fb.subgraphs[0].tensors[o].name = output_name.encode("utf-8")
+ save(fb, filename)
+
+
+def augment_fn_twins(inputs, augmentations):
+ """Return a pair of twinned augmentation functions with the same sequence of random numbers"""
+ seed = np.random.randint(2**32 - 1)
+ rng1 = np.random.default_rng(seed)
+ rng2 = np.random.default_rng(seed)
+ return augment_fn(inputs, augmentations, rng1), augment_fn(
+ inputs, augmentations, rng2
+ )
+
+
+def augment_fn(inputs, augmentations, rng):
+ mixup_strength, gaussian_strength = augmentations
+
+ augments = []
+
+ if mixup_strength:
+ mixup_range = (0.5 - mixup_strength / 2, 0.5 + mixup_strength / 2)
+ augment = lambda d: {
+ k: mixup(rng, v.numpy(), mixup_range) for k, v in d.items()
+ }
+ augments.append(augment)
+
+ if gaussian_strength:
+ values = defaultdict(list)
+ for d in inputs.as_numpy_iterator():
+ for k, v in d.items():
+ values[k].append(v)
+ noise_scale = {
+ k: np.std(v, axis=0).astype(np.float32) for k, v in values.items()
+ }
+ augment = lambda d: {
+ k: v
+ + rng.standard_normal(v.shape).astype(np.float32)
+ * gaussian_strength
+ * noise_scale[k]
+ for k, v in d.items()
+ }
+ augments.append(augment)
+
+ if len(augments) == 0:
+ return lambda x: x
+ elif len(augments) == 1:
+ return augments[0]
+ elif len(augments) == 2:
+ return lambda x: augments[1](augments[0](x))
+ else:
+ assert False, "Unexpected number of augmentation functions (%d)" % len(augments)
+
+
+def mixup(rng, batch, beta_range=(0.0, 1.0)):
+ """Each tensor in the batch becomes a linear combination of it and one other tensor"""
+ a = batch
+ b = np.array(batch)
+ rng.shuffle(b) # randomly pair up tensors in the batch
+ # random mixing coefficient for each pair
+ beta = rng.uniform(
+ low=beta_range[0], high=beta_range[1], size=batch.shape[0]
+ ).astype(np.float32)
+ return (a.T * beta).T + (b.T * (1.0 - beta)).T # return linear combinations
diff --git a/src/mlia/nn/rewrite/core/utils/__init__.py b/src/mlia/nn/rewrite/core/utils/__init__.py
new file mode 100644
index 0000000..48b1622
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/utils/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
new file mode 100644
index 0000000..ac3e875
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -0,0 +1,172 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import json
+import os
+import random
+import tempfile
+from collections import defaultdict
+
+import numpy as np
+
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from tensorflow.lite.python import interpreter as interpreter_wrapper
+
+
+def make_decode_fn(filename):
+ def decode_fn(record_bytes, type_map):
+ parse_dict = {
+ name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
+ }
+ example = tf.io.parse_single_example(record_bytes, parse_dict)
+ features = {
+ n: tf.io.parse_tensor(example[n], tf.as_dtype(t))
+ for n, t in type_map.items()
+ }
+ return features
+
+ meta_filename = filename + ".meta"
+ with open(meta_filename) as f:
+ type_map = json.load(f)["type_map"]
+ return lambda record_bytes: decode_fn(record_bytes, type_map)
+
+
+def NumpyTFReader(filename):
+ decode_fn = make_decode_fn(filename)
+ dataset = tf.data.TFRecordDataset(filename)
+ return dataset.map(decode_fn)
+
+
+def numpytf_count(filename):
+ meta_filename = filename + ".meta"
+ with open(meta_filename) as f:
+ return json.load(f)["count"]
+
+
+class NumpyTFWriter:
+ def __init__(self, filename):
+ self.filename = filename
+ self.meta_filename = filename + ".meta"
+ self.writer = tf.io.TFRecordWriter(filename)
+ self.type_map = {}
+ self.count = 0
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
+ def __del__(self):
+ self.close()
+
+ def write(self, array_dict):
+ type_map = {n: str(a.dtype.name) for n, a in array_dict.items()}
+ self.type_map.update(type_map)
+ self.count += 1
+
+ feature = {
+ n: tf.train.Feature(
+ bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(a).numpy()])
+ )
+ for n, a in array_dict.items()
+ }
+ example = tf.train.Example(features=tf.train.Features(feature=feature))
+ self.writer.write(example.SerializeToString())
+
+ def close(self):
+ with open(self.meta_filename, "w") as f:
+ meta = {"type_map": self.type_map, "count": self.count}
+ json.dump(meta, f)
+ self.writer.close()
+
+
+class TFLiteModel:
+ def __init__(self, filename, batch_size=None, num_threads=None):
+ if num_threads == 0:
+ num_threads = None
+ if batch_size == None:
+ self.interpreter = interpreter_wrapper.Interpreter(
+ model_path=filename, num_threads=num_threads
+ )
+ else: # if a batch size is specified, modify the TFLite model to use this size
+ with tempfile.TemporaryDirectory() as tmp:
+ fb = load(filename)
+ for sg in fb.subgraphs:
+ for t in list(sg.inputs) + list(sg.outputs):
+ sg.tensors[t].shape = np.array(
+ [batch_size] + list(sg.tensors[t].shape[1:]), dtype=np.int32
+ )
+ tempname = os.path.join(tmp, "rewrite_tmp.tflite")
+ save(fb, tempname)
+ self.interpreter = interpreter_wrapper.Interpreter(
+ model_path=tempname, num_threads=num_threads
+ )
+
+ try:
+ self.interpreter.allocate_tensors()
+ except RuntimeError:
+ self.interpreter = interpreter_wrapper.Interpreter(
+ model_path=filename, num_threads=num_threads
+ )
+ self.interpreter.allocate_tensors()
+
+ # Get input and output tensors.
+ self.input_details = self.interpreter.get_input_details()
+ self.output_details = self.interpreter.get_output_details()
+ details = list(self.input_details) + list(self.output_details)
+ self.handle_from_name = {d["name"]: d["index"] for d in details}
+ self.shape_from_name = {d["name"]: d["shape"] for d in details}
+ self.batch_size = next(iter(self.shape_from_name.values()))[0]
+
+ def __call__(self, named_input):
+ """Execute the model on one or a batch of named inputs (a dict of name: numpy array)"""
+ input_len = next(iter(named_input.values())).shape[0]
+ full_steps = input_len // self.batch_size
+ remainder = input_len % self.batch_size
+
+ named_ys = defaultdict(list)
+ for i in range(full_steps):
+ for name, x_batch in named_input.items():
+ x = x_batch[i : i + self.batch_size]
+ self.interpreter.set_tensor(self.handle_from_name[name], x)
+ self.interpreter.invoke()
+ for d in self.output_details:
+ named_ys[d["name"]].append(self.interpreter.get_tensor(d["index"]))
+ if remainder:
+ for name, x_batch in named_input.items():
+ x = np.zeros(self.shape_from_name[name]).astype(x_batch.dtype)
+ x[:remainder] = x_batch[-remainder:]
+ self.interpreter.set_tensor(self.handle_from_name[name], x)
+ self.interpreter.invoke()
+ for d in self.output_details:
+ named_ys[d["name"]].append(
+ self.interpreter.get_tensor(d["index"])[:remainder]
+ )
+ return {k: np.concatenate(v) for k, v in named_ys.items()}
+
+ def input_tensors(self):
+ return [d["name"] for d in self.input_details]
+
+ def output_tensors(self):
+ return [d["name"] for d in self.output_details]
+
+
+def sample_tfrec(input_file, k, output_file):
+ total = numpytf_count(input_file)
+ next = sorted(random.sample(range(total), k=k), reverse=True)
+
+ reader = NumpyTFReader(input_file)
+ with NumpyTFWriter(output_file) as writer:
+ for i, data in enumerate(reader):
+ if i == next[-1]:
+ next.pop()
+ writer.write(data)
+ if not next:
+ break
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py
new file mode 100644
index 0000000..5affc03
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/utils/parallel.py
@@ -0,0 +1,113 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import math
+import os
+from collections import defaultdict
+from multiprocessing import Pool
+
+import numpy as np
+from psutil import cpu_count
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
+
+
+class ParallelTFLiteModel(TFLiteModel):
+ def __init__(self, filename, num_procs=1, num_threads=0, batch_size=None):
+ """num_procs: 0 => detect real cores on system
+ num_threads: 0 => TFLite impl. specific setting, usually 3
+ batch_size: None => automatic (num_procs or file-determined)
+ """
+ self.pool = None
+ self.filename = filename
+ if not num_procs:
+ self.num_procs = cpu_count(logical=False)
+ else:
+ self.num_procs = int(num_procs)
+
+ self.num_threads = num_threads
+
+ if self.num_procs > 1:
+ if not batch_size:
+ batch_size = self.num_procs # default to min effective batch size
+ local_batch_size = int(math.ceil(batch_size / self.num_procs))
+ super().__init__(filename, batch_size=local_batch_size)
+ del self.interpreter
+ self.pool = Pool(
+ processes=self.num_procs,
+ initializer=_pool_create_worker,
+ initargs=[filename, self.batch_size, self.num_threads],
+ )
+ else: # fall back to serial implementation for max performance
+ super().__init__(
+ filename, batch_size=batch_size, num_threads=self.num_threads
+ )
+
+ self.total_batches = 0
+ self.partial_batches = 0
+ self.warned = False
+
+ def close(self):
+ if self.pool:
+ self.pool.close()
+ self.pool.terminate()
+
+ def __del__(self):
+ self.close()
+
+ def __call__(self, named_input):
+ if self.pool:
+ global_batch_size = next(iter(named_input.values())).shape[0]
+ # Note: self.batch_size comes from superclass and is local batch size
+ chunks = int(math.ceil(global_batch_size / self.batch_size))
+ self.total_batches += 1
+ if chunks != self.num_procs:
+ self.partial_batches += 1
+ if (
+ not self.warned
+ and self.total_batches > 10
+ and self.partial_batches / self.total_batches >= 0.5
+ ):
+ print(
+ "ParallelTFLiteModel(%s): warning - %.1f%% of batches do not use all %d processes, set batch size to a multiple of this"
+ % (
+ self.filename,
+ 100 * self.partial_batches / self.total_batches,
+ self.num_procs,
+ )
+ )
+ self.warned = True
+
+ local_batches = [
+ {
+ key: values[i * self.batch_size : (i + 1) * self.batch_size]
+ for key, values in named_input.items()
+ }
+ for i in range(chunks)
+ ]
+ chunk_results = self.pool.map(_pool_run, local_batches)
+ named_ys = defaultdict(list)
+ for chunk in chunk_results:
+ for k, v in chunk.items():
+ named_ys[k].append(v)
+ return {k: np.concatenate(v) for k, v in named_ys.items()}
+ else:
+ return super().__call__(named_input)
+
+
+_local_model = None
+
+
+def _pool_create_worker(filename, local_batch_size=None, num_threads=None):
+ global _local_model
+ _local_model = TFLiteModel(
+ filename, batch_size=local_batch_size, num_threads=num_threads
+ )
+
+
+def _pool_run(named_inputs):
+ return _local_model(named_inputs)
diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py
new file mode 100644
index 0000000..ed6c81d
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/utils/utils.py
@@ -0,0 +1,26 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+
+import flatbuffers
+from tensorflow.lite.python import schema_py_generated as schema_fb
+
+
+def load(input_tflite_file):
+ if not os.path.exists(input_tflite_file):
+ raise RuntimeError("TFLite file not found at %r\n" % input_tflite_file)
+ with open(input_tflite_file, "rb") as file_handle:
+ file_data = bytearray(file_handle.read())
+ model_obj = schema_fb.Model.GetRootAsModel(file_data, 0)
+ model = schema_fb.ModelT.InitFromObj(model_obj)
+ return model
+
+
+def save(model, output_tflite_file):
+ builder = flatbuffers.Builder(1024) # Initial size of the buffer, which
+ # will grow automatically if needed
+ model_offset = model.Pack(builder)
+ builder.Finish(model_offset, file_identifier=b"TFL3")
+ model_data = builder.Output()
+ with open(output_tflite_file, "wb") as out_file:
+ out_file.write(model_data)