From 867f37d643e66c0223457c28f5345f2f21db97f2 Mon Sep 17 00:00:00 2001 From: Annie Tallund Date: Wed, 15 Mar 2023 11:27:08 +0100 Subject: Adapt rewrite module to MLIA coding standards - Fix imports - Update variable names - Refactor helper functions - Add licence headers - Add docstrings - Use f-strings rather than % notation - Create type annotations in rewrite module - Migrate from tqdm to rich progress bar - Use logging module in rewrite module: All print statements are replaced with logging module Resolves: MLIA-831, MLIA-842, MLIA-844, MLIA-846 Signed-off-by: Benjamin Klimczak Change-Id: Idee37538d72b9f01128a894281a8d10155f7c17c --- setup.cfg | 1 - src/mlia/nn/rewrite/__init__.py | 1 + src/mlia/nn/rewrite/core/__init__.py | 1 + src/mlia/nn/rewrite/core/extract.py | 35 +- src/mlia/nn/rewrite/core/graph_edit/__init__.py | 1 + src/mlia/nn/rewrite/core/graph_edit/cut.py | 139 ++++---- src/mlia/nn/rewrite/core/graph_edit/diff.py | 102 +++--- src/mlia/nn/rewrite/core/graph_edit/join.py | 111 ++++-- src/mlia/nn/rewrite/core/graph_edit/record.py | 51 +-- src/mlia/nn/rewrite/core/rewrite.py | 3 +- src/mlia/nn/rewrite/core/train.py | 433 +++++++++++++---------- src/mlia/nn/rewrite/core/utils/__init__.py | 1 + src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py | 140 +++++--- src/mlia/nn/rewrite/core/utils/parallel.py | 90 +++-- src/mlia/nn/rewrite/core/utils/utils.py | 22 +- src/mlia/nn/select.py | 14 +- tests/test_nn_rewrite_core_graph_edit_cut.py | 4 +- tests/test_nn_rewrite_core_graph_edit_record.py | 4 +- tests/test_nn_rewrite_core_train.py | 14 +- 19 files changed, 691 insertions(+), 476 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7cdd3c5..5a68b6b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,6 @@ install_requires = requests~=2.31.0 rich~=13.5.2 tomli~=2.0.1 ; python_version<"3.11" - tqdm~=4.65.0 [options.packages.find] where = src diff --git a/src/mlia/nn/rewrite/__init__.py b/src/mlia/nn/rewrite/__init__.py index 8c1f750..74298f6 100644 --- a/src/mlia/nn/rewrite/__init__.py +++ b/src/mlia/nn/rewrite/__init__.py @@ -1,2 +1,3 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Rewrite module.""" diff --git a/src/mlia/nn/rewrite/core/__init__.py b/src/mlia/nn/rewrite/core/__init__.py index 8c1f750..8816dc1 100644 --- a/src/mlia/nn/rewrite/core/__init__.py +++ b/src/mlia/nn/rewrite/core/__init__.py @@ -1,2 +1,3 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Rewrite core module.""" diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py index 5fcd348..f609955 100644 --- a/src/mlia/nn/rewrite/core/extract.py +++ b/src/mlia/nn/rewrite/core/extract.py @@ -1,28 +1,33 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Extract module.""" +# pylint: disable=too-many-arguments, too-many-locals import os -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf - -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from tensorflow.lite.python.schema_py_generated import SubGraphT from mlia.nn.rewrite.core.graph_edit.cut import cut_model from mlia.nn.rewrite.core.graph_edit.record import record_model +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + + def extract( - output_path, - model_file, - input_data, - input_names, - output_names, - subgraph=0, - skip_outputs=False, - show_progress=False, - num_procs=1, - num_threads=0, -): + output_path: str, + model_file: str, + input_filename: str, + input_names: list, + output_names: list, + subgraph: SubGraphT = 0, + skip_outputs: bool = False, + show_progress: bool = False, + num_procs: int = 1, + num_threads: int = 0, +) -> None: + """Extract a model after cut and record.""" try: os.mkdir(output_path) except FileExistsError: @@ -39,7 +44,7 @@ def extract( input_tfrec = os.path.join(output_path, "input.tfrec") record_model( - input_data, + input_filename, start_file, input_tfrec, show_progress=show_progress, diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py index 8c1f750..273f4a4 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/__init__.py +++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py @@ -1,2 +1,3 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Rewrite core graph edit module.""" diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py index a323b7b..2707eb1 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/cut.py +++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py @@ -1,128 +1,143 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Cut module.""" import os from collections import defaultdict +from typing import Optional -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT +from tensorflow.lite.python.schema_py_generated import SubGraphT -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.rewrite.core.utils.utils import save -from mlia.nn.rewrite.core.utils.utils import load, save +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -def cut_subgraph(subgraph, input_tensor_names, output_tensor_names): - """Change the global inputs and outputs of a graph to the provided named tensors""" +def tensors_by_name(subgraph: SubGraphT, names: list) -> list: + """Seek out tensors from a subgraph and return the result.""" + seek = frozenset([name.encode("utf-8") for name in names]) + tensors = [i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek] + return tensors - def tensors_by_name(names): - seek = frozenset([name.encode("utf-8") for name in names]) - tensors = [ - i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek - ] - return tensors +def cut_subgraph( + subgraph: SubGraphT, + input_tensor_names: Optional[list], + output_tensor_names: Optional[list], +) -> None: + """Change the global inputs and outputs of a graph to the provided named tensors.""" if input_tensor_names is not None: - subgraph.inputs = tensors_by_name(input_tensor_names) + subgraph.inputs = tensors_by_name(subgraph, input_tensor_names) assert len(subgraph.inputs) == len( input_tensor_names - ), "Expected %d input tensors: %s\nFound: %s" % ( - len(subgraph.inputs), - ", ".join(input_tensor_names), - ", ".join(subgraph.tensors[i].name for i in subgraph.inputs), - ) - + ), f"Expected {len(subgraph.inputs)} input tensors: \ + {', '.join(input_tensor_names)}\nFound: \ + {', '.join(subgraph.tensors[i].name for i in subgraph.inputs)}" if output_tensor_names is not None: - subgraph.outputs = tensors_by_name(output_tensor_names) + subgraph.outputs = tensors_by_name(subgraph, output_tensor_names) assert len(subgraph.outputs) == len( output_tensor_names - ), "Expected %d output tensors: %s\nFound: %s" % ( - len(subgraph.outputs), - ", ".join(output_tensor_names), - ", ".join(subgraph.tensors[i].name for i in subgraph.outputs), - ) + ), f"Expected {len(subgraph.outputs)} output tensors: \ + {', '.join(output_tensor_names)}\nFound: \ + {', '.join(subgraph.tensors[i].name for i in subgraph.outputs)}" -def simplify(model): - """Remove any unused operators, tensors and buffers from a model""" - for s in model.subgraphs: - simplify_subgraph(s) +def simplify(model: ModelT) -> None: + """Remove any unused operators, tensors and buffers from a model.""" + for subgraph in model.subgraphs: + simplify_subgraph(subgraph) - used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs} - used_buffers = used_buffers.union({m.buffer for m in model.metadata}) + used_buffers = { + tensor.buffer for tensor in subgraph.tensors for subgraph in model.subgraphs + } + used_buffers = used_buffers.union({metadata.buffer for metadata in model.metadata}) used_buffers.add( 0 - ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime + ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the + # TFLite runtime model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers) - for s in model.subgraphs: - for t in s.tensors: - t.buffer = buf_relabel[t.buffer] + for subgraph in model.subgraphs: + for tensor in subgraph.tensors: + tensor.buffer = buf_relabel[tensor.buffer] - for m in model.metadata: - m.buffer = buf_relabel[m.buffer] + for metadata in model.metadata: + metadata.buffer = buf_relabel[metadata.buffer] -def simplify_subgraph(subgraph): +def simplify_subgraph(subgraph: SubGraphT) -> None: + """Simplify a subgraph given its operators.""" requires = defaultdict(set) - for o, operator in enumerate(subgraph.operators): - for t in operator.outputs: - if not t in subgraph.inputs: - requires[t].add(o) + for output, operator in enumerate(subgraph.operators): + for tensor in operator.outputs: + if tensor not in subgraph.inputs: + requires[tensor].add(output) op_set, ten_set = find_required(subgraph, requires, subgraph.outputs) - subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set) + subgraph.operators, _ = filter_relabel(subgraph.operators, op_set) subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set) ten_relabel[-1] = -1 # Some files have ops with -1 input tensors; leave unchanged - for op in subgraph.operators: - op.inputs = [ten_relabel[t] for t in op.inputs] - op.outputs = [ten_relabel[t] for t in op.outputs] + for operator in subgraph.operators: + operator.inputs = [ten_relabel[tensor] for tensor in operator.inputs] + operator.outputs = [ten_relabel[tensor] for tensor in operator.outputs] - subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs] - subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs] + subgraph.inputs = [ten_relabel[tensor] for tensor in subgraph.inputs] + subgraph.outputs = [ten_relabel[tensors] for tensors in subgraph.outputs] -def find_required(subgraph, requires, tensors): - visited_operators = set() +def find_required(subgraph: SubGraphT, requires: dict, tensors: dict) -> tuple: + """Find required operators in a given subgraph.""" + visited_operators: set = set() visited_tensors = set(tensors) stop_tensors = set(subgraph.inputs) - changed = True next_tensors = visited_tensors while next_tensors: loop_tensors = next_tensors next_tensors = set() - for t in loop_tensors: - candidate_operators = set(requires[t]) + for tensor in loop_tensors: + candidate_operators = set(requires[tensor]) new_operators = candidate_operators - visited_operators visited_operators = visited_operators.union(new_operators) - for op in new_operators: - candidate_tensors = set(subgraph.operators[op].inputs) + for operator in new_operators: + candidate_tensors = set(subgraph.operators[operator].inputs) new_tensors = candidate_tensors - (visited_tensors.union(next_tensors)) next_tensors = next_tensors.union(new_tensors) visited_tensors = visited_tensors.union(candidate_tensors) visited_tensors = visited_tensors.union( - subgraph.operators[op].outputs + subgraph.operators[operator].outputs ) # include stub outputs but do not traverse them next_tensors = next_tensors - stop_tensors return visited_operators, visited_tensors -def filter_relabel(src, filter): - relabel = {} - output = [] - for i, x in enumerate(src): - if i in filter: +def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple: + """Relabel tensors in a subgraph based on a filter.""" + relabel: dict = {} + output: list = [] + for i, out in enumerate(src_subgraph): + if i in relabel_filter: relabel[i] = len(output) - output.append(x) + output.append(out) return output, relabel -def cut_model(model_file, input_names, output_names, subgraph_index, output_file): +def cut_model( + model_file: str, + input_names: Optional[list], + output_names: Optional[list], + subgraph_index: int, + output_file: str, +) -> None: + """Cut subgraphs and simplify a given model.""" model = load(model_file) subgraph = model.subgraphs[subgraph_index] cut_subgraph(subgraph, input_names, output_names) diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py index 0829f0a..198e47e 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/diff.py +++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py @@ -1,63 +1,79 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Diff module: compare subgraph outputs.""" +# pylint: disable=too-many-locals +from __future__ import annotations + import os from collections import defaultdict +from pathlib import Path +from typing import Any -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import numpy as np import tensorflow as tf +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -import numpy as np -from tensorflow.lite.python import interpreter as interpreter_wrapper -from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter +def dict_mean(mean_dict: dict) -> Any: + """Return the mean of values in a given dict.""" + return np.mean(list(mean_dict.values())) -def diff_stats(file1, file2, per_tensor_and_channel=False): - dataset1 = NumpyTFReader(file1) - dataset2 = NumpyTFReader(file2) - totals = defaultdict(dict) +def add_total(name: str, key: str, values: list, totals: dict) -> None: + """Append values to dict totals.""" + if key not in totals[name]: + totals[name][key] = values + else: + totals[name][key] += values + - def add_total(name, key, values): - if not key in totals[name]: - totals[name][key] = values - else: - totals[name][key] += values +def diff_stats( + file1: str | Path, file2: str | Path, per_tensor_and_channel: bool = False +) -> tuple: + """Compare the statistics of outputs between two subgraphs.""" + dataset1 = numpytf_read(file1) + dataset2 = numpytf_read(file2) - # First iterate through dataset1 and calculate per-channel total for each tensor + totals: dict = defaultdict(dict) + + # First iterate through dataset and calculate per-channel total for each tensor count = 0 - for d in dataset1: + for data in dataset1: count += 1 - for k, v in d.items(): - value = v.numpy().astype(np.double) - add_total("dataset1_total", k, value) + for key, val in data.items(): + value = val.numpy().astype(np.double) + add_total("dataset1_total", key, value, totals) # Use this to calculate per-channel mean for each tensor - per_tensor_mean = lambda name: { - k: total / count for k, total in totals[name].items() - } + def per_tensor_mean(name: str) -> dict: + return {k: total / count for k, total in totals[name].items()} + dataset1_mean = per_tensor_mean("dataset1_total") - # Next iterate through both datasets and calculate per-channel total squared error - # between them for each tensor and dataset1 variance for each tensor using the mean from above - for i, (x1, x2) in enumerate(zip(dataset1, dataset2)): - assert x1.keys() == x2.keys(), ( - "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n" - % ( - i, - file1, - ", ".join(x1.keys()), - file2, - ", ".join(x2.keys()), - ) + # Next iterate through both datasets and calculate per-channel total squared + # error between them for each tensor and dataset1 variance for each tensor + # using the mean from above + for i, (ds1, ds2) in enumerate(zip(dataset1, dataset2)): + assert ds1.keys() == ds2.keys(), ( + f"At input {i} the files have different sets of tensors.\n" + f"{file1}: {', '.join(ds1.keys())}\n" + f"{file2}: {', '.join(ds2.keys())}\n" ) - for k in x1.keys(): - v1 = x1[k].numpy().astype(np.double) - v2 = x2[k].numpy().astype(np.double) - add_total("ae", k, abs(v1 - v2)) - add_total("se", k, (v1 - v2) ** 2) - add_total("dataset1_variance", k, (v1 - dataset1_mean[k]) ** 2) + for key in ds1.keys(): + tensor1 = ds1[key].numpy().astype(np.double) + tensor2 = ds2[key].numpy().astype(np.double) + add_total("ae", key, abs(tensor1 - tensor2), totals) + add_total("se", key, (tensor1 - tensor2) ** 2, totals) + add_total( + "dataset1_variance", + key, + (tensor1 - dataset1_mean[key]) ** 2, + totals, + ) # Finally average over number of inputs to get the rmse and the dataset1 variance mae = per_tensor_mean("ae") @@ -66,7 +82,8 @@ def diff_stats(file1, file2, per_tensor_and_channel=False): dataset1_var = per_tensor_mean("dataset1_variance") is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var} - # Divide by target standard deviation to get the per-channel nrmse for each tensor where possible + # Divide by target standard deviation to get the per-channel nrmse for each + # tensor where possible nrmse = { k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]]) for k, v in rmse.items() @@ -74,6 +91,5 @@ def diff_stats(file1, file2, per_tensor_and_channel=False): if per_tensor_and_channel: return mae, nrmse - else: - dict_mean = lambda d: np.mean(list(d.values())) - return dict_mean(mae), dict_mean(nrmse) + + return dict_mean(mae), dict_mean(nrmse) diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py index 758f4cf..14a7347 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/join.py +++ b/src/mlia/nn/rewrite/core/graph_edit/join.py @@ -1,16 +1,31 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Join module.""" +from __future__ import annotations + import os +from pathlib import Path -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT +from tensorflow.lite.python.schema_py_generated import OperatorCodeT +from tensorflow.lite.python.schema_py_generated import SubGraphT -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.rewrite.core.utils.utils import save -from mlia.nn.rewrite.core.utils.utils import load, save +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=0): +def join_models( + input_src: str | Path, + input_dst: str | Path, + output_file: str | Path, + subgraph_src: SubGraphT = 0, + subgraph_dst: SubGraphT = 0, +) -> None: + """Join two models and save the result into a given model file path.""" src_model = load(input_src) dst_model = load(input_dst) src_subgraph = src_model.subgraphs[subgraph_src] @@ -19,8 +34,13 @@ def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst= save(dst_model, output_file) -def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): - """Copy subgraph src into subgraph dst from model, connecting tensors with the same names""" +def join_subgraphs( + src_model: ModelT, + src_subgraph: SubGraphT, + dst_model: ModelT, + dst_subgraph: SubGraphT, +) -> None: + """Join two subgraphs, connecting tensors with the same names.""" # Find inputs that match outputs in the other graph and vice versa dst_to_src = { i: o @@ -36,10 +56,11 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name } - assert not (src_to_dst and dst_to_src), ( - "Source and destination subgraphs appear to connect in a loop: %d tensors from src to dst, %d tensors from dst to src" - % (len(src_to_dst), len(dst_to_src)) - ) + assert not ( + src_to_dst and dst_to_src + ), f"Source and destination subgraphs appear to connect in a loop: \ + {len(src_to_dst)} tensors from src to dst, {len(dst_to_src)} \ + tensors from dst to src" # Relabel matched input/output tensors between graphs tensor_relabel = src_to_dst if src_to_dst else dst_to_src @@ -47,36 +68,46 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): # Remove matched inputs/outputs as these will now become internal tensors if src_to_dst: src_subgraph.outputs = [ - o for o in src_subgraph.outputs if not o in tensor_relabel.keys() + output + for output in src_subgraph.outputs + if output not in tensor_relabel.keys() ] dst_subgraph.inputs = [ - i for i in dst_subgraph.inputs if not i in tensor_relabel.values() + input + for input in dst_subgraph.inputs + if input not in tensor_relabel.values() ] else: src_subgraph.inputs = [ - i for i in src_subgraph.inputs if not i in tensor_relabel.keys() + input for input in src_subgraph.inputs if input not in tensor_relabel.keys() ] dst_subgraph.outputs = [ - o for o in dst_subgraph.outputs if not o in tensor_relabel.values() + output + for output in dst_subgraph.outputs + if output not in tensor_relabel.values() ] buffer_relabel = { - src_subgraph.tensors[i].buffer: dst_subgraph.tensors[o].buffer - for i, o in tensor_relabel.items() + src_subgraph.tensors[input].buffer: dst_subgraph.tensors[output].buffer + for input, output in tensor_relabel.items() } used_tensors = [ - t for i, t in enumerate(src_subgraph.tensors) if not i in tensor_relabel + tensor + for i, tensor in enumerate(src_subgraph.tensors) + if i not in tensor_relabel ] - used_buffer_ids = [t.buffer for t in used_tensors] + used_buffer_ids = [tensor.buffer for tensor in used_tensors] + + def opcode_data(code: OperatorCodeT) -> tuple: + return ( + code.builtinCode, + code.deprecatedBuiltinCode, + code.customCode, + code.version, + ) - opcode_data = lambda c: ( - c.builtinCode, - c.deprecatedBuiltinCode, - c.customCode, - c.version, - ) opcode_relabel = { s: d for s in range(len(src_model.operatorCodes)) @@ -85,7 +116,8 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): == opcode_data(dst_model.operatorCodes[d]) } - # operator order defines execution schedule so must reflect the inputs/outputs dependencies + # operator order defines execution schedule so must reflect + # the inputs/outputs dependencies if dst_to_src: dst_subgraph.operators += src_subgraph.operators else: @@ -99,17 +131,17 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): ] = -1 # Some files have ops with -1 input tensors; leave unchanged for i in used_buffer_ids: - if not i in buffer_relabel: + if i not in buffer_relabel: buffer_relabel[i] = len(dst_model.buffers) dst_model.buffers.append(src_model.buffers[i]) - for o in src_subgraph.operators: - o.inputs = [tensor_relabel[t] for t in o.inputs] - o.outputs = [tensor_relabel[t] for t in o.outputs] - o.opcodeIndex = opcode_relabel[o.opcodeIndex] + for operator in src_subgraph.operators: + operator.inputs = [tensor_relabel[tensor] for tensor in operator.inputs] + operator.outputs = [tensor_relabel[tensor] for tensor in operator.outputs] + operator.opcodeIndex = opcode_relabel[operator.opcodeIndex] - for t in used_tensors: - t.buffer = buffer_relabel[t.buffer] + for tensor in used_tensors: + tensor.buffer = buffer_relabel[tensor.buffer] src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs] src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs] @@ -118,11 +150,12 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs)) -def append_relabel(src, dst, map=None): - if map is None: - map = {} - for i, x in enumerate(src): - if not i in map: - map[i] = len(dst) +def append_relabel(src: list, dst: list, operator_map: dict | None = None) -> dict: + """Return a map over relabeled tensors in a subgraph.""" + if not operator_map: + operator_map = {} + for i, x in enumerate(src): # pylint: disable=invalid-name + if i not in operator_map: + operator_map[i] = len(dst) dst.append(x) - return map + return operator_map diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index ae13313..90f3db8 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -1,34 +1,39 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Save subgraph data.""" +# pylint: disable=too-many-locals +from __future__ import annotations + import math import os +from pathlib import Path -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf +from rich.progress import track -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) - -from tqdm import tqdm -from mlia.nn.rewrite.core.utils.numpy_tfrecord import ( - NumpyTFReader, - NumpyTFWriter, - TFLiteModel, - numpytf_count, -) +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read +from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + def record_model( - input_filename, - model_filename, - output_filename, - batch_size=None, - show_progress=False, - num_procs=1, - num_threads=0, -): - """num_procs: 0 => detect real cores on system - num_threads: 0 => TFLite impl. specific setting, usually 3""" + input_filename: str | Path, + model_filename: str | Path, + output_filename: str | Path, + batch_size: int = 0, + show_progress: bool = False, + num_procs: int = 1, + num_threads: int = 0, +) -> None: + """Model recorder. + + num_procs: 0 => detect real cores on system + num_threads: 0 => TFLite impl. specific setting, usually 3 + """ model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size) if not batch_size: batch_size = ( @@ -36,10 +41,10 @@ def record_model( ) # automatically batch to the minimum effective size if not specified total = numpytf_count(input_filename) - dataset = NumpyTFReader(input_filename) + dataset = numpytf_read(input_filename) if batch_size > 1: - # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now. + # Collapse batch-size 1 items into batch-size n. dataset = dataset.map( lambda d: {k: tf.squeeze(v, axis=0) for k, v in d.items()} ) @@ -48,7 +53,7 @@ def record_model( with NumpyTFWriter(output_filename) as writer: for _, named_x in enumerate( - tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress) + track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) ): named_y = model(named_x) if batch_size > 1: diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index d4f61c5..ab34b47 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -42,4 +42,5 @@ class Rewriter(Optimizer): return self.model def optimization_config(self) -> str: - """Optimization configirations.""" + """Optimization configurations.""" + return str(self.optimizer_configuration) diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 096daf4..f837964 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -1,35 +1,41 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Sequential trainer.""" +# pylint: disable=too-many-arguments, too-many-instance-attributes, +# pylint: disable=too-many-locals, too-many-branches, too-many-statements +from __future__ import annotations + +import logging import math import os import tempfile from collections import defaultdict +from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import get_args +from typing import Literal -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import numpy as np import tensorflow as tf +from numpy.random import Generator -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) - -try: - from tensorflow.keras.optimizers.schedules import CosineDecay -except ImportError: - # In TF 2.4 CosineDecay was still experimental - from tensorflow.keras.experimental import CosineDecay - -import numpy as np -from mlia.nn.rewrite.core.utils.numpy_tfrecord import ( - NumpyTFReader, - NumpyTFWriter, - TFLiteModel, - numpytf_count, -) -from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel -from mlia.nn.rewrite.core.graph_edit.record import record_model -from mlia.nn.rewrite.core.utils.utils import load, save from mlia.nn.rewrite.core.extract import extract -from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.graph_edit.diff import diff_stats +from mlia.nn.rewrite.core.graph_edit.join import join_models +from mlia.nn.rewrite.core.graph_edit.record import record_model +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read +from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel +from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.rewrite.core.utils.utils import save +from mlia.utils.logging import log_action +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +logger = logging.getLogger(__name__) augmentation_presets = { "none": (None, None), @@ -40,31 +46,34 @@ augmentation_presets = { "mix_gaussian_small": (1.6, 0.3), } -learning_rate_schedules = {"cosine", "late", "constant"} +LearningRateSchedule = Literal["cosine", "late", "constant"] +LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule) def train( - source_model, - unmodified_model, - output_model, - input_tfrec, - replace_fn, - input_tensors, - output_tensors, - augment, - steps, - lr, - batch_size, - verbose, - show_progress, - learning_rate_schedule="cosine", - checkpoint_at=None, - checkpoint_decay_steps=0, - num_procs=1, - num_threads=0, -): + source_model: str, + unmodified_model: Any, + output_model: str, + input_tfrec: str, + replace_fn: Callable, + input_tensors: list, + output_tensors: list, + augment: tuple[float | None, float | None], + steps: int, + learning_rate: float, + batch_size: int, + verbose: bool, + show_progress: bool, + learning_rate_schedule: LearningRateSchedule = "cosine", + checkpoint_at: list | None = None, + num_procs: int = 1, + num_threads: int = 0, +) -> Any: + """Extract and train a model, and return the results.""" if unmodified_model: - unmodified_model_dir = tempfile.TemporaryDirectory() + unmodified_model_dir = ( + tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + ) unmodified_model_dir_path = unmodified_model_dir.name extract( unmodified_model_dir_path, @@ -79,8 +88,6 @@ def train( results = [] with tempfile.TemporaryDirectory() as train_dir: - p = lambda file: os.path.join(train_dir, file) - extract( train_dir, source_model, @@ -94,14 +101,13 @@ def train( tflite_filenames = train_in_dir( train_dir, unmodified_model_dir_path, - p("new.tflite"), + Path(train_dir, "new.tflite"), replace_fn, augment, steps, - lr, + learning_rate, batch_size, checkpoint_at=checkpoint_at, - checkpoint_decay_steps=checkpoint_decay_steps, verbose=verbose, show_progress=show_progress, num_procs=num_procs, @@ -114,7 +120,8 @@ def train( if output_model: if i + 1 < len(tflite_filenames): - # Append the same _@STEPS.tflite postfix used by intermediate checkpoints for all but the last output + # Append the same _@STEPS.tflite postfix used by intermediate + # checkpoints for all but the last output postfix = filename.split("_@")[-1] output_filename = output_model.split(".tflite")[0] + postfix else: @@ -122,115 +129,130 @@ def train( join_in_dir(train_dir, filename, output_filename) if unmodified_model_dir: - unmodified_model_dir.cleanup() + cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() return ( results if checkpoint_at else results[0] ) # only return a list if multiple checkpoints are asked for -def eval_in_dir(dir, new_part, num_procs=1, num_threads=0): - p = lambda file: os.path.join(dir, file) - input = ( - p("input_orig.tfrec") - if os.path.exists(p("input_orig.tfrec")) - else p("input.tfrec") +def eval_in_dir( + target_dir: str, new_part: str, num_procs: int = 1, num_threads: int = 0 +) -> tuple: + """Evaluate a model in a given directory.""" + model_input_path = Path(target_dir, "input_orig.tfrec") + model_output_path = Path(target_dir, "output_orig.tfrec") + model_input = ( + model_input_path + if model_input_path.exists() + else Path(target_dir, "input.tfrec") ) output = ( - p("output_orig.tfrec") - if os.path.exists(p("output_orig.tfrec")) - else p("output.tfrec") + model_output_path + if model_output_path.exists() + else Path(target_dir, "output.tfrec") ) with tempfile.TemporaryDirectory() as tmp_dir: - predict = os.path.join(tmp_dir, "predict.tfrec") + predict = Path(tmp_dir, "predict.tfrec") record_model( - input, new_part, predict, num_procs=num_procs, num_threads=num_threads + str(model_input), + new_part, + str(predict), + num_procs=num_procs, + num_threads=num_threads, ) - mae, nrmse = diff_stats(output, predict) + mae, nrmse = diff_stats(str(output), str(predict)) return mae, nrmse -def join_in_dir(dir, new_part, output_model): +def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None: + """Join two models in a given directory.""" with tempfile.TemporaryDirectory() as tmp_dir: - d = lambda file: os.path.join(dir, file) - new_end = os.path.join(tmp_dir, "new_end.tflite") - join_models(new_part, d("end.tflite"), new_end) - join_models(d("start.tflite"), new_end, output_model) + new_end = Path(tmp_dir, "new_end.tflite") + join_models(new_part, Path(model_dir, "end.tflite"), new_end) + join_models(Path(model_dir, "start.tflite"), new_end, output_model) def train_in_dir( - train_dir, - baseline_dir, - output_filename, - replace_fn, - augmentations, - steps, - lr=1e-3, - batch_size=32, - checkpoint_at=None, - checkpoint_decay_steps=0, - schedule="cosine", - verbose=False, - show_progress=False, - num_procs=None, - num_threads=1, -): - """Train a replacement for replace.tflite using the input.tfrec and output.tfrec in train_dir. - If baseline_dir is provided, train the replacement to match baseline outputs for train_dir inputs. - Result saved as new.tflite in train_dir. + train_dir: str, + baseline_dir: Any, + output_filename: Path, + replace_fn: Callable, + augmentations: tuple[float | None, float | None], + steps: int, + learning_rate: float = 1e-3, + batch_size: int = 32, + checkpoint_at: list | None = None, + schedule: str = "cosine", + verbose: bool = False, + show_progress: bool = False, + num_procs: int = 0, + num_threads: int = 1, +) -> list: + """Train a replacement for replace.tflite using the input.tfrec \ + and output.tfrec in train_dir. + + If baseline_dir is provided, train the replacement to match baseline + outputs for train_dir inputs. Result saved as new.tflite in train_dir. """ teacher_dir = baseline_dir if baseline_dir else train_dir teacher = ParallelTFLiteModel( - "%s/replace.tflite" % teacher_dir, num_procs, num_threads, batch_size=batch_size - ) - replace = TFLiteModel("%s/replace.tflite" % train_dir) - assert len(teacher.input_tensors()) == 1, ( - "Can only train replacements with a single input tensor right now, found %s" - % teacher.input_tensors() - ) - assert len(teacher.output_tensors()) == 1, ( - "Can only train replacements with a single output tensor right now, found %s" - % teacher.output_tensors() + f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size ) + replace = TFLiteModel(f"{train_dir}/replace.tflite") + assert ( + len(teacher.input_tensors()) == 1 + ), f"Can only train replacements with a single input tensor right now, \ + found {teacher.input_tensors()}" + + assert ( + len(teacher.output_tensors()) == 1 + ), f"Can only train replacements with a single output tensor right now, \ + found {teacher.output_tensors()}" + input_name = teacher.input_tensors()[0] output_name = teacher.output_tensors()[0] assert len(teacher.shape_from_name) == len( replace.shape_from_name - ), "Baseline and train models must have the same number of inputs and outputs. Teacher: {}\nTrain dir: {}".format( - teacher.shape_from_name, replace.shape_from_name - ) + ), f"Baseline and train models must have the same number of inputs and outputs. \ + Teacher: {teacher.shape_from_name}\nTrain dir: {replace.shape_from_name}" + assert all( tn == rn and (ts[1:] == rs[1:]).all() for (tn, ts), (rn, rs) in zip( teacher.shape_from_name.items(), replace.shape_from_name.items() ) - ), "Baseline and train models must have the same input and output shapes for the subgraph being replaced. Teacher: {}\nTrain dir: {}".format( - teacher.shape_from_name, replace.shape_from_name - ) + ), "Baseline and train models must have the same input and output shapes for the \ + subgraph being replaced. Teacher: {teacher.shape_from_name}\n \ + Train dir: {replace.shape_from_name}" - input_filename = os.path.join(train_dir, "input.tfrec") - total = numpytf_count(input_filename) - dict_inputs = NumpyTFReader(input_filename) + input_filename = Path(train_dir, "input.tfrec") + total = numpytf_count(str(input_filename)) + dict_inputs = numpytf_read(str(input_filename)) inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0)) if any(augmentations): - # Map the teacher inputs here because the augmentation stage passes these through a TFLite model to get the outputs - teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "input.tfrec")).map( + # Map the teacher inputs here because the augmentation stage passes these + # through a TFLite model to get the outputs + teacher_outputs = numpytf_read(str(Path(teacher_dir, "input.tfrec"))).map( lambda d: tf.squeeze(d[input_name], axis=0) ) else: - teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "output.tfrec")).map( + teacher_outputs = numpytf_read(str(Path(teacher_dir, "output.tfrec"))).map( lambda d: tf.squeeze(d[output_name], axis=0) ) steps_per_epoch = math.ceil(total / batch_size) epochs = int(math.ceil(steps / steps_per_epoch)) if verbose: - print( - "Training on %d items for %d steps (%d epochs with batch size %d)" - % (total, epochs * steps_per_epoch, epochs, batch_size) + logger.info( + "Training on %d items for %d steps (%d epochs with batch size %d)", + total, + epochs * steps_per_epoch, + epochs, + batch_size, ) dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) @@ -240,13 +262,21 @@ def train_in_dir( if any(augmentations): augment_train, augment_teacher = augment_fn_twins(dict_inputs, augmentations) - augment_fn = lambda train, teach: ( - augment_train({input_name: train})[input_name], - teacher(augment_teacher({input_name: teach}))[output_name], - ) + + def get_augment_results( + train: Any, teach: Any # pylint: disable=redefined-outer-name + ) -> tuple: + """Return results of train and teach based on augmentations.""" + return ( + augment_train({input_name: train})[input_name], + teacher(augment_teacher({input_name: teach}))[output_name], + ) + dataset = dataset.map( - lambda train, teach: tf.py_function( - augment_fn, inp=[train, teach], Tout=[tf.float32, tf.float32] + lambda augment_train, augment_teach: tf.py_function( + get_augment_results, + inp=[augment_train, augment_teach], + Tout=[tf.float32, tf.float32], ) ) @@ -256,7 +286,7 @@ def train_in_dir( output_shape = teacher.shape_from_name[output_name][1:] model = replace_fn(input_shape, output_shape) - optimizer = tf.keras.optimizers.Nadam(learning_rate=lr) + optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate) loss_fn = tf.keras.losses.MeanSquaredError() model.compile(optimizer=optimizer, loss=loss_fn) @@ -265,20 +295,26 @@ def train_in_dir( steps_so_far = 0 - def cosine_decay(epoch_step, logs): - """Cosine decay from lr at start of the run to zero at the end""" + def cosine_decay( + epoch_step: int, logs: Any # pylint: disable=unused-argument + ) -> None: + """Cosine decay from learning rate at start of the run to zero at the end.""" current_step = epoch_step + steps_so_far - learning_rate = lr * (math.cos(math.pi * current_step / steps) + 1) / 2.0 - tf.keras.backend.set_value(optimizer.learning_rate, learning_rate) + cd_learning_rate = ( + learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0 + ) + tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate) - def late_decay(epoch_step, logs): - """Constant until the last 20% of the run, then linear decay to zero""" + def late_decay( + epoch_step: int, logs: Any # pylint: disable=unused-argument + ) -> None: + """Constant until the last 20% of the run, then linear decay to zero.""" current_step = epoch_step + steps_so_far steps_remaining = steps - current_step decay_length = steps // 5 decay_fraction = min(steps_remaining, decay_length) / decay_length - learning_rate = lr * decay_fraction - tf.keras.backend.set_value(optimizer.learning_rate, learning_rate) + ld_learning_rate = learning_rate * decay_fraction + tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate) if schedule == "cosine": callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)] @@ -287,9 +323,10 @@ def train_in_dir( elif schedule == "constant": callbacks = [] else: - assert schedule not in learning_rate_schedules + assert schedule not in LEARNING_RATE_SCHEDULES raise ValueError( - f'LR schedule "{schedule}" not implemented - expected one of {learning_rate_schedules}.' + f'Learning rate schedule "{schedule}" not implemented - ' + f"expected one of {LEARNING_RATE_SCHEDULES}." ) output_filenames = [] @@ -305,53 +342,66 @@ def train_in_dir( verbose=show_progress, ) steps_so_far += steps_to_train - print( - "lr decayed from %f to %f over %d steps" - % (lr_start, optimizer.learning_rate.numpy(), steps_to_train) + logger.info( + "lr decayed from %f to %f over %d steps", + lr_start, + optimizer.learning_rate.numpy(), + steps_to_train, ) if steps_so_far < steps: - filename, ext = os.path.splitext(output_filename) - checkpoint_filename = filename + ("_@%d" % steps_so_far) + ext + filename, ext = Path(output_filename).parts[1:] + checkpoint_filename = filename + (f"_@{steps_so_far}") + ext else: - checkpoint_filename = output_filename - print("%d/%d: Saved as %s" % (steps_so_far, steps, checkpoint_filename)) - save_as_tflite( - model, - checkpoint_filename, - input_name, - replace.shape_from_name[input_name], - output_name, - replace.shape_from_name[output_name], - ) - output_filenames.append(checkpoint_filename) + checkpoint_filename = str(output_filename) + with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"): + save_as_tflite( + model, + checkpoint_filename, + input_name, + replace.shape_from_name[input_name], + output_name, + replace.shape_from_name[output_name], + ) + output_filenames.append(checkpoint_filename) teacher.close() return output_filenames def save_as_tflite( - keras_model, filename, input_name, input_shape, output_name, output_shape -): + keras_model: tf.keras.Model, + filename: str, + input_name: str, + input_shape: list, + output_name: str, + output_shape: list, +) -> None: + """Save Keras model as TFLite file.""" converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) tflite_model = converter.convert() - with open(filename, "wb") as f: - f.write(tflite_model) + with open(filename, "wb") as file: + file.write(tflite_model) # Now fix the shapes and names to match those we expect - fb = load(filename) - i = fb.subgraphs[0].inputs[0] - fb.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32) - fb.subgraphs[0].tensors[i].name = input_name.encode("utf-8") - o = fb.subgraphs[0].outputs[0] - fb.subgraphs[0].tensors[o].shape = np.array(output_shape, dtype=np.int32) - fb.subgraphs[0].tensors[o].name = output_name.encode("utf-8") - save(fb, filename) - - -def augment_fn_twins(inputs, augmentations): - """Return a pair of twinned augmentation functions with the same sequence of random numbers""" + flatbuffer = load(filename) + i = flatbuffer.subgraphs[0].inputs[0] + flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32) + flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8") + output = flatbuffer.subgraphs[0].outputs[0] + flatbuffer.subgraphs[0].tensors[output].shape = np.array( + output_shape, dtype=np.int32 + ) + flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8") + save(flatbuffer, filename) + + +def augment_fn_twins( + inputs: dict, augmentations: tuple[float | None, float | None] +) -> Any: + """Return a pair of twinned augmentation functions with the same sequence \ + of random numbers.""" seed = np.random.randint(2**32 - 1) rng1 = np.random.default_rng(seed) rng2 = np.random.default_rng(seed) @@ -360,52 +410,67 @@ def augment_fn_twins(inputs, augmentations): ) -def augment_fn(inputs, augmentations, rng): +def augment_fn( + inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator +) -> Any: + """Augmentation module.""" mixup_strength, gaussian_strength = augmentations augments = [] if mixup_strength: mixup_range = (0.5 - mixup_strength / 2, 0.5 + mixup_strength / 2) - augment = lambda d: { - k: mixup(rng, v.numpy(), mixup_range) for k, v in d.items() - } - augments.append(augment) + + def mixup_augment(augment_dict: dict) -> dict: + return { + k: mixup(rng, v.numpy(), mixup_range) for k, v in augment_dict.items() + } + + augments.append(mixup_augment) if gaussian_strength: values = defaultdict(list) - for d in inputs.as_numpy_iterator(): - for k, v in d.items(): - values[k].append(v) + for numpy_dict in inputs.as_numpy_iterator(): + for key, value in numpy_dict.items(): + values[key].append(value) noise_scale = { k: np.std(v, axis=0).astype(np.float32) for k, v in values.items() } - augment = lambda d: { - k: v - + rng.standard_normal(v.shape).astype(np.float32) - * gaussian_strength - * noise_scale[k] - for k, v in d.items() - } - augments.append(augment) - if len(augments) == 0: + def gaussian_strength_augment(augment_dict: dict) -> dict: + return { + k: v + + rng.standard_normal(v.shape).astype(np.float32) + * gaussian_strength + * noise_scale[k] + for k, v in augment_dict.items() + } + + augments.append(gaussian_strength_augment) + + if len(augments) == 0: # pylint: disable=no-else-return return lambda x: x elif len(augments) == 1: return augments[0] elif len(augments) == 2: return lambda x: augments[1](augments[0](x)) else: - assert False, "Unexpected number of augmentation functions (%d)" % len(augments) - - -def mixup(rng, batch, beta_range=(0.0, 1.0)): - """Each tensor in the batch becomes a linear combination of it and one other tensor""" - a = batch - b = np.array(batch) - rng.shuffle(b) # randomly pair up tensors in the batch + assert ( + False + ), f"Unexpected number of augmentation \ + functions ({len(augments)})" + + +def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any: + """Each tensor in the batch becomes a linear combination of it \ + and one other tensor.""" + batch_a = batch + batch_b = np.array(batch) + rng.shuffle(batch_b) # randomly pair up tensors in the batch # random mixing coefficient for each pair beta = rng.uniform( low=beta_range[0], high=beta_range[1], size=batch.shape[0] ).astype(np.float32) - return (a.T * beta).T + (b.T * (1.0 - beta)).T # return linear combinations + return (batch_a.T * beta).T + ( + batch_b.T * (1.0 - beta) + ).T # return linear combinations diff --git a/src/mlia/nn/rewrite/core/utils/__init__.py b/src/mlia/nn/rewrite/core/utils/__init__.py index 8c1f750..f0b5026 100644 --- a/src/mlia/nn/rewrite/core/utils/__init__.py +++ b/src/mlia/nn/rewrite/core/utils/__init__.py @@ -1,2 +1,3 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Rewrite core utils module.""" diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py index 2141003..9229810 100644 --- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py +++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py @@ -1,26 +1,32 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Numpy TFRecord utils.""" +from __future__ import annotations + import json import os import random import tempfile from collections import defaultdict +from pathlib import Path +from typing import Any +from typing import Callable import numpy as np +import tensorflow as tf +from tensorflow.lite.python import interpreter as interpreter_wrapper from mlia.nn.rewrite.core.utils.utils import load from mlia.nn.rewrite.core.utils.utils import save os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -import tensorflow as tf - tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -from tensorflow.lite.python import interpreter as interpreter_wrapper +def make_decode_fn(filename: str) -> Callable: + """Make decode filename.""" -def make_decode_fn(filename): - def decode_fn(record_bytes, type_map): + def decode_fn(record_bytes: Any, type_map: dict) -> dict: parse_dict = { name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys() } @@ -32,38 +38,48 @@ def make_decode_fn(filename): return features meta_filename = filename + ".meta" - with open(meta_filename) as f: - type_map = json.load(f)["type_map"] + with open(meta_filename, encoding="utf-8") as file: + type_map = json.load(file)["type_map"] return lambda record_bytes: decode_fn(record_bytes, type_map) -def NumpyTFReader(filename): - decode_fn = make_decode_fn(filename) - dataset = tf.data.TFRecordDataset(filename) +def numpytf_read(filename: str | Path) -> Any: + """Read TFRecord dataset.""" + decode_fn = make_decode_fn(str(filename)) + dataset = tf.data.TFRecordDataset(str(filename)) return dataset.map(decode_fn) -def numpytf_count(filename): - meta_filename = filename + ".meta" - with open(meta_filename) as f: - return json.load(f)["count"] +def numpytf_count(filename: str | Path) -> Any: + """Return count from TFRecord file.""" + meta_filename = f"{filename}.meta" + with open(meta_filename, encoding="utf-8") as file: + return json.load(file)["count"] class NumpyTFWriter: - def __init__(self, filename): + """Numpy TF serializer.""" + + def __init__(self, filename: str | Path) -> None: + """Initiate a Numpy TF Serializer.""" self.filename = filename - self.meta_filename = filename + ".meta" - self.writer = tf.io.TFRecordWriter(filename) - self.type_map = {} + self.meta_filename = f"{filename}.meta" + self.writer = tf.io.TFRecordWriter(str(filename)) + self.type_map: dict = {} self.count = 0 - def __enter__(self): + def __enter__(self) -> Any: + """Enter instance.""" return self - def __exit__(self, type, value, traceback): + def __exit__( + self, exception_type: Any, exception_value: Any, exception_traceback: Any + ) -> None: + """Close instance.""" self.close() - def write(self, array_dict): + def write(self, array_dict: dict) -> None: + """Write array dict.""" type_map = {n: str(a.dtype.name) for n, a in array_dict.items()} self.type_map.update(type_map) self.count += 1 @@ -77,31 +93,41 @@ class NumpyTFWriter: example = tf.train.Example(features=tf.train.Features(feature=feature)) self.writer.write(example.SerializeToString()) - def close(self): - with open(self.meta_filename, "w") as f: + def close(self) -> None: + """Close NumpyTFWriter.""" + with open(self.meta_filename, "w", encoding="utf-8") as file: meta = {"type_map": self.type_map, "count": self.count} - json.dump(meta, f) + json.dump(meta, file) self.writer.close() class TFLiteModel: - def __init__(self, filename, batch_size=None, num_threads=None): - if num_threads == 0: + """A representation of a TFLite Model.""" + + def __init__( + self, + filename: str, + batch_size: int | None = None, + num_threads: int | None = None, + ) -> None: + """Initiate a TFLite Model.""" + if not num_threads: num_threads = None - if batch_size == None: + if not batch_size: self.interpreter = interpreter_wrapper.Interpreter( model_path=filename, num_threads=num_threads ) else: # if a batch size is specified, modify the TFLite model to use this size with tempfile.TemporaryDirectory() as tmp: - fb = load(filename) - for sg in fb.subgraphs: - for t in list(sg.inputs) + list(sg.outputs): - sg.tensors[t].shape = np.array( - [batch_size] + list(sg.tensors[t].shape[1:]), dtype=np.int32 + flatbuffer = load(filename) + for subgraph in flatbuffer.subgraphs: + for tensor in list(subgraph.inputs) + list(subgraph.outputs): + subgraph.tensors[tensor].shape = np.array( + [batch_size] + list(subgraph.tensors[tensor].shape[1:]), + dtype=np.int32, ) tempname = os.path.join(tmp, "rewrite_tmp.tflite") - save(fb, tempname) + save(flatbuffer, tempname) self.interpreter = interpreter_wrapper.Interpreter( model_path=tempname, num_threads=num_threads ) @@ -122,8 +148,9 @@ class TFLiteModel: self.shape_from_name = {d["name"]: d["shape"] for d in details} self.batch_size = next(iter(self.shape_from_name.values()))[0] - def __call__(self, named_input): - """Execute the model on one or a batch of named inputs (a dict of name: numpy array)""" + def __call__(self, named_input: dict) -> dict: + """Execute the model on one or a batch of named inputs \ + (a dict of name: numpy array).""" input_len = next(iter(named_input.values())).shape[0] full_steps = input_len // self.batch_size remainder = input_len % self.batch_size @@ -131,39 +158,46 @@ class TFLiteModel: named_ys = defaultdict(list) for i in range(full_steps): for name, x_batch in named_input.items(): - x = x_batch[i : i + self.batch_size] - self.interpreter.set_tensor(self.handle_from_name[name], x) + x_tensor = x_batch[i : i + self.batch_size] # noqa: E203 + self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) self.interpreter.invoke() - for d in self.output_details: - named_ys[d["name"]].append(self.interpreter.get_tensor(d["index"])) + for output_detail in self.output_details: + named_ys[output_detail["name"]].append( + self.interpreter.get_tensor(output_detail["index"]) + ) if remainder: for name, x_batch in named_input.items(): - x = np.zeros(self.shape_from_name[name]).astype(x_batch.dtype) - x[:remainder] = x_batch[-remainder:] - self.interpreter.set_tensor(self.handle_from_name[name], x) + x_tensor = np.zeros( # pylint: disable=invalid-name + self.shape_from_name[name] + ).astype(x_batch.dtype) + x_tensor[:remainder] = x_batch[-remainder:] + self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) self.interpreter.invoke() - for d in self.output_details: - named_ys[d["name"]].append( - self.interpreter.get_tensor(d["index"])[:remainder] + for output_detail in self.output_details: + named_ys[output_detail["name"]].append( + self.interpreter.get_tensor(output_detail["index"])[:remainder] ) return {k: np.concatenate(v) for k, v in named_ys.items()} - def input_tensors(self): + def input_tensors(self) -> list: + """Return name from input details.""" return [d["name"] for d in self.input_details] - def output_tensors(self): + def output_tensors(self) -> list: + """Return name from output details.""" return [d["name"] for d in self.output_details] -def sample_tfrec(input_file, k, output_file): +def sample_tfrec(input_file: str, k: int, output_file: str) -> None: + """Count, read and write TFRecord input and output data.""" total = numpytf_count(input_file) - next = sorted(random.sample(range(total), k=k), reverse=True) + next_sample = sorted(random.sample(range(total), k=k), reverse=True) - reader = NumpyTFReader(input_file) + reader = numpytf_read(input_file) with NumpyTFWriter(output_file) as writer: for i, data in enumerate(reader): - if i == next[-1]: - next.pop() + if i == next_sample[-1]: + next_sample.pop() writer.write(data) - if not next: + if not next_sample: break diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py index b1a2914..d930a1e 100644 --- a/src/mlia/nn/rewrite/core/utils/parallel.py +++ b/src/mlia/nn/rewrite/core/utils/parallel.py @@ -1,28 +1,45 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Parallelize a TFLiteModel.""" +from __future__ import annotations + +import logging import math import os from collections import defaultdict from multiprocessing import cpu_count from multiprocessing import Pool +from pathlib import Path +from typing import Any import numpy as np - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) - from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +logger = logging.getLogger(__name__) + class ParallelTFLiteModel(TFLiteModel): - def __init__(self, filename, num_procs=1, num_threads=0, batch_size=None): - """num_procs: 0 => detect real cores on system - num_threads: 0 => TFLite impl. specific setting, usually 3 - batch_size: None => automatic (num_procs or file-determined) - """ + """A parallel version of a TFLiteModel. + + num_procs: 0 => detect real cores on system + num_threads: 0 => TFLite impl. specific setting, usually 3 + batch_size: None => automatic (num_procs or file-determined) + """ + + def __init__( + self, + filename: str | Path, + num_procs: int = 1, + num_threads: int = 0, + batch_size: int | None = None, + ) -> None: + """Initiate a Parallel TFLite Model.""" self.pool = None + filename = str(filename) self.filename = filename if not num_procs: self.num_procs = cpu_count() @@ -37,7 +54,7 @@ class ParallelTFLiteModel(TFLiteModel): local_batch_size = int(math.ceil(batch_size / self.num_procs)) super().__init__(filename, batch_size=local_batch_size) del self.interpreter - self.pool = Pool( + self.pool = Pool( # pylint: disable=consider-using-with processes=self.num_procs, initializer=_pool_create_worker, initargs=[filename, self.batch_size, self.num_threads], @@ -51,15 +68,18 @@ class ParallelTFLiteModel(TFLiteModel): self.partial_batches = 0 self.warned = False - def close(self): + def close(self) -> None: + """Close and terminate pool.""" if self.pool: self.pool.close() self.pool.terminate() - def __del__(self): + def __del__(self) -> None: + """Close instance.""" self.close() - def __call__(self, named_input): + def __call__(self, named_input: dict) -> Any: + """Call instance.""" if self.pool: global_batch_size = next(iter(named_input.values())).shape[0] # Note: self.batch_size comes from superclass and is local batch size @@ -72,19 +92,21 @@ class ParallelTFLiteModel(TFLiteModel): and self.total_batches > 10 and self.partial_batches / self.total_batches >= 0.5 ): - print( - "ParallelTFLiteModel(%s): warning - %.1f%% of batches do not use all %d processes, set batch size to a multiple of this" - % ( - self.filename, - 100 * self.partial_batches / self.total_batches, - self.num_procs, - ) + logger.warning( + "ParallelTFLiteModel(%s): warning - %.1f of batches " + "do not use all %d processes, set batch size to " + "a multiple of this.", + self.filename, + 100 * self.partial_batches / self.total_batches, + self.num_procs, ) self.warned = True local_batches = [ { - key: values[i * self.batch_size : (i + 1) * self.batch_size] + key: values[ + i * self.batch_size : (i + 1) * self.batch_size # noqa: E203 + ] for key, values in named_input.items() } for i in range(chunks) @@ -92,22 +114,26 @@ class ParallelTFLiteModel(TFLiteModel): chunk_results = self.pool.map(_pool_run, local_batches) named_ys = defaultdict(list) for chunk in chunk_results: - for k, v in chunk.items(): - named_ys[k].append(v) - return {k: np.concatenate(v) for k, v in named_ys.items()} - else: - return super().__call__(named_input) + for key, value in chunk.items(): + named_ys[key].append(value) + return {key: np.concatenate(value) for key, value in named_ys.items()} + + return super().__call__(named_input) -_local_model = None +_LOCAL_MODEL = None -def _pool_create_worker(filename, local_batch_size=None, num_threads=None): - global _local_model - _local_model = TFLiteModel( +def _pool_create_worker( + filename: str, local_batch_size: int = 0, num_threads: int = 0 +) -> None: + global _LOCAL_MODEL # pylint: disable=global-statement + _LOCAL_MODEL = TFLiteModel( filename, batch_size=local_batch_size, num_threads=num_threads ) -def _pool_run(named_inputs): - return _local_model(named_inputs) +def _pool_run(named_inputs: dict) -> Any: + if _LOCAL_MODEL: + return _LOCAL_MODEL(named_inputs) + raise ValueError("TFLiteModel is not initiated") diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py index d1ed322..ddf0cc2 100644 --- a/src/mlia/nn/rewrite/core/utils/utils.py +++ b/src/mlia/nn/rewrite/core/utils/utils.py @@ -1,22 +1,28 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -import os +"""Model and file system utilites.""" +from __future__ import annotations + +from pathlib import Path import flatbuffers -from tensorflow.lite.python import schema_py_generated as schema_fb +from tensorflow.lite.python.schema_py_generated import Model +from tensorflow.lite.python.schema_py_generated import ModelT -def load(input_tflite_file): - if not os.path.exists(input_tflite_file): - raise FileNotFoundError("TFLite file not found at %r\n" % input_tflite_file) +def load(input_tflite_file: str | Path) -> ModelT: + """Load a flatbuffer model from file.""" + if not Path(input_tflite_file).exists(): + raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n") with open(input_tflite_file, "rb") as file_handle: file_data = bytearray(file_handle.read()) - model_obj = schema_fb.Model.GetRootAsModel(file_data, 0) - model = schema_fb.ModelT.InitFromObj(model_obj) + model_obj = Model.GetRootAsModel(file_data, 0) + model = ModelT.InitFromObj(model_obj) return model -def save(model, output_tflite_file): +def save(model: ModelT, output_tflite_file: str | Path) -> None: + """Save a flatbuffer model to a given file.""" builder = flatbuffers.Builder(1024) # Initial size of the buffer, which # will grow automatically if needed model_offset = model.Pack(builder) diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 7a25e47..5e223fa 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -6,6 +6,8 @@ from __future__ import annotations import math from pathlib import Path from typing import Any +from typing import cast +from typing import List from typing import NamedTuple import tensorflow as tf @@ -129,12 +131,12 @@ def get_optimizer( return Clusterer(model, config) if isinstance(config, RewriteConfiguration): - return Rewriter(model, config) # type: ignore + return Rewriter(model, config) - if isinstance(config, OptimizationSettings) or is_list_of( - config, OptimizationSettings - ): - return _get_optimizer(model, config) # type: ignore + if isinstance(config, OptimizationSettings): + return _get_optimizer(model, cast(OptimizationSettings, config)) + if is_list_of(config, OptimizationSettings): + return _get_optimizer(model, cast(List[OptimizationSettings], config)) raise ConfigurationError(f"Unknown optimization configuration {config}") @@ -186,7 +188,7 @@ def _get_optimizer_configuration( if opt_type == "rewrite": if isinstance(optimization_target, str): - return RewriteConfiguration( # type: ignore + return RewriteConfiguration( str(optimization_target), layers_to_optimize, dataset ) diff --git a/tests/test_nn_rewrite_core_graph_edit_cut.py b/tests/test_nn_rewrite_core_graph_edit_cut.py index 914fdfd..7d267ed 100644 --- a/tests/test_nn_rewrite_core_graph_edit_cut.py +++ b/tests/test_nn_rewrite_core_graph_edit_cut.py @@ -13,11 +13,11 @@ def test_cut_model(test_tflite_model: Path, tmp_path: Path) -> None: """Test the function cut_model().""" output_file = tmp_path / "out.tflite" cut_model( - model_file=test_tflite_model, + model_file=str(test_tflite_model), input_names=["serving_default_input:0"], output_names=["sequential/flatten/Reshape"], subgraph_index=0, - output_file=output_file, + output_file=str(output_file), ) assert output_file.is_file() diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py index 39aeef5..cd728af 100644 --- a/tests/test_nn_rewrite_core_graph_edit_record.py +++ b/tests/test_nn_rewrite_core_graph_edit_record.py @@ -7,7 +7,7 @@ import pytest import tensorflow as tf from mlia.nn.rewrite.core.graph_edit.record import record_model -from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read @pytest.mark.parametrize("batch_size", (None, 1, 2)) @@ -46,7 +46,7 @@ def test_record_model( # any of the model outputs interpreter = tf.lite.Interpreter(str(test_tflite_model)) model_outputs = interpreter.get_output_details() - dataset = NumpyTFReader(str(output_file)) + dataset = numpytf_read(str(output_file)) for data in dataset: for name, tensor in data.items(): assert data_matches_outputs(name, tensor, model_outputs) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index d2bc1e0..3c2ef3e 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -6,17 +6,21 @@ from __future__ import annotations from pathlib import Path from tempfile import TemporaryDirectory +from typing import Any import numpy as np import pytest import tensorflow as tf from mlia.nn.rewrite.core.train import augmentation_presets +from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train -def replace_fully_connected_with_conv(input_shape, output_shape) -> tf.keras.Model: +def replace_fully_connected_with_conv( + input_shape: Any, output_shape: Any +) -> tf.keras.Model: """Get a replacement model for the fully connected layer.""" for name, shape in { "Input": input_shape, @@ -43,7 +47,7 @@ def check_train( augmentation_preset: tuple[float | None, float | None] = augmentation_presets[ "none" ], - lr_schedule: str = "cosine", + lr_schedule: LearningRateSchedule = "cosine", use_unmodified_model: bool = False, num_procs: int = 1, ) -> None: @@ -60,7 +64,7 @@ def check_train( output_tensors=["StatefulPartitionedCall:0"], augment=augmentation_preset, steps=32, - lr=1e-3, + learning_rate=1e-3, batch_size=batch_size, verbose=verbose, show_progress=show_progress, @@ -104,7 +108,7 @@ def test_train( verbose: bool, show_progress: bool, augmentation_preset: tuple[float | None, float | None], - lr_schedule: str, + lr_schedule: LearningRateSchedule, use_unmodified_model: bool, num_procs: int, ) -> None: @@ -131,7 +135,7 @@ def test_train_invalid_schedule( check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - lr_schedule="unknown_schedule", + lr_schedule="unknown_schedule", # type: ignore ) -- cgit v1.2.1