diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/__init__.py | 1 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/cut.py | 139 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/diff.py | 102 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/join.py | 111 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/record.py | 51 |
5 files changed, 237 insertions, 167 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py index 8c1f750..273f4a4 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/__init__.py +++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py @@ -1,2 +1,3 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Rewrite core graph edit module.""" diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py index a323b7b..2707eb1 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/cut.py +++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py @@ -1,128 +1,143 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Cut module.""" import os from collections import defaultdict +from typing import Optional -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT +from tensorflow.lite.python.schema_py_generated import SubGraphT -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.rewrite.core.utils.utils import save -from mlia.nn.rewrite.core.utils.utils import load, save +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -def cut_subgraph(subgraph, input_tensor_names, output_tensor_names): - """Change the global inputs and outputs of a graph to the provided named tensors""" +def tensors_by_name(subgraph: SubGraphT, names: list) -> list: + """Seek out tensors from a subgraph and return the result.""" + seek = frozenset([name.encode("utf-8") for name in names]) + tensors = [i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek] + return tensors - def tensors_by_name(names): - seek = frozenset([name.encode("utf-8") for name in names]) - tensors = [ - i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek - ] - return tensors +def cut_subgraph( + subgraph: SubGraphT, + input_tensor_names: Optional[list], + output_tensor_names: Optional[list], +) -> None: + """Change the global inputs and outputs of a graph to the provided named tensors.""" if input_tensor_names is not None: - subgraph.inputs = tensors_by_name(input_tensor_names) + subgraph.inputs = tensors_by_name(subgraph, input_tensor_names) assert len(subgraph.inputs) == len( input_tensor_names - ), "Expected %d input tensors: %s\nFound: %s" % ( - len(subgraph.inputs), - ", ".join(input_tensor_names), - ", ".join(subgraph.tensors[i].name for i in subgraph.inputs), - ) - + ), f"Expected {len(subgraph.inputs)} input tensors: \ + {', '.join(input_tensor_names)}\nFound: \ + {', '.join(subgraph.tensors[i].name for i in subgraph.inputs)}" if output_tensor_names is not None: - subgraph.outputs = tensors_by_name(output_tensor_names) + subgraph.outputs = tensors_by_name(subgraph, output_tensor_names) assert len(subgraph.outputs) == len( output_tensor_names - ), "Expected %d output tensors: %s\nFound: %s" % ( - len(subgraph.outputs), - ", ".join(output_tensor_names), - ", ".join(subgraph.tensors[i].name for i in subgraph.outputs), - ) + ), f"Expected {len(subgraph.outputs)} output tensors: \ + {', '.join(output_tensor_names)}\nFound: \ + {', '.join(subgraph.tensors[i].name for i in subgraph.outputs)}" -def simplify(model): - """Remove any unused operators, tensors and buffers from a model""" - for s in model.subgraphs: - simplify_subgraph(s) +def simplify(model: ModelT) -> None: + """Remove any unused operators, tensors and buffers from a model.""" + for subgraph in model.subgraphs: + simplify_subgraph(subgraph) - used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs} - used_buffers = used_buffers.union({m.buffer for m in model.metadata}) + used_buffers = { + tensor.buffer for tensor in subgraph.tensors for subgraph in model.subgraphs + } + used_buffers = used_buffers.union({metadata.buffer for metadata in model.metadata}) used_buffers.add( 0 - ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime + ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the + # TFLite runtime model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers) - for s in model.subgraphs: - for t in s.tensors: - t.buffer = buf_relabel[t.buffer] + for subgraph in model.subgraphs: + for tensor in subgraph.tensors: + tensor.buffer = buf_relabel[tensor.buffer] - for m in model.metadata: - m.buffer = buf_relabel[m.buffer] + for metadata in model.metadata: + metadata.buffer = buf_relabel[metadata.buffer] -def simplify_subgraph(subgraph): +def simplify_subgraph(subgraph: SubGraphT) -> None: + """Simplify a subgraph given its operators.""" requires = defaultdict(set) - for o, operator in enumerate(subgraph.operators): - for t in operator.outputs: - if not t in subgraph.inputs: - requires[t].add(o) + for output, operator in enumerate(subgraph.operators): + for tensor in operator.outputs: + if tensor not in subgraph.inputs: + requires[tensor].add(output) op_set, ten_set = find_required(subgraph, requires, subgraph.outputs) - subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set) + subgraph.operators, _ = filter_relabel(subgraph.operators, op_set) subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set) ten_relabel[-1] = -1 # Some files have ops with -1 input tensors; leave unchanged - for op in subgraph.operators: - op.inputs = [ten_relabel[t] for t in op.inputs] - op.outputs = [ten_relabel[t] for t in op.outputs] + for operator in subgraph.operators: + operator.inputs = [ten_relabel[tensor] for tensor in operator.inputs] + operator.outputs = [ten_relabel[tensor] for tensor in operator.outputs] - subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs] - subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs] + subgraph.inputs = [ten_relabel[tensor] for tensor in subgraph.inputs] + subgraph.outputs = [ten_relabel[tensors] for tensors in subgraph.outputs] -def find_required(subgraph, requires, tensors): - visited_operators = set() +def find_required(subgraph: SubGraphT, requires: dict, tensors: dict) -> tuple: + """Find required operators in a given subgraph.""" + visited_operators: set = set() visited_tensors = set(tensors) stop_tensors = set(subgraph.inputs) - changed = True next_tensors = visited_tensors while next_tensors: loop_tensors = next_tensors next_tensors = set() - for t in loop_tensors: - candidate_operators = set(requires[t]) + for tensor in loop_tensors: + candidate_operators = set(requires[tensor]) new_operators = candidate_operators - visited_operators visited_operators = visited_operators.union(new_operators) - for op in new_operators: - candidate_tensors = set(subgraph.operators[op].inputs) + for operator in new_operators: + candidate_tensors = set(subgraph.operators[operator].inputs) new_tensors = candidate_tensors - (visited_tensors.union(next_tensors)) next_tensors = next_tensors.union(new_tensors) visited_tensors = visited_tensors.union(candidate_tensors) visited_tensors = visited_tensors.union( - subgraph.operators[op].outputs + subgraph.operators[operator].outputs ) # include stub outputs but do not traverse them next_tensors = next_tensors - stop_tensors return visited_operators, visited_tensors -def filter_relabel(src, filter): - relabel = {} - output = [] - for i, x in enumerate(src): - if i in filter: +def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple: + """Relabel tensors in a subgraph based on a filter.""" + relabel: dict = {} + output: list = [] + for i, out in enumerate(src_subgraph): + if i in relabel_filter: relabel[i] = len(output) - output.append(x) + output.append(out) return output, relabel -def cut_model(model_file, input_names, output_names, subgraph_index, output_file): +def cut_model( + model_file: str, + input_names: Optional[list], + output_names: Optional[list], + subgraph_index: int, + output_file: str, +) -> None: + """Cut subgraphs and simplify a given model.""" model = load(model_file) subgraph = model.subgraphs[subgraph_index] cut_subgraph(subgraph, input_names, output_names) diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py index 0829f0a..198e47e 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/diff.py +++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py @@ -1,63 +1,79 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Diff module: compare subgraph outputs.""" +# pylint: disable=too-many-locals +from __future__ import annotations + import os from collections import defaultdict +from pathlib import Path +from typing import Any -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import numpy as np import tensorflow as tf +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -import numpy as np -from tensorflow.lite.python import interpreter as interpreter_wrapper -from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter +def dict_mean(mean_dict: dict) -> Any: + """Return the mean of values in a given dict.""" + return np.mean(list(mean_dict.values())) -def diff_stats(file1, file2, per_tensor_and_channel=False): - dataset1 = NumpyTFReader(file1) - dataset2 = NumpyTFReader(file2) - totals = defaultdict(dict) +def add_total(name: str, key: str, values: list, totals: dict) -> None: + """Append values to dict totals.""" + if key not in totals[name]: + totals[name][key] = values + else: + totals[name][key] += values + - def add_total(name, key, values): - if not key in totals[name]: - totals[name][key] = values - else: - totals[name][key] += values +def diff_stats( + file1: str | Path, file2: str | Path, per_tensor_and_channel: bool = False +) -> tuple: + """Compare the statistics of outputs between two subgraphs.""" + dataset1 = numpytf_read(file1) + dataset2 = numpytf_read(file2) - # First iterate through dataset1 and calculate per-channel total for each tensor + totals: dict = defaultdict(dict) + + # First iterate through dataset and calculate per-channel total for each tensor count = 0 - for d in dataset1: + for data in dataset1: count += 1 - for k, v in d.items(): - value = v.numpy().astype(np.double) - add_total("dataset1_total", k, value) + for key, val in data.items(): + value = val.numpy().astype(np.double) + add_total("dataset1_total", key, value, totals) # Use this to calculate per-channel mean for each tensor - per_tensor_mean = lambda name: { - k: total / count for k, total in totals[name].items() - } + def per_tensor_mean(name: str) -> dict: + return {k: total / count for k, total in totals[name].items()} + dataset1_mean = per_tensor_mean("dataset1_total") - # Next iterate through both datasets and calculate per-channel total squared error - # between them for each tensor and dataset1 variance for each tensor using the mean from above - for i, (x1, x2) in enumerate(zip(dataset1, dataset2)): - assert x1.keys() == x2.keys(), ( - "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n" - % ( - i, - file1, - ", ".join(x1.keys()), - file2, - ", ".join(x2.keys()), - ) + # Next iterate through both datasets and calculate per-channel total squared + # error between them for each tensor and dataset1 variance for each tensor + # using the mean from above + for i, (ds1, ds2) in enumerate(zip(dataset1, dataset2)): + assert ds1.keys() == ds2.keys(), ( + f"At input {i} the files have different sets of tensors.\n" + f"{file1}: {', '.join(ds1.keys())}\n" + f"{file2}: {', '.join(ds2.keys())}\n" ) - for k in x1.keys(): - v1 = x1[k].numpy().astype(np.double) - v2 = x2[k].numpy().astype(np.double) - add_total("ae", k, abs(v1 - v2)) - add_total("se", k, (v1 - v2) ** 2) - add_total("dataset1_variance", k, (v1 - dataset1_mean[k]) ** 2) + for key in ds1.keys(): + tensor1 = ds1[key].numpy().astype(np.double) + tensor2 = ds2[key].numpy().astype(np.double) + add_total("ae", key, abs(tensor1 - tensor2), totals) + add_total("se", key, (tensor1 - tensor2) ** 2, totals) + add_total( + "dataset1_variance", + key, + (tensor1 - dataset1_mean[key]) ** 2, + totals, + ) # Finally average over number of inputs to get the rmse and the dataset1 variance mae = per_tensor_mean("ae") @@ -66,7 +82,8 @@ def diff_stats(file1, file2, per_tensor_and_channel=False): dataset1_var = per_tensor_mean("dataset1_variance") is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var} - # Divide by target standard deviation to get the per-channel nrmse for each tensor where possible + # Divide by target standard deviation to get the per-channel nrmse for each + # tensor where possible nrmse = { k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]]) for k, v in rmse.items() @@ -74,6 +91,5 @@ def diff_stats(file1, file2, per_tensor_and_channel=False): if per_tensor_and_channel: return mae, nrmse - else: - dict_mean = lambda d: np.mean(list(d.values())) - return dict_mean(mae), dict_mean(nrmse) + + return dict_mean(mae), dict_mean(nrmse) diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py index 758f4cf..14a7347 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/join.py +++ b/src/mlia/nn/rewrite/core/graph_edit/join.py @@ -1,16 +1,31 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Join module.""" +from __future__ import annotations + import os +from pathlib import Path -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT +from tensorflow.lite.python.schema_py_generated import OperatorCodeT +from tensorflow.lite.python.schema_py_generated import SubGraphT -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.rewrite.core.utils.utils import save -from mlia.nn.rewrite.core.utils.utils import load, save +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=0): +def join_models( + input_src: str | Path, + input_dst: str | Path, + output_file: str | Path, + subgraph_src: SubGraphT = 0, + subgraph_dst: SubGraphT = 0, +) -> None: + """Join two models and save the result into a given model file path.""" src_model = load(input_src) dst_model = load(input_dst) src_subgraph = src_model.subgraphs[subgraph_src] @@ -19,8 +34,13 @@ def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst= save(dst_model, output_file) -def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): - """Copy subgraph src into subgraph dst from model, connecting tensors with the same names""" +def join_subgraphs( + src_model: ModelT, + src_subgraph: SubGraphT, + dst_model: ModelT, + dst_subgraph: SubGraphT, +) -> None: + """Join two subgraphs, connecting tensors with the same names.""" # Find inputs that match outputs in the other graph and vice versa dst_to_src = { i: o @@ -36,10 +56,11 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name } - assert not (src_to_dst and dst_to_src), ( - "Source and destination subgraphs appear to connect in a loop: %d tensors from src to dst, %d tensors from dst to src" - % (len(src_to_dst), len(dst_to_src)) - ) + assert not ( + src_to_dst and dst_to_src + ), f"Source and destination subgraphs appear to connect in a loop: \ + {len(src_to_dst)} tensors from src to dst, {len(dst_to_src)} \ + tensors from dst to src" # Relabel matched input/output tensors between graphs tensor_relabel = src_to_dst if src_to_dst else dst_to_src @@ -47,36 +68,46 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): # Remove matched inputs/outputs as these will now become internal tensors if src_to_dst: src_subgraph.outputs = [ - o for o in src_subgraph.outputs if not o in tensor_relabel.keys() + output + for output in src_subgraph.outputs + if output not in tensor_relabel.keys() ] dst_subgraph.inputs = [ - i for i in dst_subgraph.inputs if not i in tensor_relabel.values() + input + for input in dst_subgraph.inputs + if input not in tensor_relabel.values() ] else: src_subgraph.inputs = [ - i for i in src_subgraph.inputs if not i in tensor_relabel.keys() + input for input in src_subgraph.inputs if input not in tensor_relabel.keys() ] dst_subgraph.outputs = [ - o for o in dst_subgraph.outputs if not o in tensor_relabel.values() + output + for output in dst_subgraph.outputs + if output not in tensor_relabel.values() ] buffer_relabel = { - src_subgraph.tensors[i].buffer: dst_subgraph.tensors[o].buffer - for i, o in tensor_relabel.items() + src_subgraph.tensors[input].buffer: dst_subgraph.tensors[output].buffer + for input, output in tensor_relabel.items() } used_tensors = [ - t for i, t in enumerate(src_subgraph.tensors) if not i in tensor_relabel + tensor + for i, tensor in enumerate(src_subgraph.tensors) + if i not in tensor_relabel ] - used_buffer_ids = [t.buffer for t in used_tensors] + used_buffer_ids = [tensor.buffer for tensor in used_tensors] + + def opcode_data(code: OperatorCodeT) -> tuple: + return ( + code.builtinCode, + code.deprecatedBuiltinCode, + code.customCode, + code.version, + ) - opcode_data = lambda c: ( - c.builtinCode, - c.deprecatedBuiltinCode, - c.customCode, - c.version, - ) opcode_relabel = { s: d for s in range(len(src_model.operatorCodes)) @@ -85,7 +116,8 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): == opcode_data(dst_model.operatorCodes[d]) } - # operator order defines execution schedule so must reflect the inputs/outputs dependencies + # operator order defines execution schedule so must reflect + # the inputs/outputs dependencies if dst_to_src: dst_subgraph.operators += src_subgraph.operators else: @@ -99,17 +131,17 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): ] = -1 # Some files have ops with -1 input tensors; leave unchanged for i in used_buffer_ids: - if not i in buffer_relabel: + if i not in buffer_relabel: buffer_relabel[i] = len(dst_model.buffers) dst_model.buffers.append(src_model.buffers[i]) - for o in src_subgraph.operators: - o.inputs = [tensor_relabel[t] for t in o.inputs] - o.outputs = [tensor_relabel[t] for t in o.outputs] - o.opcodeIndex = opcode_relabel[o.opcodeIndex] + for operator in src_subgraph.operators: + operator.inputs = [tensor_relabel[tensor] for tensor in operator.inputs] + operator.outputs = [tensor_relabel[tensor] for tensor in operator.outputs] + operator.opcodeIndex = opcode_relabel[operator.opcodeIndex] - for t in used_tensors: - t.buffer = buffer_relabel[t.buffer] + for tensor in used_tensors: + tensor.buffer = buffer_relabel[tensor.buffer] src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs] src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs] @@ -118,11 +150,12 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph): dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs)) -def append_relabel(src, dst, map=None): - if map is None: - map = {} - for i, x in enumerate(src): - if not i in map: - map[i] = len(dst) +def append_relabel(src: list, dst: list, operator_map: dict | None = None) -> dict: + """Return a map over relabeled tensors in a subgraph.""" + if not operator_map: + operator_map = {} + for i, x in enumerate(src): # pylint: disable=invalid-name + if i not in operator_map: + operator_map[i] = len(dst) dst.append(x) - return map + return operator_map diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index ae13313..90f3db8 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -1,34 +1,39 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Save subgraph data.""" +# pylint: disable=too-many-locals +from __future__ import annotations + import math import os +from pathlib import Path -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf +from rich.progress import track -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) - -from tqdm import tqdm -from mlia.nn.rewrite.core.utils.numpy_tfrecord import ( - NumpyTFReader, - NumpyTFWriter, - TFLiteModel, - numpytf_count, -) +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read +from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + def record_model( - input_filename, - model_filename, - output_filename, - batch_size=None, - show_progress=False, - num_procs=1, - num_threads=0, -): - """num_procs: 0 => detect real cores on system - num_threads: 0 => TFLite impl. specific setting, usually 3""" + input_filename: str | Path, + model_filename: str | Path, + output_filename: str | Path, + batch_size: int = 0, + show_progress: bool = False, + num_procs: int = 1, + num_threads: int = 0, +) -> None: + """Model recorder. + + num_procs: 0 => detect real cores on system + num_threads: 0 => TFLite impl. specific setting, usually 3 + """ model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size) if not batch_size: batch_size = ( @@ -36,10 +41,10 @@ def record_model( ) # automatically batch to the minimum effective size if not specified total = numpytf_count(input_filename) - dataset = NumpyTFReader(input_filename) + dataset = numpytf_read(input_filename) if batch_size > 1: - # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now. + # Collapse batch-size 1 items into batch-size n. dataset = dataset.map( lambda d: {k: tf.squeeze(v, axis=0) for k, v in d.items()} ) @@ -48,7 +53,7 @@ def record_model( with NumpyTFWriter(output_filename) as writer: for _, named_x in enumerate( - tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress) + track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) ): named_y = model(named_x) if batch_size > 1: |