aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/__init__.py1
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py139
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/diff.py102
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py111
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py51
5 files changed, 237 insertions, 167 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
index 8c1f750..273f4a4 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/__init__.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
@@ -1,2 +1,3 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Rewrite core graph edit module."""
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
index a323b7b..2707eb1 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/cut.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -1,128 +1,143 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Cut module."""
import os
from collections import defaultdict
+from typing import Optional
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+from tensorflow.lite.python.schema_py_generated import SubGraphT
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
-from mlia.nn.rewrite.core.utils.utils import load, save
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-def cut_subgraph(subgraph, input_tensor_names, output_tensor_names):
- """Change the global inputs and outputs of a graph to the provided named tensors"""
+def tensors_by_name(subgraph: SubGraphT, names: list) -> list:
+ """Seek out tensors from a subgraph and return the result."""
+ seek = frozenset([name.encode("utf-8") for name in names])
+ tensors = [i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek]
+ return tensors
- def tensors_by_name(names):
- seek = frozenset([name.encode("utf-8") for name in names])
- tensors = [
- i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek
- ]
- return tensors
+def cut_subgraph(
+ subgraph: SubGraphT,
+ input_tensor_names: Optional[list],
+ output_tensor_names: Optional[list],
+) -> None:
+ """Change the global inputs and outputs of a graph to the provided named tensors."""
if input_tensor_names is not None:
- subgraph.inputs = tensors_by_name(input_tensor_names)
+ subgraph.inputs = tensors_by_name(subgraph, input_tensor_names)
assert len(subgraph.inputs) == len(
input_tensor_names
- ), "Expected %d input tensors: %s\nFound: %s" % (
- len(subgraph.inputs),
- ", ".join(input_tensor_names),
- ", ".join(subgraph.tensors[i].name for i in subgraph.inputs),
- )
-
+ ), f"Expected {len(subgraph.inputs)} input tensors: \
+ {', '.join(input_tensor_names)}\nFound: \
+ {', '.join(subgraph.tensors[i].name for i in subgraph.inputs)}"
if output_tensor_names is not None:
- subgraph.outputs = tensors_by_name(output_tensor_names)
+ subgraph.outputs = tensors_by_name(subgraph, output_tensor_names)
assert len(subgraph.outputs) == len(
output_tensor_names
- ), "Expected %d output tensors: %s\nFound: %s" % (
- len(subgraph.outputs),
- ", ".join(output_tensor_names),
- ", ".join(subgraph.tensors[i].name for i in subgraph.outputs),
- )
+ ), f"Expected {len(subgraph.outputs)} output tensors: \
+ {', '.join(output_tensor_names)}\nFound: \
+ {', '.join(subgraph.tensors[i].name for i in subgraph.outputs)}"
-def simplify(model):
- """Remove any unused operators, tensors and buffers from a model"""
- for s in model.subgraphs:
- simplify_subgraph(s)
+def simplify(model: ModelT) -> None:
+ """Remove any unused operators, tensors and buffers from a model."""
+ for subgraph in model.subgraphs:
+ simplify_subgraph(subgraph)
- used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs}
- used_buffers = used_buffers.union({m.buffer for m in model.metadata})
+ used_buffers = {
+ tensor.buffer for tensor in subgraph.tensors for subgraph in model.subgraphs
+ }
+ used_buffers = used_buffers.union({metadata.buffer for metadata in model.metadata})
used_buffers.add(
0
- ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime
+ ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the
+ # TFLite runtime
model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers)
- for s in model.subgraphs:
- for t in s.tensors:
- t.buffer = buf_relabel[t.buffer]
+ for subgraph in model.subgraphs:
+ for tensor in subgraph.tensors:
+ tensor.buffer = buf_relabel[tensor.buffer]
- for m in model.metadata:
- m.buffer = buf_relabel[m.buffer]
+ for metadata in model.metadata:
+ metadata.buffer = buf_relabel[metadata.buffer]
-def simplify_subgraph(subgraph):
+def simplify_subgraph(subgraph: SubGraphT) -> None:
+ """Simplify a subgraph given its operators."""
requires = defaultdict(set)
- for o, operator in enumerate(subgraph.operators):
- for t in operator.outputs:
- if not t in subgraph.inputs:
- requires[t].add(o)
+ for output, operator in enumerate(subgraph.operators):
+ for tensor in operator.outputs:
+ if tensor not in subgraph.inputs:
+ requires[tensor].add(output)
op_set, ten_set = find_required(subgraph, requires, subgraph.outputs)
- subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set)
+ subgraph.operators, _ = filter_relabel(subgraph.operators, op_set)
subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set)
ten_relabel[-1] = -1 # Some files have ops with -1 input tensors; leave unchanged
- for op in subgraph.operators:
- op.inputs = [ten_relabel[t] for t in op.inputs]
- op.outputs = [ten_relabel[t] for t in op.outputs]
+ for operator in subgraph.operators:
+ operator.inputs = [ten_relabel[tensor] for tensor in operator.inputs]
+ operator.outputs = [ten_relabel[tensor] for tensor in operator.outputs]
- subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs]
- subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs]
+ subgraph.inputs = [ten_relabel[tensor] for tensor in subgraph.inputs]
+ subgraph.outputs = [ten_relabel[tensors] for tensors in subgraph.outputs]
-def find_required(subgraph, requires, tensors):
- visited_operators = set()
+def find_required(subgraph: SubGraphT, requires: dict, tensors: dict) -> tuple:
+ """Find required operators in a given subgraph."""
+ visited_operators: set = set()
visited_tensors = set(tensors)
stop_tensors = set(subgraph.inputs)
- changed = True
next_tensors = visited_tensors
while next_tensors:
loop_tensors = next_tensors
next_tensors = set()
- for t in loop_tensors:
- candidate_operators = set(requires[t])
+ for tensor in loop_tensors:
+ candidate_operators = set(requires[tensor])
new_operators = candidate_operators - visited_operators
visited_operators = visited_operators.union(new_operators)
- for op in new_operators:
- candidate_tensors = set(subgraph.operators[op].inputs)
+ for operator in new_operators:
+ candidate_tensors = set(subgraph.operators[operator].inputs)
new_tensors = candidate_tensors - (visited_tensors.union(next_tensors))
next_tensors = next_tensors.union(new_tensors)
visited_tensors = visited_tensors.union(candidate_tensors)
visited_tensors = visited_tensors.union(
- subgraph.operators[op].outputs
+ subgraph.operators[operator].outputs
) # include stub outputs but do not traverse them
next_tensors = next_tensors - stop_tensors
return visited_operators, visited_tensors
-def filter_relabel(src, filter):
- relabel = {}
- output = []
- for i, x in enumerate(src):
- if i in filter:
+def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple:
+ """Relabel tensors in a subgraph based on a filter."""
+ relabel: dict = {}
+ output: list = []
+ for i, out in enumerate(src_subgraph):
+ if i in relabel_filter:
relabel[i] = len(output)
- output.append(x)
+ output.append(out)
return output, relabel
-def cut_model(model_file, input_names, output_names, subgraph_index, output_file):
+def cut_model(
+ model_file: str,
+ input_names: Optional[list],
+ output_names: Optional[list],
+ subgraph_index: int,
+ output_file: str,
+) -> None:
+ """Cut subgraphs and simplify a given model."""
model = load(model_file)
subgraph = model.subgraphs[subgraph_index]
cut_subgraph(subgraph, input_names, output_names)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py
index 0829f0a..198e47e 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/diff.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py
@@ -1,63 +1,79 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Diff module: compare subgraph outputs."""
+# pylint: disable=too-many-locals
+from __future__ import annotations
+
import os
from collections import defaultdict
+from pathlib import Path
+from typing import Any
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import numpy as np
import tensorflow as tf
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-import numpy as np
-from tensorflow.lite.python import interpreter as interpreter_wrapper
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter
+def dict_mean(mean_dict: dict) -> Any:
+ """Return the mean of values in a given dict."""
+ return np.mean(list(mean_dict.values()))
-def diff_stats(file1, file2, per_tensor_and_channel=False):
- dataset1 = NumpyTFReader(file1)
- dataset2 = NumpyTFReader(file2)
- totals = defaultdict(dict)
+def add_total(name: str, key: str, values: list, totals: dict) -> None:
+ """Append values to dict totals."""
+ if key not in totals[name]:
+ totals[name][key] = values
+ else:
+ totals[name][key] += values
+
- def add_total(name, key, values):
- if not key in totals[name]:
- totals[name][key] = values
- else:
- totals[name][key] += values
+def diff_stats(
+ file1: str | Path, file2: str | Path, per_tensor_and_channel: bool = False
+) -> tuple:
+ """Compare the statistics of outputs between two subgraphs."""
+ dataset1 = numpytf_read(file1)
+ dataset2 = numpytf_read(file2)
- # First iterate through dataset1 and calculate per-channel total for each tensor
+ totals: dict = defaultdict(dict)
+
+ # First iterate through dataset and calculate per-channel total for each tensor
count = 0
- for d in dataset1:
+ for data in dataset1:
count += 1
- for k, v in d.items():
- value = v.numpy().astype(np.double)
- add_total("dataset1_total", k, value)
+ for key, val in data.items():
+ value = val.numpy().astype(np.double)
+ add_total("dataset1_total", key, value, totals)
# Use this to calculate per-channel mean for each tensor
- per_tensor_mean = lambda name: {
- k: total / count for k, total in totals[name].items()
- }
+ def per_tensor_mean(name: str) -> dict:
+ return {k: total / count for k, total in totals[name].items()}
+
dataset1_mean = per_tensor_mean("dataset1_total")
- # Next iterate through both datasets and calculate per-channel total squared error
- # between them for each tensor and dataset1 variance for each tensor using the mean from above
- for i, (x1, x2) in enumerate(zip(dataset1, dataset2)):
- assert x1.keys() == x2.keys(), (
- "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n"
- % (
- i,
- file1,
- ", ".join(x1.keys()),
- file2,
- ", ".join(x2.keys()),
- )
+ # Next iterate through both datasets and calculate per-channel total squared
+ # error between them for each tensor and dataset1 variance for each tensor
+ # using the mean from above
+ for i, (ds1, ds2) in enumerate(zip(dataset1, dataset2)):
+ assert ds1.keys() == ds2.keys(), (
+ f"At input {i} the files have different sets of tensors.\n"
+ f"{file1}: {', '.join(ds1.keys())}\n"
+ f"{file2}: {', '.join(ds2.keys())}\n"
)
- for k in x1.keys():
- v1 = x1[k].numpy().astype(np.double)
- v2 = x2[k].numpy().astype(np.double)
- add_total("ae", k, abs(v1 - v2))
- add_total("se", k, (v1 - v2) ** 2)
- add_total("dataset1_variance", k, (v1 - dataset1_mean[k]) ** 2)
+ for key in ds1.keys():
+ tensor1 = ds1[key].numpy().astype(np.double)
+ tensor2 = ds2[key].numpy().astype(np.double)
+ add_total("ae", key, abs(tensor1 - tensor2), totals)
+ add_total("se", key, (tensor1 - tensor2) ** 2, totals)
+ add_total(
+ "dataset1_variance",
+ key,
+ (tensor1 - dataset1_mean[key]) ** 2,
+ totals,
+ )
# Finally average over number of inputs to get the rmse and the dataset1 variance
mae = per_tensor_mean("ae")
@@ -66,7 +82,8 @@ def diff_stats(file1, file2, per_tensor_and_channel=False):
dataset1_var = per_tensor_mean("dataset1_variance")
is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var}
- # Divide by target standard deviation to get the per-channel nrmse for each tensor where possible
+ # Divide by target standard deviation to get the per-channel nrmse for each
+ # tensor where possible
nrmse = {
k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]])
for k, v in rmse.items()
@@ -74,6 +91,5 @@ def diff_stats(file1, file2, per_tensor_and_channel=False):
if per_tensor_and_channel:
return mae, nrmse
- else:
- dict_mean = lambda d: np.mean(list(d.values()))
- return dict_mean(mae), dict_mean(nrmse)
+
+ return dict_mean(mae), dict_mean(nrmse)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
index 758f4cf..14a7347 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/join.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -1,16 +1,31 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Join module."""
+from __future__ import annotations
+
import os
+from pathlib import Path
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+from tensorflow.lite.python.schema_py_generated import OperatorCodeT
+from tensorflow.lite.python.schema_py_generated import SubGraphT
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
-from mlia.nn.rewrite.core.utils.utils import load, save
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=0):
+def join_models(
+ input_src: str | Path,
+ input_dst: str | Path,
+ output_file: str | Path,
+ subgraph_src: SubGraphT = 0,
+ subgraph_dst: SubGraphT = 0,
+) -> None:
+ """Join two models and save the result into a given model file path."""
src_model = load(input_src)
dst_model = load(input_dst)
src_subgraph = src_model.subgraphs[subgraph_src]
@@ -19,8 +34,13 @@ def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=
save(dst_model, output_file)
-def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
- """Copy subgraph src into subgraph dst from model, connecting tensors with the same names"""
+def join_subgraphs(
+ src_model: ModelT,
+ src_subgraph: SubGraphT,
+ dst_model: ModelT,
+ dst_subgraph: SubGraphT,
+) -> None:
+ """Join two subgraphs, connecting tensors with the same names."""
# Find inputs that match outputs in the other graph and vice versa
dst_to_src = {
i: o
@@ -36,10 +56,11 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name
}
- assert not (src_to_dst and dst_to_src), (
- "Source and destination subgraphs appear to connect in a loop: %d tensors from src to dst, %d tensors from dst to src"
- % (len(src_to_dst), len(dst_to_src))
- )
+ assert not (
+ src_to_dst and dst_to_src
+ ), f"Source and destination subgraphs appear to connect in a loop: \
+ {len(src_to_dst)} tensors from src to dst, {len(dst_to_src)} \
+ tensors from dst to src"
# Relabel matched input/output tensors between graphs
tensor_relabel = src_to_dst if src_to_dst else dst_to_src
@@ -47,36 +68,46 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
# Remove matched inputs/outputs as these will now become internal tensors
if src_to_dst:
src_subgraph.outputs = [
- o for o in src_subgraph.outputs if not o in tensor_relabel.keys()
+ output
+ for output in src_subgraph.outputs
+ if output not in tensor_relabel.keys()
]
dst_subgraph.inputs = [
- i for i in dst_subgraph.inputs if not i in tensor_relabel.values()
+ input
+ for input in dst_subgraph.inputs
+ if input not in tensor_relabel.values()
]
else:
src_subgraph.inputs = [
- i for i in src_subgraph.inputs if not i in tensor_relabel.keys()
+ input for input in src_subgraph.inputs if input not in tensor_relabel.keys()
]
dst_subgraph.outputs = [
- o for o in dst_subgraph.outputs if not o in tensor_relabel.values()
+ output
+ for output in dst_subgraph.outputs
+ if output not in tensor_relabel.values()
]
buffer_relabel = {
- src_subgraph.tensors[i].buffer: dst_subgraph.tensors[o].buffer
- for i, o in tensor_relabel.items()
+ src_subgraph.tensors[input].buffer: dst_subgraph.tensors[output].buffer
+ for input, output in tensor_relabel.items()
}
used_tensors = [
- t for i, t in enumerate(src_subgraph.tensors) if not i in tensor_relabel
+ tensor
+ for i, tensor in enumerate(src_subgraph.tensors)
+ if i not in tensor_relabel
]
- used_buffer_ids = [t.buffer for t in used_tensors]
+ used_buffer_ids = [tensor.buffer for tensor in used_tensors]
+
+ def opcode_data(code: OperatorCodeT) -> tuple:
+ return (
+ code.builtinCode,
+ code.deprecatedBuiltinCode,
+ code.customCode,
+ code.version,
+ )
- opcode_data = lambda c: (
- c.builtinCode,
- c.deprecatedBuiltinCode,
- c.customCode,
- c.version,
- )
opcode_relabel = {
s: d
for s in range(len(src_model.operatorCodes))
@@ -85,7 +116,8 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
== opcode_data(dst_model.operatorCodes[d])
}
- # operator order defines execution schedule so must reflect the inputs/outputs dependencies
+ # operator order defines execution schedule so must reflect
+ # the inputs/outputs dependencies
if dst_to_src:
dst_subgraph.operators += src_subgraph.operators
else:
@@ -99,17 +131,17 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
] = -1 # Some files have ops with -1 input tensors; leave unchanged
for i in used_buffer_ids:
- if not i in buffer_relabel:
+ if i not in buffer_relabel:
buffer_relabel[i] = len(dst_model.buffers)
dst_model.buffers.append(src_model.buffers[i])
- for o in src_subgraph.operators:
- o.inputs = [tensor_relabel[t] for t in o.inputs]
- o.outputs = [tensor_relabel[t] for t in o.outputs]
- o.opcodeIndex = opcode_relabel[o.opcodeIndex]
+ for operator in src_subgraph.operators:
+ operator.inputs = [tensor_relabel[tensor] for tensor in operator.inputs]
+ operator.outputs = [tensor_relabel[tensor] for tensor in operator.outputs]
+ operator.opcodeIndex = opcode_relabel[operator.opcodeIndex]
- for t in used_tensors:
- t.buffer = buffer_relabel[t.buffer]
+ for tensor in used_tensors:
+ tensor.buffer = buffer_relabel[tensor.buffer]
src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs]
src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs]
@@ -118,11 +150,12 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs))
-def append_relabel(src, dst, map=None):
- if map is None:
- map = {}
- for i, x in enumerate(src):
- if not i in map:
- map[i] = len(dst)
+def append_relabel(src: list, dst: list, operator_map: dict | None = None) -> dict:
+ """Return a map over relabeled tensors in a subgraph."""
+ if not operator_map:
+ operator_map = {}
+ for i, x in enumerate(src): # pylint: disable=invalid-name
+ if i not in operator_map:
+ operator_map[i] = len(dst)
dst.append(x)
- return map
+ return operator_map
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index ae13313..90f3db8 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -1,34 +1,39 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Save subgraph data."""
+# pylint: disable=too-many-locals
+from __future__ import annotations
+
import math
import os
+from pathlib import Path
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
+from rich.progress import track
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-
-from tqdm import tqdm
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import (
- NumpyTFReader,
- NumpyTFWriter,
- TFLiteModel,
- numpytf_count,
-)
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
def record_model(
- input_filename,
- model_filename,
- output_filename,
- batch_size=None,
- show_progress=False,
- num_procs=1,
- num_threads=0,
-):
- """num_procs: 0 => detect real cores on system
- num_threads: 0 => TFLite impl. specific setting, usually 3"""
+ input_filename: str | Path,
+ model_filename: str | Path,
+ output_filename: str | Path,
+ batch_size: int = 0,
+ show_progress: bool = False,
+ num_procs: int = 1,
+ num_threads: int = 0,
+) -> None:
+ """Model recorder.
+
+ num_procs: 0 => detect real cores on system
+ num_threads: 0 => TFLite impl. specific setting, usually 3
+ """
model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size)
if not batch_size:
batch_size = (
@@ -36,10 +41,10 @@ def record_model(
) # automatically batch to the minimum effective size if not specified
total = numpytf_count(input_filename)
- dataset = NumpyTFReader(input_filename)
+ dataset = numpytf_read(input_filename)
if batch_size > 1:
- # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now.
+ # Collapse batch-size 1 items into batch-size n.
dataset = dataset.map(
lambda d: {k: tf.squeeze(v, axis=0) for k, v in d.items()}
)
@@ -48,7 +53,7 @@ def record_model(
with NumpyTFWriter(output_filename) as writer:
for _, named_x in enumerate(
- tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
):
named_y = model(named_x)
if batch_size > 1: