aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2023-03-15 11:27:08 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:43:14 +0100
commit867f37d643e66c0223457c28f5345f2f21db97f2 (patch)
tree4e3c55896760e24a8b5eadc5176ce7f5586552e1
parent62768232c5fe4ed6b87136c336b65e13d030e9d4 (diff)
downloadmlia-867f37d643e66c0223457c28f5345f2f21db97f2.tar.gz
Adapt rewrite module to MLIA coding standards
- Fix imports - Update variable names - Refactor helper functions - Add licence headers - Add docstrings - Use f-strings rather than % notation - Create type annotations in rewrite module - Migrate from tqdm to rich progress bar - Use logging module in rewrite module: All print statements are replaced with logging module Resolves: MLIA-831, MLIA-842, MLIA-844, MLIA-846 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: Idee37538d72b9f01128a894281a8d10155f7c17c
-rw-r--r--setup.cfg1
-rw-r--r--src/mlia/nn/rewrite/__init__.py1
-rw-r--r--src/mlia/nn/rewrite/core/__init__.py1
-rw-r--r--src/mlia/nn/rewrite/core/extract.py35
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/__init__.py1
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py139
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/diff.py102
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py111
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py51
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py3
-rw-r--r--src/mlia/nn/rewrite/core/train.py433
-rw-r--r--src/mlia/nn/rewrite/core/utils/__init__.py1
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py140
-rw-r--r--src/mlia/nn/rewrite/core/utils/parallel.py90
-rw-r--r--src/mlia/nn/rewrite/core/utils/utils.py22
-rw-r--r--src/mlia/nn/select.py14
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_cut.py4
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_record.py4
-rw-r--r--tests/test_nn_rewrite_core_train.py14
19 files changed, 691 insertions, 476 deletions
diff --git a/setup.cfg b/setup.cfg
index 7cdd3c5..5a68b6b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -35,7 +35,6 @@ install_requires =
requests~=2.31.0
rich~=13.5.2
tomli~=2.0.1 ; python_version<"3.11"
- tqdm~=4.65.0
[options.packages.find]
where = src
diff --git a/src/mlia/nn/rewrite/__init__.py b/src/mlia/nn/rewrite/__init__.py
index 8c1f750..74298f6 100644
--- a/src/mlia/nn/rewrite/__init__.py
+++ b/src/mlia/nn/rewrite/__init__.py
@@ -1,2 +1,3 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Rewrite module."""
diff --git a/src/mlia/nn/rewrite/core/__init__.py b/src/mlia/nn/rewrite/core/__init__.py
index 8c1f750..8816dc1 100644
--- a/src/mlia/nn/rewrite/core/__init__.py
+++ b/src/mlia/nn/rewrite/core/__init__.py
@@ -1,2 +1,3 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Rewrite core module."""
diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py
index 5fcd348..f609955 100644
--- a/src/mlia/nn/rewrite/core/extract.py
+++ b/src/mlia/nn/rewrite/core/extract.py
@@ -1,28 +1,33 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Extract module."""
+# pylint: disable=too-many-arguments, too-many-locals
import os
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
-
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+from tensorflow.lite.python.schema_py_generated import SubGraphT
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
from mlia.nn.rewrite.core.graph_edit.record import record_model
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+
def extract(
- output_path,
- model_file,
- input_data,
- input_names,
- output_names,
- subgraph=0,
- skip_outputs=False,
- show_progress=False,
- num_procs=1,
- num_threads=0,
-):
+ output_path: str,
+ model_file: str,
+ input_filename: str,
+ input_names: list,
+ output_names: list,
+ subgraph: SubGraphT = 0,
+ skip_outputs: bool = False,
+ show_progress: bool = False,
+ num_procs: int = 1,
+ num_threads: int = 0,
+) -> None:
+ """Extract a model after cut and record."""
try:
os.mkdir(output_path)
except FileExistsError:
@@ -39,7 +44,7 @@ def extract(
input_tfrec = os.path.join(output_path, "input.tfrec")
record_model(
- input_data,
+ input_filename,
start_file,
input_tfrec,
show_progress=show_progress,
diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
index 8c1f750..273f4a4 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/__init__.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
@@ -1,2 +1,3 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Rewrite core graph edit module."""
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
index a323b7b..2707eb1 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/cut.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -1,128 +1,143 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Cut module."""
import os
from collections import defaultdict
+from typing import Optional
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+from tensorflow.lite.python.schema_py_generated import SubGraphT
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
-from mlia.nn.rewrite.core.utils.utils import load, save
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-def cut_subgraph(subgraph, input_tensor_names, output_tensor_names):
- """Change the global inputs and outputs of a graph to the provided named tensors"""
+def tensors_by_name(subgraph: SubGraphT, names: list) -> list:
+ """Seek out tensors from a subgraph and return the result."""
+ seek = frozenset([name.encode("utf-8") for name in names])
+ tensors = [i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek]
+ return tensors
- def tensors_by_name(names):
- seek = frozenset([name.encode("utf-8") for name in names])
- tensors = [
- i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek
- ]
- return tensors
+def cut_subgraph(
+ subgraph: SubGraphT,
+ input_tensor_names: Optional[list],
+ output_tensor_names: Optional[list],
+) -> None:
+ """Change the global inputs and outputs of a graph to the provided named tensors."""
if input_tensor_names is not None:
- subgraph.inputs = tensors_by_name(input_tensor_names)
+ subgraph.inputs = tensors_by_name(subgraph, input_tensor_names)
assert len(subgraph.inputs) == len(
input_tensor_names
- ), "Expected %d input tensors: %s\nFound: %s" % (
- len(subgraph.inputs),
- ", ".join(input_tensor_names),
- ", ".join(subgraph.tensors[i].name for i in subgraph.inputs),
- )
-
+ ), f"Expected {len(subgraph.inputs)} input tensors: \
+ {', '.join(input_tensor_names)}\nFound: \
+ {', '.join(subgraph.tensors[i].name for i in subgraph.inputs)}"
if output_tensor_names is not None:
- subgraph.outputs = tensors_by_name(output_tensor_names)
+ subgraph.outputs = tensors_by_name(subgraph, output_tensor_names)
assert len(subgraph.outputs) == len(
output_tensor_names
- ), "Expected %d output tensors: %s\nFound: %s" % (
- len(subgraph.outputs),
- ", ".join(output_tensor_names),
- ", ".join(subgraph.tensors[i].name for i in subgraph.outputs),
- )
+ ), f"Expected {len(subgraph.outputs)} output tensors: \
+ {', '.join(output_tensor_names)}\nFound: \
+ {', '.join(subgraph.tensors[i].name for i in subgraph.outputs)}"
-def simplify(model):
- """Remove any unused operators, tensors and buffers from a model"""
- for s in model.subgraphs:
- simplify_subgraph(s)
+def simplify(model: ModelT) -> None:
+ """Remove any unused operators, tensors and buffers from a model."""
+ for subgraph in model.subgraphs:
+ simplify_subgraph(subgraph)
- used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs}
- used_buffers = used_buffers.union({m.buffer for m in model.metadata})
+ used_buffers = {
+ tensor.buffer for tensor in subgraph.tensors for subgraph in model.subgraphs
+ }
+ used_buffers = used_buffers.union({metadata.buffer for metadata in model.metadata})
used_buffers.add(
0
- ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime
+ ) # Buffer zero is always expected to be a zero-sized nullptr buffer by the
+ # TFLite runtime
model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers)
- for s in model.subgraphs:
- for t in s.tensors:
- t.buffer = buf_relabel[t.buffer]
+ for subgraph in model.subgraphs:
+ for tensor in subgraph.tensors:
+ tensor.buffer = buf_relabel[tensor.buffer]
- for m in model.metadata:
- m.buffer = buf_relabel[m.buffer]
+ for metadata in model.metadata:
+ metadata.buffer = buf_relabel[metadata.buffer]
-def simplify_subgraph(subgraph):
+def simplify_subgraph(subgraph: SubGraphT) -> None:
+ """Simplify a subgraph given its operators."""
requires = defaultdict(set)
- for o, operator in enumerate(subgraph.operators):
- for t in operator.outputs:
- if not t in subgraph.inputs:
- requires[t].add(o)
+ for output, operator in enumerate(subgraph.operators):
+ for tensor in operator.outputs:
+ if tensor not in subgraph.inputs:
+ requires[tensor].add(output)
op_set, ten_set = find_required(subgraph, requires, subgraph.outputs)
- subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set)
+ subgraph.operators, _ = filter_relabel(subgraph.operators, op_set)
subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set)
ten_relabel[-1] = -1 # Some files have ops with -1 input tensors; leave unchanged
- for op in subgraph.operators:
- op.inputs = [ten_relabel[t] for t in op.inputs]
- op.outputs = [ten_relabel[t] for t in op.outputs]
+ for operator in subgraph.operators:
+ operator.inputs = [ten_relabel[tensor] for tensor in operator.inputs]
+ operator.outputs = [ten_relabel[tensor] for tensor in operator.outputs]
- subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs]
- subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs]
+ subgraph.inputs = [ten_relabel[tensor] for tensor in subgraph.inputs]
+ subgraph.outputs = [ten_relabel[tensors] for tensors in subgraph.outputs]
-def find_required(subgraph, requires, tensors):
- visited_operators = set()
+def find_required(subgraph: SubGraphT, requires: dict, tensors: dict) -> tuple:
+ """Find required operators in a given subgraph."""
+ visited_operators: set = set()
visited_tensors = set(tensors)
stop_tensors = set(subgraph.inputs)
- changed = True
next_tensors = visited_tensors
while next_tensors:
loop_tensors = next_tensors
next_tensors = set()
- for t in loop_tensors:
- candidate_operators = set(requires[t])
+ for tensor in loop_tensors:
+ candidate_operators = set(requires[tensor])
new_operators = candidate_operators - visited_operators
visited_operators = visited_operators.union(new_operators)
- for op in new_operators:
- candidate_tensors = set(subgraph.operators[op].inputs)
+ for operator in new_operators:
+ candidate_tensors = set(subgraph.operators[operator].inputs)
new_tensors = candidate_tensors - (visited_tensors.union(next_tensors))
next_tensors = next_tensors.union(new_tensors)
visited_tensors = visited_tensors.union(candidate_tensors)
visited_tensors = visited_tensors.union(
- subgraph.operators[op].outputs
+ subgraph.operators[operator].outputs
) # include stub outputs but do not traverse them
next_tensors = next_tensors - stop_tensors
return visited_operators, visited_tensors
-def filter_relabel(src, filter):
- relabel = {}
- output = []
- for i, x in enumerate(src):
- if i in filter:
+def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple:
+ """Relabel tensors in a subgraph based on a filter."""
+ relabel: dict = {}
+ output: list = []
+ for i, out in enumerate(src_subgraph):
+ if i in relabel_filter:
relabel[i] = len(output)
- output.append(x)
+ output.append(out)
return output, relabel
-def cut_model(model_file, input_names, output_names, subgraph_index, output_file):
+def cut_model(
+ model_file: str,
+ input_names: Optional[list],
+ output_names: Optional[list],
+ subgraph_index: int,
+ output_file: str,
+) -> None:
+ """Cut subgraphs and simplify a given model."""
model = load(model_file)
subgraph = model.subgraphs[subgraph_index]
cut_subgraph(subgraph, input_names, output_names)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py
index 0829f0a..198e47e 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/diff.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py
@@ -1,63 +1,79 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Diff module: compare subgraph outputs."""
+# pylint: disable=too-many-locals
+from __future__ import annotations
+
import os
from collections import defaultdict
+from pathlib import Path
+from typing import Any
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import numpy as np
import tensorflow as tf
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-import numpy as np
-from tensorflow.lite.python import interpreter as interpreter_wrapper
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter
+def dict_mean(mean_dict: dict) -> Any:
+ """Return the mean of values in a given dict."""
+ return np.mean(list(mean_dict.values()))
-def diff_stats(file1, file2, per_tensor_and_channel=False):
- dataset1 = NumpyTFReader(file1)
- dataset2 = NumpyTFReader(file2)
- totals = defaultdict(dict)
+def add_total(name: str, key: str, values: list, totals: dict) -> None:
+ """Append values to dict totals."""
+ if key not in totals[name]:
+ totals[name][key] = values
+ else:
+ totals[name][key] += values
+
- def add_total(name, key, values):
- if not key in totals[name]:
- totals[name][key] = values
- else:
- totals[name][key] += values
+def diff_stats(
+ file1: str | Path, file2: str | Path, per_tensor_and_channel: bool = False
+) -> tuple:
+ """Compare the statistics of outputs between two subgraphs."""
+ dataset1 = numpytf_read(file1)
+ dataset2 = numpytf_read(file2)
- # First iterate through dataset1 and calculate per-channel total for each tensor
+ totals: dict = defaultdict(dict)
+
+ # First iterate through dataset and calculate per-channel total for each tensor
count = 0
- for d in dataset1:
+ for data in dataset1:
count += 1
- for k, v in d.items():
- value = v.numpy().astype(np.double)
- add_total("dataset1_total", k, value)
+ for key, val in data.items():
+ value = val.numpy().astype(np.double)
+ add_total("dataset1_total", key, value, totals)
# Use this to calculate per-channel mean for each tensor
- per_tensor_mean = lambda name: {
- k: total / count for k, total in totals[name].items()
- }
+ def per_tensor_mean(name: str) -> dict:
+ return {k: total / count for k, total in totals[name].items()}
+
dataset1_mean = per_tensor_mean("dataset1_total")
- # Next iterate through both datasets and calculate per-channel total squared error
- # between them for each tensor and dataset1 variance for each tensor using the mean from above
- for i, (x1, x2) in enumerate(zip(dataset1, dataset2)):
- assert x1.keys() == x2.keys(), (
- "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n"
- % (
- i,
- file1,
- ", ".join(x1.keys()),
- file2,
- ", ".join(x2.keys()),
- )
+ # Next iterate through both datasets and calculate per-channel total squared
+ # error between them for each tensor and dataset1 variance for each tensor
+ # using the mean from above
+ for i, (ds1, ds2) in enumerate(zip(dataset1, dataset2)):
+ assert ds1.keys() == ds2.keys(), (
+ f"At input {i} the files have different sets of tensors.\n"
+ f"{file1}: {', '.join(ds1.keys())}\n"
+ f"{file2}: {', '.join(ds2.keys())}\n"
)
- for k in x1.keys():
- v1 = x1[k].numpy().astype(np.double)
- v2 = x2[k].numpy().astype(np.double)
- add_total("ae", k, abs(v1 - v2))
- add_total("se", k, (v1 - v2) ** 2)
- add_total("dataset1_variance", k, (v1 - dataset1_mean[k]) ** 2)
+ for key in ds1.keys():
+ tensor1 = ds1[key].numpy().astype(np.double)
+ tensor2 = ds2[key].numpy().astype(np.double)
+ add_total("ae", key, abs(tensor1 - tensor2), totals)
+ add_total("se", key, (tensor1 - tensor2) ** 2, totals)
+ add_total(
+ "dataset1_variance",
+ key,
+ (tensor1 - dataset1_mean[key]) ** 2,
+ totals,
+ )
# Finally average over number of inputs to get the rmse and the dataset1 variance
mae = per_tensor_mean("ae")
@@ -66,7 +82,8 @@ def diff_stats(file1, file2, per_tensor_and_channel=False):
dataset1_var = per_tensor_mean("dataset1_variance")
is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var}
- # Divide by target standard deviation to get the per-channel nrmse for each tensor where possible
+ # Divide by target standard deviation to get the per-channel nrmse for each
+ # tensor where possible
nrmse = {
k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]])
for k, v in rmse.items()
@@ -74,6 +91,5 @@ def diff_stats(file1, file2, per_tensor_and_channel=False):
if per_tensor_and_channel:
return mae, nrmse
- else:
- dict_mean = lambda d: np.mean(list(d.values()))
- return dict_mean(mae), dict_mean(nrmse)
+
+ return dict_mean(mae), dict_mean(nrmse)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
index 758f4cf..14a7347 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/join.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -1,16 +1,31 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Join module."""
+from __future__ import annotations
+
import os
+from pathlib import Path
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+from tensorflow.lite.python.schema_py_generated import OperatorCodeT
+from tensorflow.lite.python.schema_py_generated import SubGraphT
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
-from mlia.nn.rewrite.core.utils.utils import load, save
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=0):
+def join_models(
+ input_src: str | Path,
+ input_dst: str | Path,
+ output_file: str | Path,
+ subgraph_src: SubGraphT = 0,
+ subgraph_dst: SubGraphT = 0,
+) -> None:
+ """Join two models and save the result into a given model file path."""
src_model = load(input_src)
dst_model = load(input_dst)
src_subgraph = src_model.subgraphs[subgraph_src]
@@ -19,8 +34,13 @@ def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=
save(dst_model, output_file)
-def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
- """Copy subgraph src into subgraph dst from model, connecting tensors with the same names"""
+def join_subgraphs(
+ src_model: ModelT,
+ src_subgraph: SubGraphT,
+ dst_model: ModelT,
+ dst_subgraph: SubGraphT,
+) -> None:
+ """Join two subgraphs, connecting tensors with the same names."""
# Find inputs that match outputs in the other graph and vice versa
dst_to_src = {
i: o
@@ -36,10 +56,11 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name
}
- assert not (src_to_dst and dst_to_src), (
- "Source and destination subgraphs appear to connect in a loop: %d tensors from src to dst, %d tensors from dst to src"
- % (len(src_to_dst), len(dst_to_src))
- )
+ assert not (
+ src_to_dst and dst_to_src
+ ), f"Source and destination subgraphs appear to connect in a loop: \
+ {len(src_to_dst)} tensors from src to dst, {len(dst_to_src)} \
+ tensors from dst to src"
# Relabel matched input/output tensors between graphs
tensor_relabel = src_to_dst if src_to_dst else dst_to_src
@@ -47,36 +68,46 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
# Remove matched inputs/outputs as these will now become internal tensors
if src_to_dst:
src_subgraph.outputs = [
- o for o in src_subgraph.outputs if not o in tensor_relabel.keys()
+ output
+ for output in src_subgraph.outputs
+ if output not in tensor_relabel.keys()
]
dst_subgraph.inputs = [
- i for i in dst_subgraph.inputs if not i in tensor_relabel.values()
+ input
+ for input in dst_subgraph.inputs
+ if input not in tensor_relabel.values()
]
else:
src_subgraph.inputs = [
- i for i in src_subgraph.inputs if not i in tensor_relabel.keys()
+ input for input in src_subgraph.inputs if input not in tensor_relabel.keys()
]
dst_subgraph.outputs = [
- o for o in dst_subgraph.outputs if not o in tensor_relabel.values()
+ output
+ for output in dst_subgraph.outputs
+ if output not in tensor_relabel.values()
]
buffer_relabel = {
- src_subgraph.tensors[i].buffer: dst_subgraph.tensors[o].buffer
- for i, o in tensor_relabel.items()
+ src_subgraph.tensors[input].buffer: dst_subgraph.tensors[output].buffer
+ for input, output in tensor_relabel.items()
}
used_tensors = [
- t for i, t in enumerate(src_subgraph.tensors) if not i in tensor_relabel
+ tensor
+ for i, tensor in enumerate(src_subgraph.tensors)
+ if i not in tensor_relabel
]
- used_buffer_ids = [t.buffer for t in used_tensors]
+ used_buffer_ids = [tensor.buffer for tensor in used_tensors]
+
+ def opcode_data(code: OperatorCodeT) -> tuple:
+ return (
+ code.builtinCode,
+ code.deprecatedBuiltinCode,
+ code.customCode,
+ code.version,
+ )
- opcode_data = lambda c: (
- c.builtinCode,
- c.deprecatedBuiltinCode,
- c.customCode,
- c.version,
- )
opcode_relabel = {
s: d
for s in range(len(src_model.operatorCodes))
@@ -85,7 +116,8 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
== opcode_data(dst_model.operatorCodes[d])
}
- # operator order defines execution schedule so must reflect the inputs/outputs dependencies
+ # operator order defines execution schedule so must reflect
+ # the inputs/outputs dependencies
if dst_to_src:
dst_subgraph.operators += src_subgraph.operators
else:
@@ -99,17 +131,17 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
] = -1 # Some files have ops with -1 input tensors; leave unchanged
for i in used_buffer_ids:
- if not i in buffer_relabel:
+ if i not in buffer_relabel:
buffer_relabel[i] = len(dst_model.buffers)
dst_model.buffers.append(src_model.buffers[i])
- for o in src_subgraph.operators:
- o.inputs = [tensor_relabel[t] for t in o.inputs]
- o.outputs = [tensor_relabel[t] for t in o.outputs]
- o.opcodeIndex = opcode_relabel[o.opcodeIndex]
+ for operator in src_subgraph.operators:
+ operator.inputs = [tensor_relabel[tensor] for tensor in operator.inputs]
+ operator.outputs = [tensor_relabel[tensor] for tensor in operator.outputs]
+ operator.opcodeIndex = opcode_relabel[operator.opcodeIndex]
- for t in used_tensors:
- t.buffer = buffer_relabel[t.buffer]
+ for tensor in used_tensors:
+ tensor.buffer = buffer_relabel[tensor.buffer]
src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs]
src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs]
@@ -118,11 +150,12 @@ def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs))
-def append_relabel(src, dst, map=None):
- if map is None:
- map = {}
- for i, x in enumerate(src):
- if not i in map:
- map[i] = len(dst)
+def append_relabel(src: list, dst: list, operator_map: dict | None = None) -> dict:
+ """Return a map over relabeled tensors in a subgraph."""
+ if not operator_map:
+ operator_map = {}
+ for i, x in enumerate(src): # pylint: disable=invalid-name
+ if i not in operator_map:
+ operator_map[i] = len(dst)
dst.append(x)
- return map
+ return operator_map
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index ae13313..90f3db8 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -1,34 +1,39 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Save subgraph data."""
+# pylint: disable=too-many-locals
+from __future__ import annotations
+
import math
import os
+from pathlib import Path
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
+from rich.progress import track
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-
-from tqdm import tqdm
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import (
- NumpyTFReader,
- NumpyTFWriter,
- TFLiteModel,
- numpytf_count,
-)
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
def record_model(
- input_filename,
- model_filename,
- output_filename,
- batch_size=None,
- show_progress=False,
- num_procs=1,
- num_threads=0,
-):
- """num_procs: 0 => detect real cores on system
- num_threads: 0 => TFLite impl. specific setting, usually 3"""
+ input_filename: str | Path,
+ model_filename: str | Path,
+ output_filename: str | Path,
+ batch_size: int = 0,
+ show_progress: bool = False,
+ num_procs: int = 1,
+ num_threads: int = 0,
+) -> None:
+ """Model recorder.
+
+ num_procs: 0 => detect real cores on system
+ num_threads: 0 => TFLite impl. specific setting, usually 3
+ """
model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size)
if not batch_size:
batch_size = (
@@ -36,10 +41,10 @@ def record_model(
) # automatically batch to the minimum effective size if not specified
total = numpytf_count(input_filename)
- dataset = NumpyTFReader(input_filename)
+ dataset = numpytf_read(input_filename)
if batch_size > 1:
- # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now.
+ # Collapse batch-size 1 items into batch-size n.
dataset = dataset.map(
lambda d: {k: tf.squeeze(v, axis=0) for k, v in d.items()}
)
@@ -48,7 +53,7 @@ def record_model(
with NumpyTFWriter(output_filename) as writer:
for _, named_x in enumerate(
- tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
):
named_y = model(named_x)
if batch_size > 1:
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index d4f61c5..ab34b47 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -42,4 +42,5 @@ class Rewriter(Optimizer):
return self.model
def optimization_config(self) -> str:
- """Optimization configirations."""
+ """Optimization configurations."""
+ return str(self.optimizer_configuration)
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 096daf4..f837964 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,35 +1,41 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Sequential trainer."""
+# pylint: disable=too-many-arguments, too-many-instance-attributes,
+# pylint: disable=too-many-locals, too-many-branches, too-many-statements
+from __future__ import annotations
+
+import logging
import math
import os
import tempfile
from collections import defaultdict
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import get_args
+from typing import Literal
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import numpy as np
import tensorflow as tf
+from numpy.random import Generator
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-
-try:
- from tensorflow.keras.optimizers.schedules import CosineDecay
-except ImportError:
- # In TF 2.4 CosineDecay was still experimental
- from tensorflow.keras.experimental import CosineDecay
-
-import numpy as np
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import (
- NumpyTFReader,
- NumpyTFWriter,
- TFLiteModel,
- numpytf_count,
-)
-from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
-from mlia.nn.rewrite.core.graph_edit.record import record_model
-from mlia.nn.rewrite.core.utils.utils import load, save
from mlia.nn.rewrite.core.extract import extract
-from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.graph_edit.diff import diff_stats
+from mlia.nn.rewrite.core.graph_edit.join import join_models
+from mlia.nn.rewrite.core.graph_edit.record import record_model
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
+from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
+from mlia.utils.logging import log_action
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+logger = logging.getLogger(__name__)
augmentation_presets = {
"none": (None, None),
@@ -40,31 +46,34 @@ augmentation_presets = {
"mix_gaussian_small": (1.6, 0.3),
}
-learning_rate_schedules = {"cosine", "late", "constant"}
+LearningRateSchedule = Literal["cosine", "late", "constant"]
+LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule)
def train(
- source_model,
- unmodified_model,
- output_model,
- input_tfrec,
- replace_fn,
- input_tensors,
- output_tensors,
- augment,
- steps,
- lr,
- batch_size,
- verbose,
- show_progress,
- learning_rate_schedule="cosine",
- checkpoint_at=None,
- checkpoint_decay_steps=0,
- num_procs=1,
- num_threads=0,
-):
+ source_model: str,
+ unmodified_model: Any,
+ output_model: str,
+ input_tfrec: str,
+ replace_fn: Callable,
+ input_tensors: list,
+ output_tensors: list,
+ augment: tuple[float | None, float | None],
+ steps: int,
+ learning_rate: float,
+ batch_size: int,
+ verbose: bool,
+ show_progress: bool,
+ learning_rate_schedule: LearningRateSchedule = "cosine",
+ checkpoint_at: list | None = None,
+ num_procs: int = 1,
+ num_threads: int = 0,
+) -> Any:
+ """Extract and train a model, and return the results."""
if unmodified_model:
- unmodified_model_dir = tempfile.TemporaryDirectory()
+ unmodified_model_dir = (
+ tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
+ )
unmodified_model_dir_path = unmodified_model_dir.name
extract(
unmodified_model_dir_path,
@@ -79,8 +88,6 @@ def train(
results = []
with tempfile.TemporaryDirectory() as train_dir:
- p = lambda file: os.path.join(train_dir, file)
-
extract(
train_dir,
source_model,
@@ -94,14 +101,13 @@ def train(
tflite_filenames = train_in_dir(
train_dir,
unmodified_model_dir_path,
- p("new.tflite"),
+ Path(train_dir, "new.tflite"),
replace_fn,
augment,
steps,
- lr,
+ learning_rate,
batch_size,
checkpoint_at=checkpoint_at,
- checkpoint_decay_steps=checkpoint_decay_steps,
verbose=verbose,
show_progress=show_progress,
num_procs=num_procs,
@@ -114,7 +120,8 @@ def train(
if output_model:
if i + 1 < len(tflite_filenames):
- # Append the same _@STEPS.tflite postfix used by intermediate checkpoints for all but the last output
+ # Append the same _@STEPS.tflite postfix used by intermediate
+ # checkpoints for all but the last output
postfix = filename.split("_@")[-1]
output_filename = output_model.split(".tflite")[0] + postfix
else:
@@ -122,115 +129,130 @@ def train(
join_in_dir(train_dir, filename, output_filename)
if unmodified_model_dir:
- unmodified_model_dir.cleanup()
+ cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
return (
results if checkpoint_at else results[0]
) # only return a list if multiple checkpoints are asked for
-def eval_in_dir(dir, new_part, num_procs=1, num_threads=0):
- p = lambda file: os.path.join(dir, file)
- input = (
- p("input_orig.tfrec")
- if os.path.exists(p("input_orig.tfrec"))
- else p("input.tfrec")
+def eval_in_dir(
+ target_dir: str, new_part: str, num_procs: int = 1, num_threads: int = 0
+) -> tuple:
+ """Evaluate a model in a given directory."""
+ model_input_path = Path(target_dir, "input_orig.tfrec")
+ model_output_path = Path(target_dir, "output_orig.tfrec")
+ model_input = (
+ model_input_path
+ if model_input_path.exists()
+ else Path(target_dir, "input.tfrec")
)
output = (
- p("output_orig.tfrec")
- if os.path.exists(p("output_orig.tfrec"))
- else p("output.tfrec")
+ model_output_path
+ if model_output_path.exists()
+ else Path(target_dir, "output.tfrec")
)
with tempfile.TemporaryDirectory() as tmp_dir:
- predict = os.path.join(tmp_dir, "predict.tfrec")
+ predict = Path(tmp_dir, "predict.tfrec")
record_model(
- input, new_part, predict, num_procs=num_procs, num_threads=num_threads
+ str(model_input),
+ new_part,
+ str(predict),
+ num_procs=num_procs,
+ num_threads=num_threads,
)
- mae, nrmse = diff_stats(output, predict)
+ mae, nrmse = diff_stats(str(output), str(predict))
return mae, nrmse
-def join_in_dir(dir, new_part, output_model):
+def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None:
+ """Join two models in a given directory."""
with tempfile.TemporaryDirectory() as tmp_dir:
- d = lambda file: os.path.join(dir, file)
- new_end = os.path.join(tmp_dir, "new_end.tflite")
- join_models(new_part, d("end.tflite"), new_end)
- join_models(d("start.tflite"), new_end, output_model)
+ new_end = Path(tmp_dir, "new_end.tflite")
+ join_models(new_part, Path(model_dir, "end.tflite"), new_end)
+ join_models(Path(model_dir, "start.tflite"), new_end, output_model)
def train_in_dir(
- train_dir,
- baseline_dir,
- output_filename,
- replace_fn,
- augmentations,
- steps,
- lr=1e-3,
- batch_size=32,
- checkpoint_at=None,
- checkpoint_decay_steps=0,
- schedule="cosine",
- verbose=False,
- show_progress=False,
- num_procs=None,
- num_threads=1,
-):
- """Train a replacement for replace.tflite using the input.tfrec and output.tfrec in train_dir.
- If baseline_dir is provided, train the replacement to match baseline outputs for train_dir inputs.
- Result saved as new.tflite in train_dir.
+ train_dir: str,
+ baseline_dir: Any,
+ output_filename: Path,
+ replace_fn: Callable,
+ augmentations: tuple[float | None, float | None],
+ steps: int,
+ learning_rate: float = 1e-3,
+ batch_size: int = 32,
+ checkpoint_at: list | None = None,
+ schedule: str = "cosine",
+ verbose: bool = False,
+ show_progress: bool = False,
+ num_procs: int = 0,
+ num_threads: int = 1,
+) -> list:
+ """Train a replacement for replace.tflite using the input.tfrec \
+ and output.tfrec in train_dir.
+
+ If baseline_dir is provided, train the replacement to match baseline
+ outputs for train_dir inputs. Result saved as new.tflite in train_dir.
"""
teacher_dir = baseline_dir if baseline_dir else train_dir
teacher = ParallelTFLiteModel(
- "%s/replace.tflite" % teacher_dir, num_procs, num_threads, batch_size=batch_size
- )
- replace = TFLiteModel("%s/replace.tflite" % train_dir)
- assert len(teacher.input_tensors()) == 1, (
- "Can only train replacements with a single input tensor right now, found %s"
- % teacher.input_tensors()
- )
- assert len(teacher.output_tensors()) == 1, (
- "Can only train replacements with a single output tensor right now, found %s"
- % teacher.output_tensors()
+ f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size
)
+ replace = TFLiteModel(f"{train_dir}/replace.tflite")
+ assert (
+ len(teacher.input_tensors()) == 1
+ ), f"Can only train replacements with a single input tensor right now, \
+ found {teacher.input_tensors()}"
+
+ assert (
+ len(teacher.output_tensors()) == 1
+ ), f"Can only train replacements with a single output tensor right now, \
+ found {teacher.output_tensors()}"
+
input_name = teacher.input_tensors()[0]
output_name = teacher.output_tensors()[0]
assert len(teacher.shape_from_name) == len(
replace.shape_from_name
- ), "Baseline and train models must have the same number of inputs and outputs. Teacher: {}\nTrain dir: {}".format(
- teacher.shape_from_name, replace.shape_from_name
- )
+ ), f"Baseline and train models must have the same number of inputs and outputs. \
+ Teacher: {teacher.shape_from_name}\nTrain dir: {replace.shape_from_name}"
+
assert all(
tn == rn and (ts[1:] == rs[1:]).all()
for (tn, ts), (rn, rs) in zip(
teacher.shape_from_name.items(), replace.shape_from_name.items()
)
- ), "Baseline and train models must have the same input and output shapes for the subgraph being replaced. Teacher: {}\nTrain dir: {}".format(
- teacher.shape_from_name, replace.shape_from_name
- )
+ ), "Baseline and train models must have the same input and output shapes for the \
+ subgraph being replaced. Teacher: {teacher.shape_from_name}\n \
+ Train dir: {replace.shape_from_name}"
- input_filename = os.path.join(train_dir, "input.tfrec")
- total = numpytf_count(input_filename)
- dict_inputs = NumpyTFReader(input_filename)
+ input_filename = Path(train_dir, "input.tfrec")
+ total = numpytf_count(str(input_filename))
+ dict_inputs = numpytf_read(str(input_filename))
inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0))
if any(augmentations):
- # Map the teacher inputs here because the augmentation stage passes these through a TFLite model to get the outputs
- teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "input.tfrec")).map(
+ # Map the teacher inputs here because the augmentation stage passes these
+ # through a TFLite model to get the outputs
+ teacher_outputs = numpytf_read(str(Path(teacher_dir, "input.tfrec"))).map(
lambda d: tf.squeeze(d[input_name], axis=0)
)
else:
- teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "output.tfrec")).map(
+ teacher_outputs = numpytf_read(str(Path(teacher_dir, "output.tfrec"))).map(
lambda d: tf.squeeze(d[output_name], axis=0)
)
steps_per_epoch = math.ceil(total / batch_size)
epochs = int(math.ceil(steps / steps_per_epoch))
if verbose:
- print(
- "Training on %d items for %d steps (%d epochs with batch size %d)"
- % (total, epochs * steps_per_epoch, epochs, batch_size)
+ logger.info(
+ "Training on %d items for %d steps (%d epochs with batch size %d)",
+ total,
+ epochs * steps_per_epoch,
+ epochs,
+ batch_size,
)
dataset = tf.data.Dataset.zip((inputs, teacher_outputs))
@@ -240,13 +262,21 @@ def train_in_dir(
if any(augmentations):
augment_train, augment_teacher = augment_fn_twins(dict_inputs, augmentations)
- augment_fn = lambda train, teach: (
- augment_train({input_name: train})[input_name],
- teacher(augment_teacher({input_name: teach}))[output_name],
- )
+
+ def get_augment_results(
+ train: Any, teach: Any # pylint: disable=redefined-outer-name
+ ) -> tuple:
+ """Return results of train and teach based on augmentations."""
+ return (
+ augment_train({input_name: train})[input_name],
+ teacher(augment_teacher({input_name: teach}))[output_name],
+ )
+
dataset = dataset.map(
- lambda train, teach: tf.py_function(
- augment_fn, inp=[train, teach], Tout=[tf.float32, tf.float32]
+ lambda augment_train, augment_teach: tf.py_function(
+ get_augment_results,
+ inp=[augment_train, augment_teach],
+ Tout=[tf.float32, tf.float32],
)
)
@@ -256,7 +286,7 @@ def train_in_dir(
output_shape = teacher.shape_from_name[output_name][1:]
model = replace_fn(input_shape, output_shape)
- optimizer = tf.keras.optimizers.Nadam(learning_rate=lr)
+ optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate)
loss_fn = tf.keras.losses.MeanSquaredError()
model.compile(optimizer=optimizer, loss=loss_fn)
@@ -265,20 +295,26 @@ def train_in_dir(
steps_so_far = 0
- def cosine_decay(epoch_step, logs):
- """Cosine decay from lr at start of the run to zero at the end"""
+ def cosine_decay(
+ epoch_step: int, logs: Any # pylint: disable=unused-argument
+ ) -> None:
+ """Cosine decay from learning rate at start of the run to zero at the end."""
current_step = epoch_step + steps_so_far
- learning_rate = lr * (math.cos(math.pi * current_step / steps) + 1) / 2.0
- tf.keras.backend.set_value(optimizer.learning_rate, learning_rate)
+ cd_learning_rate = (
+ learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0
+ )
+ tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate)
- def late_decay(epoch_step, logs):
- """Constant until the last 20% of the run, then linear decay to zero"""
+ def late_decay(
+ epoch_step: int, logs: Any # pylint: disable=unused-argument
+ ) -> None:
+ """Constant until the last 20% of the run, then linear decay to zero."""
current_step = epoch_step + steps_so_far
steps_remaining = steps - current_step
decay_length = steps // 5
decay_fraction = min(steps_remaining, decay_length) / decay_length
- learning_rate = lr * decay_fraction
- tf.keras.backend.set_value(optimizer.learning_rate, learning_rate)
+ ld_learning_rate = learning_rate * decay_fraction
+ tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate)
if schedule == "cosine":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
@@ -287,9 +323,10 @@ def train_in_dir(
elif schedule == "constant":
callbacks = []
else:
- assert schedule not in learning_rate_schedules
+ assert schedule not in LEARNING_RATE_SCHEDULES
raise ValueError(
- f'LR schedule "{schedule}" not implemented - expected one of {learning_rate_schedules}.'
+ f'Learning rate schedule "{schedule}" not implemented - '
+ f"expected one of {LEARNING_RATE_SCHEDULES}."
)
output_filenames = []
@@ -305,53 +342,66 @@ def train_in_dir(
verbose=show_progress,
)
steps_so_far += steps_to_train
- print(
- "lr decayed from %f to %f over %d steps"
- % (lr_start, optimizer.learning_rate.numpy(), steps_to_train)
+ logger.info(
+ "lr decayed from %f to %f over %d steps",
+ lr_start,
+ optimizer.learning_rate.numpy(),
+ steps_to_train,
)
if steps_so_far < steps:
- filename, ext = os.path.splitext(output_filename)
- checkpoint_filename = filename + ("_@%d" % steps_so_far) + ext
+ filename, ext = Path(output_filename).parts[1:]
+ checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
else:
- checkpoint_filename = output_filename
- print("%d/%d: Saved as %s" % (steps_so_far, steps, checkpoint_filename))
- save_as_tflite(
- model,
- checkpoint_filename,
- input_name,
- replace.shape_from_name[input_name],
- output_name,
- replace.shape_from_name[output_name],
- )
- output_filenames.append(checkpoint_filename)
+ checkpoint_filename = str(output_filename)
+ with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"):
+ save_as_tflite(
+ model,
+ checkpoint_filename,
+ input_name,
+ replace.shape_from_name[input_name],
+ output_name,
+ replace.shape_from_name[output_name],
+ )
+ output_filenames.append(checkpoint_filename)
teacher.close()
return output_filenames
def save_as_tflite(
- keras_model, filename, input_name, input_shape, output_name, output_shape
-):
+ keras_model: tf.keras.Model,
+ filename: str,
+ input_name: str,
+ input_shape: list,
+ output_name: str,
+ output_shape: list,
+) -> None:
+ """Save Keras model as TFLite file."""
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()
- with open(filename, "wb") as f:
- f.write(tflite_model)
+ with open(filename, "wb") as file:
+ file.write(tflite_model)
# Now fix the shapes and names to match those we expect
- fb = load(filename)
- i = fb.subgraphs[0].inputs[0]
- fb.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32)
- fb.subgraphs[0].tensors[i].name = input_name.encode("utf-8")
- o = fb.subgraphs[0].outputs[0]
- fb.subgraphs[0].tensors[o].shape = np.array(output_shape, dtype=np.int32)
- fb.subgraphs[0].tensors[o].name = output_name.encode("utf-8")
- save(fb, filename)
-
-
-def augment_fn_twins(inputs, augmentations):
- """Return a pair of twinned augmentation functions with the same sequence of random numbers"""
+ flatbuffer = load(filename)
+ i = flatbuffer.subgraphs[0].inputs[0]
+ flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32)
+ flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8")
+ output = flatbuffer.subgraphs[0].outputs[0]
+ flatbuffer.subgraphs[0].tensors[output].shape = np.array(
+ output_shape, dtype=np.int32
+ )
+ flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8")
+ save(flatbuffer, filename)
+
+
+def augment_fn_twins(
+ inputs: dict, augmentations: tuple[float | None, float | None]
+) -> Any:
+ """Return a pair of twinned augmentation functions with the same sequence \
+ of random numbers."""
seed = np.random.randint(2**32 - 1)
rng1 = np.random.default_rng(seed)
rng2 = np.random.default_rng(seed)
@@ -360,52 +410,67 @@ def augment_fn_twins(inputs, augmentations):
)
-def augment_fn(inputs, augmentations, rng):
+def augment_fn(
+ inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator
+) -> Any:
+ """Augmentation module."""
mixup_strength, gaussian_strength = augmentations
augments = []
if mixup_strength:
mixup_range = (0.5 - mixup_strength / 2, 0.5 + mixup_strength / 2)
- augment = lambda d: {
- k: mixup(rng, v.numpy(), mixup_range) for k, v in d.items()
- }
- augments.append(augment)
+
+ def mixup_augment(augment_dict: dict) -> dict:
+ return {
+ k: mixup(rng, v.numpy(), mixup_range) for k, v in augment_dict.items()
+ }
+
+ augments.append(mixup_augment)
if gaussian_strength:
values = defaultdict(list)
- for d in inputs.as_numpy_iterator():
- for k, v in d.items():
- values[k].append(v)
+ for numpy_dict in inputs.as_numpy_iterator():
+ for key, value in numpy_dict.items():
+ values[key].append(value)
noise_scale = {
k: np.std(v, axis=0).astype(np.float32) for k, v in values.items()
}
- augment = lambda d: {
- k: v
- + rng.standard_normal(v.shape).astype(np.float32)
- * gaussian_strength
- * noise_scale[k]
- for k, v in d.items()
- }
- augments.append(augment)
- if len(augments) == 0:
+ def gaussian_strength_augment(augment_dict: dict) -> dict:
+ return {
+ k: v
+ + rng.standard_normal(v.shape).astype(np.float32)
+ * gaussian_strength
+ * noise_scale[k]
+ for k, v in augment_dict.items()
+ }
+
+ augments.append(gaussian_strength_augment)
+
+ if len(augments) == 0: # pylint: disable=no-else-return
return lambda x: x
elif len(augments) == 1:
return augments[0]
elif len(augments) == 2:
return lambda x: augments[1](augments[0](x))
else:
- assert False, "Unexpected number of augmentation functions (%d)" % len(augments)
-
-
-def mixup(rng, batch, beta_range=(0.0, 1.0)):
- """Each tensor in the batch becomes a linear combination of it and one other tensor"""
- a = batch
- b = np.array(batch)
- rng.shuffle(b) # randomly pair up tensors in the batch
+ assert (
+ False
+ ), f"Unexpected number of augmentation \
+ functions ({len(augments)})"
+
+
+def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any:
+ """Each tensor in the batch becomes a linear combination of it \
+ and one other tensor."""
+ batch_a = batch
+ batch_b = np.array(batch)
+ rng.shuffle(batch_b) # randomly pair up tensors in the batch
# random mixing coefficient for each pair
beta = rng.uniform(
low=beta_range[0], high=beta_range[1], size=batch.shape[0]
).astype(np.float32)
- return (a.T * beta).T + (b.T * (1.0 - beta)).T # return linear combinations
+ return (batch_a.T * beta).T + (
+ batch_b.T * (1.0 - beta)
+ ).T # return linear combinations
diff --git a/src/mlia/nn/rewrite/core/utils/__init__.py b/src/mlia/nn/rewrite/core/utils/__init__.py
index 8c1f750..f0b5026 100644
--- a/src/mlia/nn/rewrite/core/utils/__init__.py
+++ b/src/mlia/nn/rewrite/core/utils/__init__.py
@@ -1,2 +1,3 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Rewrite core utils module."""
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
index 2141003..9229810 100644
--- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -1,26 +1,32 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Numpy TFRecord utils."""
+from __future__ import annotations
+
import json
import os
import random
import tempfile
from collections import defaultdict
+from pathlib import Path
+from typing import Any
+from typing import Callable
import numpy as np
+import tensorflow as tf
+from tensorflow.lite.python import interpreter as interpreter_wrapper
from mlia.nn.rewrite.core.utils.utils import load
from mlia.nn.rewrite.core.utils.utils import save
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
-import tensorflow as tf
-
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-from tensorflow.lite.python import interpreter as interpreter_wrapper
+def make_decode_fn(filename: str) -> Callable:
+ """Make decode filename."""
-def make_decode_fn(filename):
- def decode_fn(record_bytes, type_map):
+ def decode_fn(record_bytes: Any, type_map: dict) -> dict:
parse_dict = {
name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
}
@@ -32,38 +38,48 @@ def make_decode_fn(filename):
return features
meta_filename = filename + ".meta"
- with open(meta_filename) as f:
- type_map = json.load(f)["type_map"]
+ with open(meta_filename, encoding="utf-8") as file:
+ type_map = json.load(file)["type_map"]
return lambda record_bytes: decode_fn(record_bytes, type_map)
-def NumpyTFReader(filename):
- decode_fn = make_decode_fn(filename)
- dataset = tf.data.TFRecordDataset(filename)
+def numpytf_read(filename: str | Path) -> Any:
+ """Read TFRecord dataset."""
+ decode_fn = make_decode_fn(str(filename))
+ dataset = tf.data.TFRecordDataset(str(filename))
return dataset.map(decode_fn)
-def numpytf_count(filename):
- meta_filename = filename + ".meta"
- with open(meta_filename) as f:
- return json.load(f)["count"]
+def numpytf_count(filename: str | Path) -> Any:
+ """Return count from TFRecord file."""
+ meta_filename = f"{filename}.meta"
+ with open(meta_filename, encoding="utf-8") as file:
+ return json.load(file)["count"]
class NumpyTFWriter:
- def __init__(self, filename):
+ """Numpy TF serializer."""
+
+ def __init__(self, filename: str | Path) -> None:
+ """Initiate a Numpy TF Serializer."""
self.filename = filename
- self.meta_filename = filename + ".meta"
- self.writer = tf.io.TFRecordWriter(filename)
- self.type_map = {}
+ self.meta_filename = f"{filename}.meta"
+ self.writer = tf.io.TFRecordWriter(str(filename))
+ self.type_map: dict = {}
self.count = 0
- def __enter__(self):
+ def __enter__(self) -> Any:
+ """Enter instance."""
return self
- def __exit__(self, type, value, traceback):
+ def __exit__(
+ self, exception_type: Any, exception_value: Any, exception_traceback: Any
+ ) -> None:
+ """Close instance."""
self.close()
- def write(self, array_dict):
+ def write(self, array_dict: dict) -> None:
+ """Write array dict."""
type_map = {n: str(a.dtype.name) for n, a in array_dict.items()}
self.type_map.update(type_map)
self.count += 1
@@ -77,31 +93,41 @@ class NumpyTFWriter:
example = tf.train.Example(features=tf.train.Features(feature=feature))
self.writer.write(example.SerializeToString())
- def close(self):
- with open(self.meta_filename, "w") as f:
+ def close(self) -> None:
+ """Close NumpyTFWriter."""
+ with open(self.meta_filename, "w", encoding="utf-8") as file:
meta = {"type_map": self.type_map, "count": self.count}
- json.dump(meta, f)
+ json.dump(meta, file)
self.writer.close()
class TFLiteModel:
- def __init__(self, filename, batch_size=None, num_threads=None):
- if num_threads == 0:
+ """A representation of a TFLite Model."""
+
+ def __init__(
+ self,
+ filename: str,
+ batch_size: int | None = None,
+ num_threads: int | None = None,
+ ) -> None:
+ """Initiate a TFLite Model."""
+ if not num_threads:
num_threads = None
- if batch_size == None:
+ if not batch_size:
self.interpreter = interpreter_wrapper.Interpreter(
model_path=filename, num_threads=num_threads
)
else: # if a batch size is specified, modify the TFLite model to use this size
with tempfile.TemporaryDirectory() as tmp:
- fb = load(filename)
- for sg in fb.subgraphs:
- for t in list(sg.inputs) + list(sg.outputs):
- sg.tensors[t].shape = np.array(
- [batch_size] + list(sg.tensors[t].shape[1:]), dtype=np.int32
+ flatbuffer = load(filename)
+ for subgraph in flatbuffer.subgraphs:
+ for tensor in list(subgraph.inputs) + list(subgraph.outputs):
+ subgraph.tensors[tensor].shape = np.array(
+ [batch_size] + list(subgraph.tensors[tensor].shape[1:]),
+ dtype=np.int32,
)
tempname = os.path.join(tmp, "rewrite_tmp.tflite")
- save(fb, tempname)
+ save(flatbuffer, tempname)
self.interpreter = interpreter_wrapper.Interpreter(
model_path=tempname, num_threads=num_threads
)
@@ -122,8 +148,9 @@ class TFLiteModel:
self.shape_from_name = {d["name"]: d["shape"] for d in details}
self.batch_size = next(iter(self.shape_from_name.values()))[0]
- def __call__(self, named_input):
- """Execute the model on one or a batch of named inputs (a dict of name: numpy array)"""
+ def __call__(self, named_input: dict) -> dict:
+ """Execute the model on one or a batch of named inputs \
+ (a dict of name: numpy array)."""
input_len = next(iter(named_input.values())).shape[0]
full_steps = input_len // self.batch_size
remainder = input_len % self.batch_size
@@ -131,39 +158,46 @@ class TFLiteModel:
named_ys = defaultdict(list)
for i in range(full_steps):
for name, x_batch in named_input.items():
- x = x_batch[i : i + self.batch_size]
- self.interpreter.set_tensor(self.handle_from_name[name], x)
+ x_tensor = x_batch[i : i + self.batch_size] # noqa: E203
+ self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
self.interpreter.invoke()
- for d in self.output_details:
- named_ys[d["name"]].append(self.interpreter.get_tensor(d["index"]))
+ for output_detail in self.output_details:
+ named_ys[output_detail["name"]].append(
+ self.interpreter.get_tensor(output_detail["index"])
+ )
if remainder:
for name, x_batch in named_input.items():
- x = np.zeros(self.shape_from_name[name]).astype(x_batch.dtype)
- x[:remainder] = x_batch[-remainder:]
- self.interpreter.set_tensor(self.handle_from_name[name], x)
+ x_tensor = np.zeros( # pylint: disable=invalid-name
+ self.shape_from_name[name]
+ ).astype(x_batch.dtype)
+ x_tensor[:remainder] = x_batch[-remainder:]
+ self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
self.interpreter.invoke()
- for d in self.output_details:
- named_ys[d["name"]].append(
- self.interpreter.get_tensor(d["index"])[:remainder]
+ for output_detail in self.output_details:
+ named_ys[output_detail["name"]].append(
+ self.interpreter.get_tensor(output_detail["index"])[:remainder]
)
return {k: np.concatenate(v) for k, v in named_ys.items()}
- def input_tensors(self):
+ def input_tensors(self) -> list:
+ """Return name from input details."""
return [d["name"] for d in self.input_details]
- def output_tensors(self):
+ def output_tensors(self) -> list:
+ """Return name from output details."""
return [d["name"] for d in self.output_details]
-def sample_tfrec(input_file, k, output_file):
+def sample_tfrec(input_file: str, k: int, output_file: str) -> None:
+ """Count, read and write TFRecord input and output data."""
total = numpytf_count(input_file)
- next = sorted(random.sample(range(total), k=k), reverse=True)
+ next_sample = sorted(random.sample(range(total), k=k), reverse=True)
- reader = NumpyTFReader(input_file)
+ reader = numpytf_read(input_file)
with NumpyTFWriter(output_file) as writer:
for i, data in enumerate(reader):
- if i == next[-1]:
- next.pop()
+ if i == next_sample[-1]:
+ next_sample.pop()
writer.write(data)
- if not next:
+ if not next_sample:
break
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py
index b1a2914..d930a1e 100644
--- a/src/mlia/nn/rewrite/core/utils/parallel.py
+++ b/src/mlia/nn/rewrite/core/utils/parallel.py
@@ -1,28 +1,45 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Parallelize a TFLiteModel."""
+from __future__ import annotations
+
+import logging
import math
import os
from collections import defaultdict
from multiprocessing import cpu_count
from multiprocessing import Pool
+from pathlib import Path
+from typing import Any
import numpy as np
-
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-
from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+logger = logging.getLogger(__name__)
+
class ParallelTFLiteModel(TFLiteModel):
- def __init__(self, filename, num_procs=1, num_threads=0, batch_size=None):
- """num_procs: 0 => detect real cores on system
- num_threads: 0 => TFLite impl. specific setting, usually 3
- batch_size: None => automatic (num_procs or file-determined)
- """
+ """A parallel version of a TFLiteModel.
+
+ num_procs: 0 => detect real cores on system
+ num_threads: 0 => TFLite impl. specific setting, usually 3
+ batch_size: None => automatic (num_procs or file-determined)
+ """
+
+ def __init__(
+ self,
+ filename: str | Path,
+ num_procs: int = 1,
+ num_threads: int = 0,
+ batch_size: int | None = None,
+ ) -> None:
+ """Initiate a Parallel TFLite Model."""
self.pool = None
+ filename = str(filename)
self.filename = filename
if not num_procs:
self.num_procs = cpu_count()
@@ -37,7 +54,7 @@ class ParallelTFLiteModel(TFLiteModel):
local_batch_size = int(math.ceil(batch_size / self.num_procs))
super().__init__(filename, batch_size=local_batch_size)
del self.interpreter
- self.pool = Pool(
+ self.pool = Pool( # pylint: disable=consider-using-with
processes=self.num_procs,
initializer=_pool_create_worker,
initargs=[filename, self.batch_size, self.num_threads],
@@ -51,15 +68,18 @@ class ParallelTFLiteModel(TFLiteModel):
self.partial_batches = 0
self.warned = False
- def close(self):
+ def close(self) -> None:
+ """Close and terminate pool."""
if self.pool:
self.pool.close()
self.pool.terminate()
- def __del__(self):
+ def __del__(self) -> None:
+ """Close instance."""
self.close()
- def __call__(self, named_input):
+ def __call__(self, named_input: dict) -> Any:
+ """Call instance."""
if self.pool:
global_batch_size = next(iter(named_input.values())).shape[0]
# Note: self.batch_size comes from superclass and is local batch size
@@ -72,19 +92,21 @@ class ParallelTFLiteModel(TFLiteModel):
and self.total_batches > 10
and self.partial_batches / self.total_batches >= 0.5
):
- print(
- "ParallelTFLiteModel(%s): warning - %.1f%% of batches do not use all %d processes, set batch size to a multiple of this"
- % (
- self.filename,
- 100 * self.partial_batches / self.total_batches,
- self.num_procs,
- )
+ logger.warning(
+ "ParallelTFLiteModel(%s): warning - %.1f of batches "
+ "do not use all %d processes, set batch size to "
+ "a multiple of this.",
+ self.filename,
+ 100 * self.partial_batches / self.total_batches,
+ self.num_procs,
)
self.warned = True
local_batches = [
{
- key: values[i * self.batch_size : (i + 1) * self.batch_size]
+ key: values[
+ i * self.batch_size : (i + 1) * self.batch_size # noqa: E203
+ ]
for key, values in named_input.items()
}
for i in range(chunks)
@@ -92,22 +114,26 @@ class ParallelTFLiteModel(TFLiteModel):
chunk_results = self.pool.map(_pool_run, local_batches)
named_ys = defaultdict(list)
for chunk in chunk_results:
- for k, v in chunk.items():
- named_ys[k].append(v)
- return {k: np.concatenate(v) for k, v in named_ys.items()}
- else:
- return super().__call__(named_input)
+ for key, value in chunk.items():
+ named_ys[key].append(value)
+ return {key: np.concatenate(value) for key, value in named_ys.items()}
+
+ return super().__call__(named_input)
-_local_model = None
+_LOCAL_MODEL = None
-def _pool_create_worker(filename, local_batch_size=None, num_threads=None):
- global _local_model
- _local_model = TFLiteModel(
+def _pool_create_worker(
+ filename: str, local_batch_size: int = 0, num_threads: int = 0
+) -> None:
+ global _LOCAL_MODEL # pylint: disable=global-statement
+ _LOCAL_MODEL = TFLiteModel(
filename, batch_size=local_batch_size, num_threads=num_threads
)
-def _pool_run(named_inputs):
- return _local_model(named_inputs)
+def _pool_run(named_inputs: dict) -> Any:
+ if _LOCAL_MODEL:
+ return _LOCAL_MODEL(named_inputs)
+ raise ValueError("TFLiteModel is not initiated")
diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py
index d1ed322..ddf0cc2 100644
--- a/src/mlia/nn/rewrite/core/utils/utils.py
+++ b/src/mlia/nn/rewrite/core/utils/utils.py
@@ -1,22 +1,28 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-import os
+"""Model and file system utilites."""
+from __future__ import annotations
+
+from pathlib import Path
import flatbuffers
-from tensorflow.lite.python import schema_py_generated as schema_fb
+from tensorflow.lite.python.schema_py_generated import Model
+from tensorflow.lite.python.schema_py_generated import ModelT
-def load(input_tflite_file):
- if not os.path.exists(input_tflite_file):
- raise FileNotFoundError("TFLite file not found at %r\n" % input_tflite_file)
+def load(input_tflite_file: str | Path) -> ModelT:
+ """Load a flatbuffer model from file."""
+ if not Path(input_tflite_file).exists():
+ raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n")
with open(input_tflite_file, "rb") as file_handle:
file_data = bytearray(file_handle.read())
- model_obj = schema_fb.Model.GetRootAsModel(file_data, 0)
- model = schema_fb.ModelT.InitFromObj(model_obj)
+ model_obj = Model.GetRootAsModel(file_data, 0)
+ model = ModelT.InitFromObj(model_obj)
return model
-def save(model, output_tflite_file):
+def save(model: ModelT, output_tflite_file: str | Path) -> None:
+ """Save a flatbuffer model to a given file."""
builder = flatbuffers.Builder(1024) # Initial size of the buffer, which
# will grow automatically if needed
model_offset = model.Pack(builder)
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 7a25e47..5e223fa 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -6,6 +6,8 @@ from __future__ import annotations
import math
from pathlib import Path
from typing import Any
+from typing import cast
+from typing import List
from typing import NamedTuple
import tensorflow as tf
@@ -129,12 +131,12 @@ def get_optimizer(
return Clusterer(model, config)
if isinstance(config, RewriteConfiguration):
- return Rewriter(model, config) # type: ignore
+ return Rewriter(model, config)
- if isinstance(config, OptimizationSettings) or is_list_of(
- config, OptimizationSettings
- ):
- return _get_optimizer(model, config) # type: ignore
+ if isinstance(config, OptimizationSettings):
+ return _get_optimizer(model, cast(OptimizationSettings, config))
+ if is_list_of(config, OptimizationSettings):
+ return _get_optimizer(model, cast(List[OptimizationSettings], config))
raise ConfigurationError(f"Unknown optimization configuration {config}")
@@ -186,7 +188,7 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
- return RewriteConfiguration( # type: ignore
+ return RewriteConfiguration(
str(optimization_target), layers_to_optimize, dataset
)
diff --git a/tests/test_nn_rewrite_core_graph_edit_cut.py b/tests/test_nn_rewrite_core_graph_edit_cut.py
index 914fdfd..7d267ed 100644
--- a/tests/test_nn_rewrite_core_graph_edit_cut.py
+++ b/tests/test_nn_rewrite_core_graph_edit_cut.py
@@ -13,11 +13,11 @@ def test_cut_model(test_tflite_model: Path, tmp_path: Path) -> None:
"""Test the function cut_model()."""
output_file = tmp_path / "out.tflite"
cut_model(
- model_file=test_tflite_model,
+ model_file=str(test_tflite_model),
input_names=["serving_default_input:0"],
output_names=["sequential/flatten/Reshape"],
subgraph_index=0,
- output_file=output_file,
+ output_file=str(output_file),
)
assert output_file.is_file()
diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py
index 39aeef5..cd728af 100644
--- a/tests/test_nn_rewrite_core_graph_edit_record.py
+++ b/tests/test_nn_rewrite_core_graph_edit_record.py
@@ -7,7 +7,7 @@ import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.graph_edit.record import record_model
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
@pytest.mark.parametrize("batch_size", (None, 1, 2))
@@ -46,7 +46,7 @@ def test_record_model(
# any of the model outputs
interpreter = tf.lite.Interpreter(str(test_tflite_model))
model_outputs = interpreter.get_output_details()
- dataset = NumpyTFReader(str(output_file))
+ dataset = numpytf_read(str(output_file))
for data in dataset:
for name, tensor in data.items():
assert data_matches_outputs(name, tensor, model_outputs)
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index d2bc1e0..3c2ef3e 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -6,17 +6,21 @@ from __future__ import annotations
from pathlib import Path
from tempfile import TemporaryDirectory
+from typing import Any
import numpy as np
import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.train import augmentation_presets
+from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
-def replace_fully_connected_with_conv(input_shape, output_shape) -> tf.keras.Model:
+def replace_fully_connected_with_conv(
+ input_shape: Any, output_shape: Any
+) -> tf.keras.Model:
"""Get a replacement model for the fully connected layer."""
for name, shape in {
"Input": input_shape,
@@ -43,7 +47,7 @@ def check_train(
augmentation_preset: tuple[float | None, float | None] = augmentation_presets[
"none"
],
- lr_schedule: str = "cosine",
+ lr_schedule: LearningRateSchedule = "cosine",
use_unmodified_model: bool = False,
num_procs: int = 1,
) -> None:
@@ -60,7 +64,7 @@ def check_train(
output_tensors=["StatefulPartitionedCall:0"],
augment=augmentation_preset,
steps=32,
- lr=1e-3,
+ learning_rate=1e-3,
batch_size=batch_size,
verbose=verbose,
show_progress=show_progress,
@@ -104,7 +108,7 @@ def test_train(
verbose: bool,
show_progress: bool,
augmentation_preset: tuple[float | None, float | None],
- lr_schedule: str,
+ lr_schedule: LearningRateSchedule,
use_unmodified_model: bool,
num_procs: int,
) -> None:
@@ -131,7 +135,7 @@ def test_train_invalid_schedule(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- lr_schedule="unknown_schedule",
+ lr_schedule="unknown_schedule", # type: ignore
)