aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-03-20 18:07:54 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:42:55 +0100
commit62768232c5fe4ed6b87136c336b65e13d030e9d4 (patch)
tree847c36a2f7e092982bc1d7a66d0bf601447c8d20
parent446c379c92e15ad8f24ed0db853dd0fc9c271151 (diff)
downloadmlia-62768232c5fe4ed6b87136c336b65e13d030e9d4.tar.gz
MLIA-843 Add unit tests for module mlia.nn.rewrite
Note: The unit tests mostly call the main functions from the respective modules only. Change-Id: Ib2ce5c53d0c3eb222b8b8be42fba33ac8e007574 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
-rw-r--r--src/mlia/nn/rewrite/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/diff.py30
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py33
-rw-r--r--src/mlia/nn/rewrite/core/train.py88
-rw-r--r--src/mlia/nn/rewrite/core/utils/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py3
-rw-r--r--src/mlia/nn/rewrite/core/utils/parallel.py4
-rw-r--r--src/mlia/nn/rewrite/core/utils/utils.py2
-rw-r--r--tests/conftest.py95
-rw-r--r--tests/test_backend_vela_compat.py3
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_cut.py29
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_join.py50
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_record.py52
-rw-r--r--tests/test_nn_rewrite_core_train.py157
-rw-r--r--tests/test_nn_rewrite_core_utils.py33
-rw-r--r--tests/test_nn_rewrite_core_utils_numpy_tfrecord.py18
-rw-r--r--tests/utils/rewrite.py27
19 files changed, 484 insertions, 148 deletions
diff --git a/src/mlia/nn/rewrite/__init__.py b/src/mlia/nn/rewrite/__init__.py
index 48b1622..8c1f750 100644
--- a/src/mlia/nn/rewrite/__init__.py
+++ b/src/mlia/nn/rewrite/__init__.py
@@ -1,2 +1,2 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
+# SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/nn/rewrite/core/__init__.py b/src/mlia/nn/rewrite/core/__init__.py
index 48b1622..8c1f750 100644
--- a/src/mlia/nn/rewrite/core/__init__.py
+++ b/src/mlia/nn/rewrite/core/__init__.py
@@ -1,2 +1,2 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
+# SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
index 48b1622..8c1f750 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/__init__.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
@@ -1,2 +1,2 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
+# SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py
index b6c9616..0829f0a 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/diff.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py
@@ -13,36 +13,6 @@ from tensorflow.lite.python import interpreter as interpreter_wrapper
from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter
-def diff(file1, file2):
- results = []
-
- dataset1 = NumpyTFReader(file1)
- dataset2 = NumpyTFReader(file2)
-
- for i, (x1, x2) in enumerate(zip(dataset1, dataset2)):
- assert x1.keys() == x2.keys(), (
- "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n"
- % (
- i,
- file1,
- ", ".join(x1.keys()),
- file2,
- ", ".join(x2.keys()),
- )
- )
- results.append({})
- for k in x1.keys():
- v1 = x1[k].numpy().astype(np.double)
- v2 = x2[k].numpy().astype(np.double)
- mae = abs(v1 - v2).mean()
- results[-1][k] = mae
-
- total = sum(sum(x.values()) for x in results)
- count = sum(len(x.values()) for x in results)
- mean = total / count
- return results, mean
-
-
def diff_stats(file1, file2, per_tensor_and_channel=False):
dataset1 = NumpyTFReader(file1)
dataset2 = NumpyTFReader(file2)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index 03cd3f9..ae13313 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -37,7 +37,6 @@ def record_model(
total = numpytf_count(input_filename)
dataset = NumpyTFReader(input_filename)
- writer = NumpyTFWriter(output_filename)
if batch_size > 1:
# Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now.
@@ -47,16 +46,22 @@ def record_model(
dataset = dataset.batch(batch_size, drop_remainder=False)
total = int(math.ceil(total / batch_size))
- for j, named_x in enumerate(
- tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
- ):
- named_y = model(named_x)
- if batch_size > 1:
- for i in range(batch_size):
- # Expand the batches and recreate each dict as a batch-size 1 item for the tfrec output
- d = {k: v[i : i + 1] for k, v in named_y.items() if i < v.shape[0]}
- if d:
- writer.write(d)
- else:
- writer.write(named_y)
- model.close()
+ with NumpyTFWriter(output_filename) as writer:
+ for _, named_x in enumerate(
+ tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ ):
+ named_y = model(named_x)
+ if batch_size > 1:
+ for i in range(batch_size):
+ # Expand the batches and recreate each dict as a
+ # batch-size 1 item for the tfrec output
+ recreated_dict = {
+ k: v[i : i + 1] # noqa: E203
+ for k, v in named_y.items()
+ if i < v.shape[0]
+ }
+ if recreated_dict:
+ writer.write(recreated_dict)
+ else:
+ writer.write(named_y)
+ model.close()
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index a929b14..096daf4 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -40,85 +40,7 @@ augmentation_presets = {
"mix_gaussian_small": (1.6, 0.3),
}
-
-class SequentialTrainer:
- def __init__(
- self,
- source_model,
- output_model,
- input_tfrec,
- augment="gaussian",
- steps=6000,
- lr=1e-3,
- batch_size=32,
- show_progress=True,
- eval_fn=None,
- num_procs=1,
- num_threads=0,
- ):
- self.source_model = source_model
- self.output_model = output_model
- self.input_tfrec = input_tfrec
- self.default_augment = augment
- self.default_steps = steps
- self.default_lr = lr
- self.default_batch_size = batch_size
- self.show_progress = show_progress
- self.num_procs = num_procs
- self.num_threads = num_threads
- self.first_replace = True
- self.eval_fn = eval_fn
-
- def replace(
- self,
- model_fn,
- input_tensors,
- output_tensors,
- augment=None,
- steps=None,
- lr=None,
- batch_size=None,
- ):
- augment = self.default_augment if augment is None else augment
- steps = self.default_steps if steps is None else steps
- lr = self.default_lr if lr is None else lr
- batch_size = self.default_batch_size if batch_size is None else batch_size
-
- if isinstance(augment, str):
- augment = augmentation_presets[augment]
-
- if self.first_replace:
- source_model = self.source_model
- unmodified_model = None
- else:
- source_model = self.output_model
- unmodified_model = self.source_model
-
- mae, nrmse = train(
- source_model,
- unmodified_model,
- self.output_model,
- self.input_tfrec,
- model_fn,
- input_tensors,
- output_tensors,
- augment,
- steps,
- lr,
- batch_size,
- False,
- self.show_progress,
- None,
- 0,
- self.num_procs,
- self.num_threads,
- )
-
- self.first_replace = False
- if self.eval_fn:
- return self.eval_fn(mae, nrmse, self.output_model)
- else:
- return mae, nrmse
+learning_rate_schedules = {"cosine", "late", "constant"}
def train(
@@ -135,6 +57,7 @@ def train(
batch_size,
verbose,
show_progress,
+ learning_rate_schedule="cosine",
checkpoint_at=None,
checkpoint_decay_steps=0,
num_procs=1,
@@ -183,6 +106,7 @@ def train(
show_progress=show_progress,
num_procs=num_procs,
num_threads=num_threads,
+ schedule=learning_rate_schedule,
)
for i, filename in enumerate(tflite_filenames):
@@ -363,9 +287,9 @@ def train_in_dir(
elif schedule == "constant":
callbacks = []
else:
- assert False, (
- 'LR schedule "%s" not implemented - expected "cosine", "constant" or "late"'
- % schedule
+ assert schedule not in learning_rate_schedules
+ raise ValueError(
+ f'LR schedule "{schedule}" not implemented - expected one of {learning_rate_schedules}.'
)
output_filenames = []
diff --git a/src/mlia/nn/rewrite/core/utils/__init__.py b/src/mlia/nn/rewrite/core/utils/__init__.py
index 48b1622..8c1f750 100644
--- a/src/mlia/nn/rewrite/core/utils/__init__.py
+++ b/src/mlia/nn/rewrite/core/utils/__init__.py
@@ -1,2 +1,2 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
+# SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
index ac3e875..2141003 100644
--- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -63,9 +63,6 @@ class NumpyTFWriter:
def __exit__(self, type, value, traceback):
self.close()
- def __del__(self):
- self.close()
-
def write(self, array_dict):
type_map = {n: str(a.dtype.name) for n, a in array_dict.items()}
self.type_map.update(type_map)
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py
index 5affc03..b1a2914 100644
--- a/src/mlia/nn/rewrite/core/utils/parallel.py
+++ b/src/mlia/nn/rewrite/core/utils/parallel.py
@@ -3,10 +3,10 @@
import math
import os
from collections import defaultdict
+from multiprocessing import cpu_count
from multiprocessing import Pool
import numpy as np
-from psutil import cpu_count
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
@@ -25,7 +25,7 @@ class ParallelTFLiteModel(TFLiteModel):
self.pool = None
self.filename = filename
if not num_procs:
- self.num_procs = cpu_count(logical=False)
+ self.num_procs = cpu_count()
else:
self.num_procs = int(num_procs)
diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py
index ed6c81d..d1ed322 100644
--- a/src/mlia/nn/rewrite/core/utils/utils.py
+++ b/src/mlia/nn/rewrite/core/utils/utils.py
@@ -8,7 +8,7 @@ from tensorflow.lite.python import schema_py_generated as schema_fb
def load(input_tflite_file):
if not os.path.exists(input_tflite_file):
- raise RuntimeError("TFLite file not found at %r\n" % input_tflite_file)
+ raise FileNotFoundError("TFLite file not found at %r\n" % input_tflite_file)
with open(input_tflite_file, "rb") as file_handle:
file_data = bytearray(file_handle.read())
model_obj = schema_fb.Model.GetRootAsModel(file_data, 0)
diff --git a/tests/conftest.py b/tests/conftest.py
index 30889ca..c42b8cb 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,13 +3,16 @@
"""Pytest conf module."""
import shutil
from pathlib import Path
+from typing import Callable
from typing import Generator
+import numpy as np
import pytest
import tensorflow as tf
from mlia.backend.vela.compiler import optimize_model
from mlia.core.context import ExecutionContext
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.nn.tensorflow.utils import save_tflite_model
@@ -68,6 +71,14 @@ def get_test_keras_model() -> tf.keras.Model:
return model
+TEST_MODEL_KERAS_FILE = "test_model.h5"
+TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite"
+TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite"
+TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite"
+TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model"
+TEST_MODEL_INVALID_FILE = "invalid.tflite"
+
+
@pytest.fixture(scope="session", name="test_models_path")
def fixture_test_models_path(
tmp_path_factory: pytest.TempPathFactory,
@@ -75,15 +86,23 @@ def fixture_test_models_path(
"""Provide path to the test models."""
tmp_path = tmp_path_factory.mktemp("models")
+ # Keras Model
keras_model = get_test_keras_model()
- save_keras_model(keras_model, tmp_path / "test_model.h5")
+ save_keras_model(keras_model, tmp_path / TEST_MODEL_KERAS_FILE)
+
+ # Un-quantized TensorFlow Lite model (fp32)
+ save_tflite_model(
+ convert_to_tflite(keras_model, quantized=False),
+ tmp_path / TEST_MODEL_TFLITE_FP32_FILE,
+ )
+ # Quantized TensorFlow Lite model (int8)
tflite_model = convert_to_tflite(keras_model, quantized=True)
- tflite_model_path = tmp_path / "test_model.tflite"
+ tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE
save_tflite_model(tflite_model, tflite_model_path)
- tflite_vela_model = tmp_path / "test_model_vela.tflite"
-
+ # Vela-optimized TensorFlow Lite model (int8)
+ tflite_vela_model = tmp_path / TEST_MODEL_TFLITE_VELA_FILE
target_config = EthosUConfiguration.load_profile("ethos-u55-256")
optimize_model(
tflite_model_path,
@@ -91,9 +110,9 @@ def fixture_test_models_path(
tflite_vela_model,
)
- tf.saved_model.save(keras_model, str(tmp_path / "tf_model_test_model"))
+ tf.saved_model.save(keras_model, str(tmp_path / TEST_MODEL_TF_SAVED_MODEL_FILE))
- invalid_tflite_model = tmp_path / "invalid.tflite"
+ invalid_tflite_model = tmp_path / TEST_MODEL_INVALID_FILE
invalid_tflite_model.touch()
yield tmp_path
@@ -104,28 +123,82 @@ def fixture_test_models_path(
@pytest.fixture(scope="session", name="test_keras_model")
def fixture_test_keras_model(test_models_path: Path) -> Path:
"""Return test Keras model."""
- return test_models_path / "test_model.h5"
+ return test_models_path / TEST_MODEL_KERAS_FILE
@pytest.fixture(scope="session", name="test_tflite_model")
def fixture_test_tflite_model(test_models_path: Path) -> Path:
"""Return test TensorFlow Lite model."""
- return test_models_path / "test_model.tflite"
+ return test_models_path / TEST_MODEL_TFLITE_INT8_FILE
+
+
+@pytest.fixture(scope="session", name="test_tflite_model_fp32")
+def fixture_test_tflite_model_fp32(test_models_path: Path) -> Path:
+ """Return test TensorFlow Lite model."""
+ return test_models_path / TEST_MODEL_TFLITE_FP32_FILE
@pytest.fixture(scope="session", name="test_tflite_vela_model")
def fixture_test_tflite_vela_model(test_models_path: Path) -> Path:
"""Return test Vela-optimized TensorFlow Lite model."""
- return test_models_path / "test_model_vela.tflite"
+ return test_models_path / TEST_MODEL_TFLITE_VELA_FILE
@pytest.fixture(scope="session", name="test_tf_model")
def fixture_test_tf_model(test_models_path: Path) -> Path:
"""Return test TensorFlow Lite model."""
- return test_models_path / "tf_model_test_model"
+ return test_models_path / TEST_MODEL_TF_SAVED_MODEL_FILE
@pytest.fixture(scope="session", name="test_tflite_invalid_model")
def fixture_test_tflite_invalid_model(test_models_path: Path) -> Path:
"""Return test invalid TensorFlow Lite model."""
- return test_models_path / "invalid.tflite"
+ return test_models_path / TEST_MODEL_INVALID_FILE
+
+
+def _write_tfrecord(
+ tfrecord_file: Path,
+ data_generator: Callable,
+ input_name: str = "serving_default_input:0",
+ num_records: int = 3,
+) -> None:
+ """Write data to a tfrecord."""
+ with NumpyTFWriter(str(tfrecord_file)) as writer:
+ for _ in range(num_records):
+ writer.write({input_name: data_generator()})
+
+
+@pytest.fixture(scope="session", name="test_tfrecord")
+def fixture_test_tfrecord(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Create a tfrecord with random data matching fixture 'test_tflite_model'."""
+ tmp_path = tmp_path_factory.mktemp("tfrecords")
+ tfrecord_file = tmp_path / "test_int8.tfrecord"
+
+ def random_data() -> np.ndarray:
+ return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8)
+
+ _write_tfrecord(tfrecord_file, random_data)
+
+ yield tfrecord_file
+
+ shutil.rmtree(tmp_path)
+
+
+@pytest.fixture(scope="session", name="test_tfrecord_fp32")
+def fixture_test_tfrecord_fp32(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Create tfrecord with random data matching fixture 'test_tflite_model_fp32'."""
+ tmp_path = tmp_path_factory.mktemp("tfrecords")
+ tfrecord_file = tmp_path / "test_fp32.tfrecord"
+
+ def random_data() -> np.ndarray:
+ return np.random.rand(1, 28, 28, 1).astype(np.float32)
+
+ _write_tfrecord(tfrecord_file, random_data)
+
+ yield tfrecord_file
+
+ shutil.rmtree(tmp_path)
diff --git a/tests/test_backend_vela_compat.py b/tests/test_backend_vela_compat.py
index 0c39dd6..bea6a1b 100644
--- a/tests/test_backend_vela_compat.py
+++ b/tests/test_backend_vela_compat.py
@@ -12,13 +12,14 @@ from mlia.backend.vela.compat import Operators
from mlia.backend.vela.compat import supported_operators
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.utils.filesystem import working_directory
+from tests.conftest import TEST_MODEL_TFLITE_INT8_FILE
@pytest.mark.parametrize(
"model, expected_ops",
[
(
- "test_model.tflite",
+ TEST_MODEL_TFLITE_INT8_FILE,
Operators(
ops=[
Operator(
diff --git a/tests/test_nn_rewrite_core_graph_edit_cut.py b/tests/test_nn_rewrite_core_graph_edit_cut.py
new file mode 100644
index 0000000..914fdfd
--- /dev/null
+++ b/tests/test_nn_rewrite_core_graph_edit_cut.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.graph_edit.cut."""
+from pathlib import Path
+
+import numpy as np
+import tensorflow as tf
+
+from mlia.nn.rewrite.core.graph_edit.cut import cut_model
+
+
+def test_cut_model(test_tflite_model: Path, tmp_path: Path) -> None:
+ """Test the function cut_model()."""
+ output_file = tmp_path / "out.tflite"
+ cut_model(
+ model_file=test_tflite_model,
+ input_names=["serving_default_input:0"],
+ output_names=["sequential/flatten/Reshape"],
+ subgraph_index=0,
+ output_file=output_file,
+ )
+ assert output_file.is_file()
+
+ interpreter = tf.lite.Interpreter(model_path=str(output_file))
+ output_details = interpreter.get_output_details()
+ assert len(output_details) == 1
+ out = output_details[0]
+ assert "Reshape" in out["name"]
+ assert np.prod(out["shape"]) == 1728
diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py
new file mode 100644
index 0000000..cbbbeba
--- /dev/null
+++ b/tests/test_nn_rewrite_core_graph_edit_join.py
@@ -0,0 +1,50 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.graph_edit.join."""
+from pathlib import Path
+
+from mlia.nn.rewrite.core.graph_edit.cut import cut_model
+from mlia.nn.rewrite.core.graph_edit.join import join_models
+from mlia.nn.rewrite.core.utils.utils import load
+from tests.utils.rewrite import models_are_equal
+
+
+def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None:
+ """Test the function join_models()."""
+ # Cut model into two parts first
+ first_file = tmp_path / "first_part.tflite"
+ second_file = tmp_path / "second_part.tflite"
+ cut_model(
+ model_file=str(test_tflite_model),
+ input_names=["serving_default_input:0"],
+ output_names=["sequential/flatten/Reshape"],
+ subgraph_index=0,
+ output_file=str(first_file),
+ )
+ cut_model(
+ model_file=str(test_tflite_model),
+ input_names=["sequential/flatten/Reshape"],
+ output_names=["StatefulPartitionedCall:0"],
+ subgraph_index=0,
+ output_file=str(second_file),
+ )
+ assert first_file.is_file()
+ assert second_file.is_file()
+
+ joined_file = tmp_path / "joined.tflite"
+
+ # Now re-join the cut model and check the result is the same as the original
+ for in_src, in_dst in ((first_file, second_file), (second_file, first_file)):
+ join_models(
+ input_src=str(in_src),
+ input_dst=str(in_dst),
+ output_file=str(joined_file),
+ subgraph_src=0,
+ subgraph_dst=0,
+ )
+ assert joined_file.is_file()
+
+ orig_model = load(str(test_tflite_model))
+ joined_model = load(str(joined_file))
+
+ assert models_are_equal(orig_model, joined_model)
diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py
new file mode 100644
index 0000000..39aeef5
--- /dev/null
+++ b/tests/test_nn_rewrite_core_graph_edit_record.py
@@ -0,0 +1,52 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.graph_edit.record."""
+from pathlib import Path
+
+import pytest
+import tensorflow as tf
+
+from mlia.nn.rewrite.core.graph_edit.record import record_model
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader
+
+
+@pytest.mark.parametrize("batch_size", (None, 1, 2))
+def test_record_model(
+ test_tflite_model: Path,
+ tmp_path: Path,
+ test_tfrecord: Path,
+ batch_size: int,
+) -> None:
+ """Test the function record_model()."""
+ output_file = tmp_path / "out.tfrecord"
+ record_model(
+ input_filename=str(test_tfrecord),
+ model_filename=str(test_tflite_model),
+ output_filename=str(output_file),
+ batch_size=batch_size,
+ )
+ assert output_file.is_file()
+
+ def data_matches_outputs(name: str, tensor: tf.Tensor, model_outputs: list) -> bool:
+ """Check that the name and the tensor match any of the model outputs."""
+ for model_output in model_outputs:
+ if model_output["name"] == name:
+ # If the name is a match, tensor shape and type have to match!
+ tensor_shape = tensor.shape.as_list()
+ tensor_type = tensor.dtype.as_numpy_dtype
+ return all(
+ (
+ tensor_shape == model_output["shape"].tolist(),
+ tensor_type == model_output["dtype"],
+ )
+ )
+ return False
+
+ # Now load model and the data and make sure that the written data matches
+ # any of the model outputs
+ interpreter = tf.lite.Interpreter(str(test_tflite_model))
+ model_outputs = interpreter.get_output_details()
+ dataset = NumpyTFReader(str(output_file))
+ for data in dataset:
+ for name, tensor in data.items():
+ assert data_matches_outputs(name, tensor, model_outputs)
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
new file mode 100644
index 0000000..d2bc1e0
--- /dev/null
+++ b/tests/test_nn_rewrite_core_train.py
@@ -0,0 +1,157 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.train."""
+# pylint: disable=too-many-arguments
+from __future__ import annotations
+
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import numpy as np
+import pytest
+import tensorflow as tf
+
+from mlia.nn.rewrite.core.train import augmentation_presets
+from mlia.nn.rewrite.core.train import mixup
+from mlia.nn.rewrite.core.train import train
+
+
+def replace_fully_connected_with_conv(input_shape, output_shape) -> tf.keras.Model:
+ """Get a replacement model for the fully connected layer."""
+ for name, shape in {
+ "Input": input_shape,
+ "Output": output_shape,
+ }.items():
+ if len(shape) != 1:
+ raise RuntimeError(f"{name}: shape (N,) expected, but it is {input_shape}.")
+
+ model = tf.keras.Sequential(name="RewriteModel")
+ model.add(tf.keras.Input(input_shape))
+ model.add(tf.keras.layers.Reshape((1, 1, input_shape[0])))
+ model.add(tf.keras.layers.Conv2D(filters=output_shape[0], kernel_size=(1, 1)))
+ model.add(tf.keras.layers.Reshape(output_shape))
+
+ return model
+
+
+def check_train(
+ tflite_model: Path,
+ tfrecord: Path,
+ batch_size: int = 1,
+ verbose: bool = False,
+ show_progress: bool = False,
+ augmentation_preset: tuple[float | None, float | None] = augmentation_presets[
+ "none"
+ ],
+ lr_schedule: str = "cosine",
+ use_unmodified_model: bool = False,
+ num_procs: int = 1,
+) -> None:
+ """Test the train() function."""
+ with TemporaryDirectory() as tmp_dir:
+ output_file = Path(tmp_dir, "out.tfrecord")
+ result = train(
+ source_model=str(tflite_model),
+ unmodified_model=str(tflite_model) if use_unmodified_model else None,
+ output_model=str(output_file),
+ input_tfrec=str(tfrecord),
+ replace_fn=replace_fully_connected_with_conv,
+ input_tensors=["sequential/flatten/Reshape"],
+ output_tensors=["StatefulPartitionedCall:0"],
+ augment=augmentation_preset,
+ steps=32,
+ lr=1e-3,
+ batch_size=batch_size,
+ verbose=verbose,
+ show_progress=show_progress,
+ learning_rate_schedule=lr_schedule,
+ num_procs=num_procs,
+ )
+ assert len(result) == 2
+ assert all(res >= 0.0 for res in result), f"Results out of bound: {result}"
+ assert output_file.is_file()
+
+
+@pytest.mark.parametrize(
+ (
+ "batch_size",
+ "verbose",
+ "show_progress",
+ "augmentation_preset",
+ "lr_schedule",
+ "use_unmodified_model",
+ "num_procs",
+ ),
+ (
+ (1, False, False, augmentation_presets["none"], "cosine", False, 2),
+ (32, True, True, augmentation_presets["gaussian"], "late", True, 1),
+ (2, False, False, augmentation_presets["mixup"], "constant", True, 0),
+ (
+ 1,
+ False,
+ False,
+ augmentation_presets["mix_gaussian_large"],
+ "cosine",
+ False,
+ 2,
+ ),
+ ),
+)
+def test_train(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+ batch_size: int,
+ verbose: bool,
+ show_progress: bool,
+ augmentation_preset: tuple[float | None, float | None],
+ lr_schedule: str,
+ use_unmodified_model: bool,
+ num_procs: int,
+) -> None:
+ """Test the train() function with valid parameters."""
+ check_train(
+ tflite_model=test_tflite_model_fp32,
+ tfrecord=test_tfrecord_fp32,
+ batch_size=batch_size,
+ verbose=verbose,
+ show_progress=show_progress,
+ augmentation_preset=augmentation_preset,
+ lr_schedule=lr_schedule,
+ use_unmodified_model=use_unmodified_model,
+ num_procs=num_procs,
+ )
+
+
+def test_train_invalid_schedule(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+) -> None:
+ """Test the train() function with an invalid schedule."""
+ with pytest.raises(ValueError):
+ check_train(
+ tflite_model=test_tflite_model_fp32,
+ tfrecord=test_tfrecord_fp32,
+ lr_schedule="unknown_schedule",
+ )
+
+
+def test_train_invalid_augmentation(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+) -> None:
+ """Test the train() function with an invalid augmentation."""
+ with pytest.raises(ValueError):
+ check_train(
+ tflite_model=test_tflite_model_fp32,
+ tfrecord=test_tfrecord_fp32,
+ augmentation_preset=(1.0, 2.0, 3.0), # type: ignore
+ )
+
+
+def test_mixup() -> None:
+ """Test the mixup() function."""
+ src = np.array((1, 2, 3))
+ dst = mixup(rng=np.random.default_rng(123), batch=src)
+ assert src.shape == dst.shape
+ assert np.all(dst >= 0.0)
+ assert np.all(dst <= 3.0)
diff --git a/tests/test_nn_rewrite_core_utils.py b/tests/test_nn_rewrite_core_utils.py
new file mode 100644
index 0000000..d806a7b
--- /dev/null
+++ b/tests/test_nn_rewrite_core_utils.py
@@ -0,0 +1,33 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.utils."""
+from pathlib import Path
+
+import pytest
+import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
+from tests.utils.rewrite import models_are_equal
+
+
+def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None:
+ """Test the load/save functions for TensorFlow Lite models."""
+ with pytest.raises(FileNotFoundError):
+ load("THIS_IS_NOT_A_REAL_FILE")
+
+ model = load(test_tflite_model)
+ assert isinstance(model, ModelT)
+ assert model.subgraphs
+
+ output_file = tmp_path / "test.tflite"
+ assert not output_file.is_file()
+ save(model, output_file)
+ assert output_file.is_file()
+
+ model_copy = load(str(output_file))
+ assert models_are_equal(model, model_copy)
+
+ # Double check that the TensorFlow Lite Interpreter can still load the file.
+ tf.lite.Interpreter(model_path=str(output_file))
diff --git a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
new file mode 100644
index 0000000..7fc8048
--- /dev/null
+++ b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
@@ -0,0 +1,18 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.core.utils.numpy_tfrecord."""
+from __future__ import annotations
+
+from pathlib import Path
+
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec
+
+
+def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None:
+ """Test function sample_tfrec()."""
+ output_file = tmp_path / "output.tfrecord"
+ # Sample 1 sample from test_tfrecord
+ sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file))
+ assert output_file.is_file()
+ assert numpytf_count(str(output_file)) == 1
diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py
new file mode 100644
index 0000000..4264b4b
--- /dev/null
+++ b/tests/utils/rewrite.py
@@ -0,0 +1,27 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common test utils for the rewrite tests."""
+from __future__ import annotations
+
+from tensorflow.lite.python.schema_py_generated import ModelT
+
+
+def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
+ """Check that the two models are equal."""
+ if len(model1.subgraphs) != len(model2.subgraphs):
+ return False
+
+ for graph1, graph2 in zip(model1.subgraphs, model2.subgraphs):
+ if graph1.name != graph2.name or len(graph1.tensors) != len(graph2.tensors):
+ return False
+ for tensor1 in graph1.tensors:
+ for tensor2 in graph2.tensors:
+ if tensor1.name == tensor2.name:
+ if (
+ tensor1.shape == tensor2.shape
+ ).all() or tensor1.type == tensor2.type:
+ break
+ else:
+ return False # Tensor from graph1 not found in other graph.")
+
+ return True