From f0b8ed75fed9dc69ab1f6313339f9f7e38bfc725 Mon Sep 17 00:00:00 2001 From: Annie Tallund Date: Mon, 13 Mar 2023 17:00:31 +0100 Subject: MLIA-845 Migrate rewrite code - Add required files for rewriting of TensorFlow Lite graphs - Adapt rewrite dependency paths and project name - Add license headers Change-Id: I19c5f63215fe2af2fa7d7d44af08144c6c5f911c Signed-off-by: Benjamin Klimczak --- setup.cfg | 1 + src/mlia/nn/rewrite/__init__.py | 2 + src/mlia/nn/rewrite/core/__init__.py | 2 + src/mlia/nn/rewrite/core/extract.py | 87 ++++ src/mlia/nn/rewrite/core/graph_edit/__init__.py | 2 + src/mlia/nn/rewrite/core/graph_edit/cut.py | 130 ++++++ src/mlia/nn/rewrite/core/graph_edit/diff.py | 109 +++++ src/mlia/nn/rewrite/core/graph_edit/join.py | 128 ++++++ src/mlia/nn/rewrite/core/graph_edit/record.py | 62 +++ src/mlia/nn/rewrite/core/train.py | 487 +++++++++++++++++++++++ src/mlia/nn/rewrite/core/utils/__init__.py | 2 + src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py | 172 ++++++++ src/mlia/nn/rewrite/core/utils/parallel.py | 113 ++++++ src/mlia/nn/rewrite/core/utils/utils.py | 26 ++ 14 files changed, 1323 insertions(+) create mode 100644 src/mlia/nn/rewrite/__init__.py create mode 100644 src/mlia/nn/rewrite/core/__init__.py create mode 100644 src/mlia/nn/rewrite/core/extract.py create mode 100644 src/mlia/nn/rewrite/core/graph_edit/__init__.py create mode 100644 src/mlia/nn/rewrite/core/graph_edit/cut.py create mode 100644 src/mlia/nn/rewrite/core/graph_edit/diff.py create mode 100644 src/mlia/nn/rewrite/core/graph_edit/join.py create mode 100644 src/mlia/nn/rewrite/core/graph_edit/record.py create mode 100644 src/mlia/nn/rewrite/core/train.py create mode 100644 src/mlia/nn/rewrite/core/utils/__init__.py create mode 100644 src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py create mode 100644 src/mlia/nn/rewrite/core/utils/parallel.py create mode 100644 src/mlia/nn/rewrite/core/utils/utils.py diff --git a/setup.cfg b/setup.cfg index 5a68b6b..7cdd3c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,7 @@ install_requires = requests~=2.31.0 rich~=13.5.2 tomli~=2.0.1 ; python_version<"3.11" + tqdm~=4.65.0 [options.packages.find] where = src diff --git a/src/mlia/nn/rewrite/__init__.py b/src/mlia/nn/rewrite/__init__.py new file mode 100644 index 0000000..48b1622 --- /dev/null +++ b/src/mlia/nn/rewrite/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file diff --git a/src/mlia/nn/rewrite/core/__init__.py b/src/mlia/nn/rewrite/core/__init__.py new file mode 100644 index 0000000..48b1622 --- /dev/null +++ b/src/mlia/nn/rewrite/core/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py new file mode 100644 index 0000000..5fcd348 --- /dev/null +++ b/src/mlia/nn/rewrite/core/extract.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +from mlia.nn.rewrite.core.graph_edit.cut import cut_model +from mlia.nn.rewrite.core.graph_edit.record import record_model + + +def extract( + output_path, + model_file, + input_data, + input_names, + output_names, + subgraph=0, + skip_outputs=False, + show_progress=False, + num_procs=1, + num_threads=0, +): + try: + os.mkdir(output_path) + except FileExistsError: + pass + + start_file = os.path.join(output_path, "start.tflite") + cut_model( + model_file, + input_names=None, + output_names=input_names, + subgraph_index=subgraph, + output_file=start_file, + ) + + input_tfrec = os.path.join(output_path, "input.tfrec") + record_model( + input_data, + start_file, + input_tfrec, + show_progress=show_progress, + num_procs=num_procs, + num_threads=num_threads, + ) + + replace_file = os.path.join(output_path, "replace.tflite") + cut_model( + model_file, + input_names=input_names, + output_names=output_names, + subgraph_index=subgraph, + output_file=replace_file, + ) + + end_file = os.path.join(output_path, "end.tflite") + cut_model( + model_file, + input_names=output_names, + output_names=None, + subgraph_index=subgraph, + output_file=end_file, + ) + + if not skip_outputs: + output_tfrec = os.path.join(output_path, "output.tfrec") + record_model( + input_tfrec, + replace_file, + output_tfrec, + show_progress=show_progress, + num_procs=num_procs, + num_threads=num_threads, + ) + + end_tfrec = os.path.join(output_path, "end.tfrec") + record_model( + output_tfrec, + end_file, + end_tfrec, + show_progress=show_progress, + num_procs=num_procs, + num_threads=num_threads, + ) diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py new file mode 100644 index 0000000..48b1622 --- /dev/null +++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py new file mode 100644 index 0000000..a323b7b --- /dev/null +++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import os +from collections import defaultdict + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +from mlia.nn.rewrite.core.utils.utils import load, save + + +def cut_subgraph(subgraph, input_tensor_names, output_tensor_names): + """Change the global inputs and outputs of a graph to the provided named tensors""" + + def tensors_by_name(names): + seek = frozenset([name.encode("utf-8") for name in names]) + tensors = [ + i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek + ] + return tensors + + if input_tensor_names is not None: + subgraph.inputs = tensors_by_name(input_tensor_names) + assert len(subgraph.inputs) == len( + input_tensor_names + ), "Expected %d input tensors: %s\nFound: %s" % ( + len(subgraph.inputs), + ", ".join(input_tensor_names), + ", ".join(subgraph.tensors[i].name for i in subgraph.inputs), + ) + + if output_tensor_names is not None: + subgraph.outputs = tensors_by_name(output_tensor_names) + assert len(subgraph.outputs) == len( + output_tensor_names + ), "Expected %d output tensors: %s\nFound: %s" % ( + len(subgraph.outputs), + ", ".join(output_tensor_names), + ", ".join(subgraph.tensors[i].name for i in subgraph.outputs), + ) + + +def simplify(model): + """Remove any unused operators, tensors and buffers from a model""" + for s in model.subgraphs: + simplify_subgraph(s) + + used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs} + used_buffers = used_buffers.union({m.buffer for m in model.metadata}) + used_buffers.add( + 0 + ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime + model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers) + + for s in model.subgraphs: + for t in s.tensors: + t.buffer = buf_relabel[t.buffer] + + for m in model.metadata: + m.buffer = buf_relabel[m.buffer] + + +def simplify_subgraph(subgraph): + requires = defaultdict(set) + + for o, operator in enumerate(subgraph.operators): + for t in operator.outputs: + if not t in subgraph.inputs: + requires[t].add(o) + + op_set, ten_set = find_required(subgraph, requires, subgraph.outputs) + + subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set) + subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set) + + ten_relabel[-1] = -1 # Some files have ops with -1 input tensors; leave unchanged + + for op in subgraph.operators: + op.inputs = [ten_relabel[t] for t in op.inputs] + op.outputs = [ten_relabel[t] for t in op.outputs] + + subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs] + subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs] + + +def find_required(subgraph, requires, tensors): + visited_operators = set() + visited_tensors = set(tensors) + stop_tensors = set(subgraph.inputs) + changed = True + + next_tensors = visited_tensors + while next_tensors: + loop_tensors = next_tensors + next_tensors = set() + for t in loop_tensors: + candidate_operators = set(requires[t]) + new_operators = candidate_operators - visited_operators + visited_operators = visited_operators.union(new_operators) + for op in new_operators: + candidate_tensors = set(subgraph.operators[op].inputs) + new_tensors = candidate_tensors - (visited_tensors.union(next_tensors)) + next_tensors = next_tensors.union(new_tensors) + visited_tensors = visited_tensors.union(candidate_tensors) + visited_tensors = visited_tensors.union( + subgraph.operators[op].outputs + ) # include stub outputs but do not traverse them + next_tensors = next_tensors - stop_tensors + + return visited_operators, visited_tensors + + +def filter_relabel(src, filter): + relabel = {} + output = [] + for i, x in enumerate(src): + if i in filter: + relabel[i] = len(output) + output.append(x) + return output, relabel + + +def cut_model(model_file, input_names, output_names, subgraph_index, output_file): + model = load(model_file) + subgraph = model.subgraphs[subgraph_index] + cut_subgraph(subgraph, input_names, output_names) + simplify(model) + save(model, output_file) diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py new file mode 100644 index 0000000..b6c9616 --- /dev/null +++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import os +from collections import defaultdict + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +import numpy as np +from tensorflow.lite.python import interpreter as interpreter_wrapper +from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter + + +def diff(file1, file2): + results = [] + + dataset1 = NumpyTFReader(file1) + dataset2 = NumpyTFReader(file2) + + for i, (x1, x2) in enumerate(zip(dataset1, dataset2)): + assert x1.keys() == x2.keys(), ( + "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n" + % ( + i, + file1, + ", ".join(x1.keys()), + file2, + ", ".join(x2.keys()), + ) + ) + results.append({}) + for k in x1.keys(): + v1 = x1[k].numpy().astype(np.double) + v2 = x2[k].numpy().astype(np.double) + mae = abs(v1 - v2).mean() + results[-1][k] = mae + + total = sum(sum(x.values()) for x in results) + count = sum(len(x.values()) for x in results) + mean = total / count + return results, mean + + +def diff_stats(file1, file2, per_tensor_and_channel=False): + dataset1 = NumpyTFReader(file1) + dataset2 = NumpyTFReader(file2) + + totals = defaultdict(dict) + + def add_total(name, key, values): + if not key in totals[name]: + totals[name][key] = values + else: + totals[name][key] += values + + # First iterate through dataset1 and calculate per-channel total for each tensor + count = 0 + for d in dataset1: + count += 1 + for k, v in d.items(): + value = v.numpy().astype(np.double) + add_total("dataset1_total", k, value) + + # Use this to calculate per-channel mean for each tensor + per_tensor_mean = lambda name: { + k: total / count for k, total in totals[name].items() + } + dataset1_mean = per_tensor_mean("dataset1_total") + + # Next iterate through both datasets and calculate per-channel total squared error + # between them for each tensor and dataset1 variance for each tensor using the mean from above + for i, (x1, x2) in enumerate(zip(dataset1, dataset2)): + assert x1.keys() == x2.keys(), ( + "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n" + % ( + i, + file1, + ", ".join(x1.keys()), + file2, + ", ".join(x2.keys()), + ) + ) + for k in x1.keys(): + v1 = x1[k].numpy().astype(np.double) + v2 = x2[k].numpy().astype(np.double) + add_total("ae", k, abs(v1 - v2)) + add_total("se", k, (v1 - v2) ** 2) + add_total("dataset1_variance", k, (v1 - dataset1_mean[k]) ** 2) + + # Finally average over number of inputs to get the rmse and the dataset1 variance + mae = per_tensor_mean("ae") + mse = per_tensor_mean("se") + rmse = {k: np.sqrt(v) for k, v in mse.items()} + dataset1_var = per_tensor_mean("dataset1_variance") + is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var} + + # Divide by target standard deviation to get the per-channel nrmse for each tensor where possible + nrmse = { + k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]]) + for k, v in rmse.items() + } + + if per_tensor_and_channel: + return mae, nrmse + else: + dict_mean = lambda d: np.mean(list(d.values())) + return dict_mean(mae), dict_mean(nrmse) diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py new file mode 100644 index 0000000..758f4cf --- /dev/null +++ b/src/mlia/nn/rewrite/core/graph_edit/join.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +from mlia.nn.rewrite.core.utils.utils import load, save + + +def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=0): + src_model = load(input_src) + dst_model = load(input_dst) + src_subgraph = src_model.subgraphs[subgraph_src] + dst_subgraph = dst_model.subgraphs[subgraph_dst] + join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph) + save(dst_model, output_file) + + +def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): + """Copy subgraph src into subgraph dst from model, connecting tensors with the same names""" + # Find inputs that match outputs in the other graph and vice versa + dst_to_src = { + i: o + for i in src_subgraph.inputs + for o in dst_subgraph.outputs + if src_subgraph.tensors[i].name == dst_subgraph.tensors[o].name + } + + src_to_dst = { + o: i + for i in dst_subgraph.inputs + for o in src_subgraph.outputs + if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name + } + + assert not (src_to_dst and dst_to_src), ( + "Source and destination subgraphs appear to connect in a loop: %d tensors from src to dst, %d tensors from dst to src" + % (len(src_to_dst), len(dst_to_src)) + ) + + # Relabel matched input/output tensors between graphs + tensor_relabel = src_to_dst if src_to_dst else dst_to_src + + # Remove matched inputs/outputs as these will now become internal tensors + if src_to_dst: + src_subgraph.outputs = [ + o for o in src_subgraph.outputs if not o in tensor_relabel.keys() + ] + dst_subgraph.inputs = [ + i for i in dst_subgraph.inputs if not i in tensor_relabel.values() + ] + else: + src_subgraph.inputs = [ + i for i in src_subgraph.inputs if not i in tensor_relabel.keys() + ] + dst_subgraph.outputs = [ + o for o in dst_subgraph.outputs if not o in tensor_relabel.values() + ] + + buffer_relabel = { + src_subgraph.tensors[i].buffer: dst_subgraph.tensors[o].buffer + for i, o in tensor_relabel.items() + } + + used_tensors = [ + t for i, t in enumerate(src_subgraph.tensors) if not i in tensor_relabel + ] + + used_buffer_ids = [t.buffer for t in used_tensors] + + opcode_data = lambda c: ( + c.builtinCode, + c.deprecatedBuiltinCode, + c.customCode, + c.version, + ) + opcode_relabel = { + s: d + for s in range(len(src_model.operatorCodes)) + for d in range(len(dst_model.operatorCodes)) + if opcode_data(src_model.operatorCodes[s]) + == opcode_data(dst_model.operatorCodes[d]) + } + + # operator order defines execution schedule so must reflect the inputs/outputs dependencies + if dst_to_src: + dst_subgraph.operators += src_subgraph.operators + else: + dst_subgraph.operators = src_subgraph.operators + dst_subgraph.operators + + append_relabel(src_subgraph.tensors, dst_subgraph.tensors, tensor_relabel) + append_relabel(src_model.operatorCodes, dst_model.operatorCodes, opcode_relabel) + + tensor_relabel[ + -1 + ] = -1 # Some files have ops with -1 input tensors; leave unchanged + + for i in used_buffer_ids: + if not i in buffer_relabel: + buffer_relabel[i] = len(dst_model.buffers) + dst_model.buffers.append(src_model.buffers[i]) + + for o in src_subgraph.operators: + o.inputs = [tensor_relabel[t] for t in o.inputs] + o.outputs = [tensor_relabel[t] for t in o.outputs] + o.opcodeIndex = opcode_relabel[o.opcodeIndex] + + for t in used_tensors: + t.buffer = buffer_relabel[t.buffer] + + src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs] + src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs] + + dst_subgraph.inputs = list(set(src_subgraph.inputs).union(dst_subgraph.inputs)) + dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs)) + + +def append_relabel(src, dst, map=None): + if map is None: + map = {} + for i, x in enumerate(src): + if not i in map: + map[i] = len(dst) + dst.append(x) + return map diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py new file mode 100644 index 0000000..03cd3f9 --- /dev/null +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import math +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +from tqdm import tqdm +from mlia.nn.rewrite.core.utils.numpy_tfrecord import ( + NumpyTFReader, + NumpyTFWriter, + TFLiteModel, + numpytf_count, +) +from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel + + +def record_model( + input_filename, + model_filename, + output_filename, + batch_size=None, + show_progress=False, + num_procs=1, + num_threads=0, +): + """num_procs: 0 => detect real cores on system + num_threads: 0 => TFLite impl. specific setting, usually 3""" + model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size) + if not batch_size: + batch_size = ( + model.num_procs * model.batch_size + ) # automatically batch to the minimum effective size if not specified + + total = numpytf_count(input_filename) + dataset = NumpyTFReader(input_filename) + writer = NumpyTFWriter(output_filename) + + if batch_size > 1: + # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now. + dataset = dataset.map( + lambda d: {k: tf.squeeze(v, axis=0) for k, v in d.items()} + ) + dataset = dataset.batch(batch_size, drop_remainder=False) + total = int(math.ceil(total / batch_size)) + + for j, named_x in enumerate( + tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress) + ): + named_y = model(named_x) + if batch_size > 1: + for i in range(batch_size): + # Expand the batches and recreate each dict as a batch-size 1 item for the tfrec output + d = {k: v[i : i + 1] for k, v in named_y.items() if i < v.shape[0]} + if d: + writer.write(d) + else: + writer.write(named_y) + model.close() diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py new file mode 100644 index 0000000..a929b14 --- /dev/null +++ b/src/mlia/nn/rewrite/core/train.py @@ -0,0 +1,487 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import math +import os +import tempfile +from collections import defaultdict + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +try: + from tensorflow.keras.optimizers.schedules import CosineDecay +except ImportError: + # In TF 2.4 CosineDecay was still experimental + from tensorflow.keras.experimental import CosineDecay + +import numpy as np +from mlia.nn.rewrite.core.utils.numpy_tfrecord import ( + NumpyTFReader, + NumpyTFWriter, + TFLiteModel, + numpytf_count, +) +from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.rewrite.core.graph_edit.record import record_model +from mlia.nn.rewrite.core.utils.utils import load, save +from mlia.nn.rewrite.core.extract import extract +from mlia.nn.rewrite.core.graph_edit.join import join_models +from mlia.nn.rewrite.core.graph_edit.diff import diff_stats + + +augmentation_presets = { + "none": (None, None), + "gaussian": (None, 1.0), + "mixup": (1.0, None), + "mixout": (1.6, None), + "mix_gaussian_large": (2.0, 1.0), + "mix_gaussian_small": (1.6, 0.3), +} + + +class SequentialTrainer: + def __init__( + self, + source_model, + output_model, + input_tfrec, + augment="gaussian", + steps=6000, + lr=1e-3, + batch_size=32, + show_progress=True, + eval_fn=None, + num_procs=1, + num_threads=0, + ): + self.source_model = source_model + self.output_model = output_model + self.input_tfrec = input_tfrec + self.default_augment = augment + self.default_steps = steps + self.default_lr = lr + self.default_batch_size = batch_size + self.show_progress = show_progress + self.num_procs = num_procs + self.num_threads = num_threads + self.first_replace = True + self.eval_fn = eval_fn + + def replace( + self, + model_fn, + input_tensors, + output_tensors, + augment=None, + steps=None, + lr=None, + batch_size=None, + ): + augment = self.default_augment if augment is None else augment + steps = self.default_steps if steps is None else steps + lr = self.default_lr if lr is None else lr + batch_size = self.default_batch_size if batch_size is None else batch_size + + if isinstance(augment, str): + augment = augmentation_presets[augment] + + if self.first_replace: + source_model = self.source_model + unmodified_model = None + else: + source_model = self.output_model + unmodified_model = self.source_model + + mae, nrmse = train( + source_model, + unmodified_model, + self.output_model, + self.input_tfrec, + model_fn, + input_tensors, + output_tensors, + augment, + steps, + lr, + batch_size, + False, + self.show_progress, + None, + 0, + self.num_procs, + self.num_threads, + ) + + self.first_replace = False + if self.eval_fn: + return self.eval_fn(mae, nrmse, self.output_model) + else: + return mae, nrmse + + +def train( + source_model, + unmodified_model, + output_model, + input_tfrec, + replace_fn, + input_tensors, + output_tensors, + augment, + steps, + lr, + batch_size, + verbose, + show_progress, + checkpoint_at=None, + checkpoint_decay_steps=0, + num_procs=1, + num_threads=0, +): + if unmodified_model: + unmodified_model_dir = tempfile.TemporaryDirectory() + unmodified_model_dir_path = unmodified_model_dir.name + extract( + unmodified_model_dir_path, + source_model, + input_tfrec, + input_tensors, + output_tensors, + ) + else: + unmodified_model_dir = None + unmodified_model_dir_path = None + + results = [] + with tempfile.TemporaryDirectory() as train_dir: + p = lambda file: os.path.join(train_dir, file) + + extract( + train_dir, + source_model, + input_tfrec, + input_tensors, + output_tensors, + num_procs=num_procs, + num_threads=num_threads, + ) + + tflite_filenames = train_in_dir( + train_dir, + unmodified_model_dir_path, + p("new.tflite"), + replace_fn, + augment, + steps, + lr, + batch_size, + checkpoint_at=checkpoint_at, + checkpoint_decay_steps=checkpoint_decay_steps, + verbose=verbose, + show_progress=show_progress, + num_procs=num_procs, + num_threads=num_threads, + ) + + for i, filename in enumerate(tflite_filenames): + results.append(eval_in_dir(train_dir, filename, num_procs, num_threads)) + + if output_model: + if i + 1 < len(tflite_filenames): + # Append the same _@STEPS.tflite postfix used by intermediate checkpoints for all but the last output + postfix = filename.split("_@")[-1] + output_filename = output_model.split(".tflite")[0] + postfix + else: + output_filename = output_model + join_in_dir(train_dir, filename, output_filename) + + if unmodified_model_dir: + unmodified_model_dir.cleanup() + + return ( + results if checkpoint_at else results[0] + ) # only return a list if multiple checkpoints are asked for + + +def eval_in_dir(dir, new_part, num_procs=1, num_threads=0): + p = lambda file: os.path.join(dir, file) + input = ( + p("input_orig.tfrec") + if os.path.exists(p("input_orig.tfrec")) + else p("input.tfrec") + ) + output = ( + p("output_orig.tfrec") + if os.path.exists(p("output_orig.tfrec")) + else p("output.tfrec") + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + predict = os.path.join(tmp_dir, "predict.tfrec") + record_model( + input, new_part, predict, num_procs=num_procs, num_threads=num_threads + ) + mae, nrmse = diff_stats(output, predict) + + return mae, nrmse + + +def join_in_dir(dir, new_part, output_model): + with tempfile.TemporaryDirectory() as tmp_dir: + d = lambda file: os.path.join(dir, file) + new_end = os.path.join(tmp_dir, "new_end.tflite") + join_models(new_part, d("end.tflite"), new_end) + join_models(d("start.tflite"), new_end, output_model) + + +def train_in_dir( + train_dir, + baseline_dir, + output_filename, + replace_fn, + augmentations, + steps, + lr=1e-3, + batch_size=32, + checkpoint_at=None, + checkpoint_decay_steps=0, + schedule="cosine", + verbose=False, + show_progress=False, + num_procs=None, + num_threads=1, +): + """Train a replacement for replace.tflite using the input.tfrec and output.tfrec in train_dir. + If baseline_dir is provided, train the replacement to match baseline outputs for train_dir inputs. + Result saved as new.tflite in train_dir. + """ + teacher_dir = baseline_dir if baseline_dir else train_dir + teacher = ParallelTFLiteModel( + "%s/replace.tflite" % teacher_dir, num_procs, num_threads, batch_size=batch_size + ) + replace = TFLiteModel("%s/replace.tflite" % train_dir) + assert len(teacher.input_tensors()) == 1, ( + "Can only train replacements with a single input tensor right now, found %s" + % teacher.input_tensors() + ) + assert len(teacher.output_tensors()) == 1, ( + "Can only train replacements with a single output tensor right now, found %s" + % teacher.output_tensors() + ) + input_name = teacher.input_tensors()[0] + output_name = teacher.output_tensors()[0] + + assert len(teacher.shape_from_name) == len( + replace.shape_from_name + ), "Baseline and train models must have the same number of inputs and outputs. Teacher: {}\nTrain dir: {}".format( + teacher.shape_from_name, replace.shape_from_name + ) + assert all( + tn == rn and (ts[1:] == rs[1:]).all() + for (tn, ts), (rn, rs) in zip( + teacher.shape_from_name.items(), replace.shape_from_name.items() + ) + ), "Baseline and train models must have the same input and output shapes for the subgraph being replaced. Teacher: {}\nTrain dir: {}".format( + teacher.shape_from_name, replace.shape_from_name + ) + + input_filename = os.path.join(train_dir, "input.tfrec") + total = numpytf_count(input_filename) + dict_inputs = NumpyTFReader(input_filename) + inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0)) + if any(augmentations): + # Map the teacher inputs here because the augmentation stage passes these through a TFLite model to get the outputs + teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "input.tfrec")).map( + lambda d: tf.squeeze(d[input_name], axis=0) + ) + else: + teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "output.tfrec")).map( + lambda d: tf.squeeze(d[output_name], axis=0) + ) + + steps_per_epoch = math.ceil(total / batch_size) + epochs = int(math.ceil(steps / steps_per_epoch)) + if verbose: + print( + "Training on %d items for %d steps (%d epochs with batch size %d)" + % (total, epochs * steps_per_epoch, epochs, batch_size) + ) + + dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) + if epochs > 1: + dataset = dataset.cache() + dataset = dataset.shuffle(total).repeat().batch(batch_size) + + if any(augmentations): + augment_train, augment_teacher = augment_fn_twins(dict_inputs, augmentations) + augment_fn = lambda train, teach: ( + augment_train({input_name: train})[input_name], + teacher(augment_teacher({input_name: teach}))[output_name], + ) + dataset = dataset.map( + lambda train, teach: tf.py_function( + augment_fn, inp=[train, teach], Tout=[tf.float32, tf.float32] + ) + ) + + dataset = dataset.prefetch(tf.data.AUTOTUNE) + + input_shape = teacher.shape_from_name[input_name][1:] + output_shape = teacher.shape_from_name[output_name][1:] + model = replace_fn(input_shape, output_shape) + + optimizer = tf.keras.optimizers.Nadam(learning_rate=lr) + loss_fn = tf.keras.losses.MeanSquaredError() + model.compile(optimizer=optimizer, loss=loss_fn) + + if verbose: + model.summary() + + steps_so_far = 0 + + def cosine_decay(epoch_step, logs): + """Cosine decay from lr at start of the run to zero at the end""" + current_step = epoch_step + steps_so_far + learning_rate = lr * (math.cos(math.pi * current_step / steps) + 1) / 2.0 + tf.keras.backend.set_value(optimizer.learning_rate, learning_rate) + + def late_decay(epoch_step, logs): + """Constant until the last 20% of the run, then linear decay to zero""" + current_step = epoch_step + steps_so_far + steps_remaining = steps - current_step + decay_length = steps // 5 + decay_fraction = min(steps_remaining, decay_length) / decay_length + learning_rate = lr * decay_fraction + tf.keras.backend.set_value(optimizer.learning_rate, learning_rate) + + if schedule == "cosine": + callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)] + elif schedule == "late": + callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)] + elif schedule == "constant": + callbacks = [] + else: + assert False, ( + 'LR schedule "%s" not implemented - expected "cosine", "constant" or "late"' + % schedule + ) + + output_filenames = [] + checkpoints = (checkpoint_at if checkpoint_at else []) + [steps] + while steps_so_far < steps: + steps_to_train = checkpoints.pop(0) - steps_so_far + lr_start = optimizer.learning_rate.numpy() + model.fit( + dataset, + epochs=1, + steps_per_epoch=steps_to_train, + callbacks=callbacks, + verbose=show_progress, + ) + steps_so_far += steps_to_train + print( + "lr decayed from %f to %f over %d steps" + % (lr_start, optimizer.learning_rate.numpy(), steps_to_train) + ) + + if steps_so_far < steps: + filename, ext = os.path.splitext(output_filename) + checkpoint_filename = filename + ("_@%d" % steps_so_far) + ext + else: + checkpoint_filename = output_filename + print("%d/%d: Saved as %s" % (steps_so_far, steps, checkpoint_filename)) + save_as_tflite( + model, + checkpoint_filename, + input_name, + replace.shape_from_name[input_name], + output_name, + replace.shape_from_name[output_name], + ) + output_filenames.append(checkpoint_filename) + + teacher.close() + return output_filenames + + +def save_as_tflite( + keras_model, filename, input_name, input_shape, output_name, output_shape +): + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + tflite_model = converter.convert() + + with open(filename, "wb") as f: + f.write(tflite_model) + + # Now fix the shapes and names to match those we expect + fb = load(filename) + i = fb.subgraphs[0].inputs[0] + fb.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32) + fb.subgraphs[0].tensors[i].name = input_name.encode("utf-8") + o = fb.subgraphs[0].outputs[0] + fb.subgraphs[0].tensors[o].shape = np.array(output_shape, dtype=np.int32) + fb.subgraphs[0].tensors[o].name = output_name.encode("utf-8") + save(fb, filename) + + +def augment_fn_twins(inputs, augmentations): + """Return a pair of twinned augmentation functions with the same sequence of random numbers""" + seed = np.random.randint(2**32 - 1) + rng1 = np.random.default_rng(seed) + rng2 = np.random.default_rng(seed) + return augment_fn(inputs, augmentations, rng1), augment_fn( + inputs, augmentations, rng2 + ) + + +def augment_fn(inputs, augmentations, rng): + mixup_strength, gaussian_strength = augmentations + + augments = [] + + if mixup_strength: + mixup_range = (0.5 - mixup_strength / 2, 0.5 + mixup_strength / 2) + augment = lambda d: { + k: mixup(rng, v.numpy(), mixup_range) for k, v in d.items() + } + augments.append(augment) + + if gaussian_strength: + values = defaultdict(list) + for d in inputs.as_numpy_iterator(): + for k, v in d.items(): + values[k].append(v) + noise_scale = { + k: np.std(v, axis=0).astype(np.float32) for k, v in values.items() + } + augment = lambda d: { + k: v + + rng.standard_normal(v.shape).astype(np.float32) + * gaussian_strength + * noise_scale[k] + for k, v in d.items() + } + augments.append(augment) + + if len(augments) == 0: + return lambda x: x + elif len(augments) == 1: + return augments[0] + elif len(augments) == 2: + return lambda x: augments[1](augments[0](x)) + else: + assert False, "Unexpected number of augmentation functions (%d)" % len(augments) + + +def mixup(rng, batch, beta_range=(0.0, 1.0)): + """Each tensor in the batch becomes a linear combination of it and one other tensor""" + a = batch + b = np.array(batch) + rng.shuffle(b) # randomly pair up tensors in the batch + # random mixing coefficient for each pair + beta = rng.uniform( + low=beta_range[0], high=beta_range[1], size=batch.shape[0] + ).astype(np.float32) + return (a.T * beta).T + (b.T * (1.0 - beta)).T # return linear combinations diff --git a/src/mlia/nn/rewrite/core/utils/__init__.py b/src/mlia/nn/rewrite/core/utils/__init__.py new file mode 100644 index 0000000..48b1622 --- /dev/null +++ b/src/mlia/nn/rewrite/core/utils/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py new file mode 100644 index 0000000..ac3e875 --- /dev/null +++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import json +import os +import random +import tempfile +from collections import defaultdict + +import numpy as np + +from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.rewrite.core.utils.utils import save + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +from tensorflow.lite.python import interpreter as interpreter_wrapper + + +def make_decode_fn(filename): + def decode_fn(record_bytes, type_map): + parse_dict = { + name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys() + } + example = tf.io.parse_single_example(record_bytes, parse_dict) + features = { + n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) + for n, t in type_map.items() + } + return features + + meta_filename = filename + ".meta" + with open(meta_filename) as f: + type_map = json.load(f)["type_map"] + return lambda record_bytes: decode_fn(record_bytes, type_map) + + +def NumpyTFReader(filename): + decode_fn = make_decode_fn(filename) + dataset = tf.data.TFRecordDataset(filename) + return dataset.map(decode_fn) + + +def numpytf_count(filename): + meta_filename = filename + ".meta" + with open(meta_filename) as f: + return json.load(f)["count"] + + +class NumpyTFWriter: + def __init__(self, filename): + self.filename = filename + self.meta_filename = filename + ".meta" + self.writer = tf.io.TFRecordWriter(filename) + self.type_map = {} + self.count = 0 + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + def __del__(self): + self.close() + + def write(self, array_dict): + type_map = {n: str(a.dtype.name) for n, a in array_dict.items()} + self.type_map.update(type_map) + self.count += 1 + + feature = { + n: tf.train.Feature( + bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(a).numpy()]) + ) + for n, a in array_dict.items() + } + example = tf.train.Example(features=tf.train.Features(feature=feature)) + self.writer.write(example.SerializeToString()) + + def close(self): + with open(self.meta_filename, "w") as f: + meta = {"type_map": self.type_map, "count": self.count} + json.dump(meta, f) + self.writer.close() + + +class TFLiteModel: + def __init__(self, filename, batch_size=None, num_threads=None): + if num_threads == 0: + num_threads = None + if batch_size == None: + self.interpreter = interpreter_wrapper.Interpreter( + model_path=filename, num_threads=num_threads + ) + else: # if a batch size is specified, modify the TFLite model to use this size + with tempfile.TemporaryDirectory() as tmp: + fb = load(filename) + for sg in fb.subgraphs: + for t in list(sg.inputs) + list(sg.outputs): + sg.tensors[t].shape = np.array( + [batch_size] + list(sg.tensors[t].shape[1:]), dtype=np.int32 + ) + tempname = os.path.join(tmp, "rewrite_tmp.tflite") + save(fb, tempname) + self.interpreter = interpreter_wrapper.Interpreter( + model_path=tempname, num_threads=num_threads + ) + + try: + self.interpreter.allocate_tensors() + except RuntimeError: + self.interpreter = interpreter_wrapper.Interpreter( + model_path=filename, num_threads=num_threads + ) + self.interpreter.allocate_tensors() + + # Get input and output tensors. + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + details = list(self.input_details) + list(self.output_details) + self.handle_from_name = {d["name"]: d["index"] for d in details} + self.shape_from_name = {d["name"]: d["shape"] for d in details} + self.batch_size = next(iter(self.shape_from_name.values()))[0] + + def __call__(self, named_input): + """Execute the model on one or a batch of named inputs (a dict of name: numpy array)""" + input_len = next(iter(named_input.values())).shape[0] + full_steps = input_len // self.batch_size + remainder = input_len % self.batch_size + + named_ys = defaultdict(list) + for i in range(full_steps): + for name, x_batch in named_input.items(): + x = x_batch[i : i + self.batch_size] + self.interpreter.set_tensor(self.handle_from_name[name], x) + self.interpreter.invoke() + for d in self.output_details: + named_ys[d["name"]].append(self.interpreter.get_tensor(d["index"])) + if remainder: + for name, x_batch in named_input.items(): + x = np.zeros(self.shape_from_name[name]).astype(x_batch.dtype) + x[:remainder] = x_batch[-remainder:] + self.interpreter.set_tensor(self.handle_from_name[name], x) + self.interpreter.invoke() + for d in self.output_details: + named_ys[d["name"]].append( + self.interpreter.get_tensor(d["index"])[:remainder] + ) + return {k: np.concatenate(v) for k, v in named_ys.items()} + + def input_tensors(self): + return [d["name"] for d in self.input_details] + + def output_tensors(self): + return [d["name"] for d in self.output_details] + + +def sample_tfrec(input_file, k, output_file): + total = numpytf_count(input_file) + next = sorted(random.sample(range(total), k=k), reverse=True) + + reader = NumpyTFReader(input_file) + with NumpyTFWriter(output_file) as writer: + for i, data in enumerate(reader): + if i == next[-1]: + next.pop() + writer.write(data) + if not next: + break diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py new file mode 100644 index 0000000..5affc03 --- /dev/null +++ b/src/mlia/nn/rewrite/core/utils/parallel.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import math +import os +from collections import defaultdict +from multiprocessing import Pool + +import numpy as np +from psutil import cpu_count + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel + + +class ParallelTFLiteModel(TFLiteModel): + def __init__(self, filename, num_procs=1, num_threads=0, batch_size=None): + """num_procs: 0 => detect real cores on system + num_threads: 0 => TFLite impl. specific setting, usually 3 + batch_size: None => automatic (num_procs or file-determined) + """ + self.pool = None + self.filename = filename + if not num_procs: + self.num_procs = cpu_count(logical=False) + else: + self.num_procs = int(num_procs) + + self.num_threads = num_threads + + if self.num_procs > 1: + if not batch_size: + batch_size = self.num_procs # default to min effective batch size + local_batch_size = int(math.ceil(batch_size / self.num_procs)) + super().__init__(filename, batch_size=local_batch_size) + del self.interpreter + self.pool = Pool( + processes=self.num_procs, + initializer=_pool_create_worker, + initargs=[filename, self.batch_size, self.num_threads], + ) + else: # fall back to serial implementation for max performance + super().__init__( + filename, batch_size=batch_size, num_threads=self.num_threads + ) + + self.total_batches = 0 + self.partial_batches = 0 + self.warned = False + + def close(self): + if self.pool: + self.pool.close() + self.pool.terminate() + + def __del__(self): + self.close() + + def __call__(self, named_input): + if self.pool: + global_batch_size = next(iter(named_input.values())).shape[0] + # Note: self.batch_size comes from superclass and is local batch size + chunks = int(math.ceil(global_batch_size / self.batch_size)) + self.total_batches += 1 + if chunks != self.num_procs: + self.partial_batches += 1 + if ( + not self.warned + and self.total_batches > 10 + and self.partial_batches / self.total_batches >= 0.5 + ): + print( + "ParallelTFLiteModel(%s): warning - %.1f%% of batches do not use all %d processes, set batch size to a multiple of this" + % ( + self.filename, + 100 * self.partial_batches / self.total_batches, + self.num_procs, + ) + ) + self.warned = True + + local_batches = [ + { + key: values[i * self.batch_size : (i + 1) * self.batch_size] + for key, values in named_input.items() + } + for i in range(chunks) + ] + chunk_results = self.pool.map(_pool_run, local_batches) + named_ys = defaultdict(list) + for chunk in chunk_results: + for k, v in chunk.items(): + named_ys[k].append(v) + return {k: np.concatenate(v) for k, v in named_ys.items()} + else: + return super().__call__(named_input) + + +_local_model = None + + +def _pool_create_worker(filename, local_batch_size=None, num_threads=None): + global _local_model + _local_model = TFLiteModel( + filename, batch_size=local_batch_size, num_threads=num_threads + ) + + +def _pool_run(named_inputs): + return _local_model(named_inputs) diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py new file mode 100644 index 0000000..ed6c81d --- /dev/null +++ b/src/mlia/nn/rewrite/core/utils/utils.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import os + +import flatbuffers +from tensorflow.lite.python import schema_py_generated as schema_fb + + +def load(input_tflite_file): + if not os.path.exists(input_tflite_file): + raise RuntimeError("TFLite file not found at %r\n" % input_tflite_file) + with open(input_tflite_file, "rb") as file_handle: + file_data = bytearray(file_handle.read()) + model_obj = schema_fb.Model.GetRootAsModel(file_data, 0) + model = schema_fb.ModelT.InitFromObj(model_obj) + return model + + +def save(model, output_tflite_file): + builder = flatbuffers.Builder(1024) # Initial size of the buffer, which + # will grow automatically if needed + model_offset = model.Pack(builder) + builder.Finish(model_offset, file_identifier=b"TFL3") + model_data = builder.Output() + with open(output_tflite_file, "wb") as out_file: + out_file.write(model_data) -- cgit v1.2.1