aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_train.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r--tests/test_nn_rewrite_core_train.py157
1 files changed, 157 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
new file mode 100644
index 0000000..d2bc1e0
--- /dev/null
+++ b/tests/test_nn_rewrite_core_train.py
@@ -0,0 +1,157 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.train."""
+# pylint: disable=too-many-arguments
+from __future__ import annotations
+
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import numpy as np
+import pytest
+import tensorflow as tf
+
+from mlia.nn.rewrite.core.train import augmentation_presets
+from mlia.nn.rewrite.core.train import mixup
+from mlia.nn.rewrite.core.train import train
+
+
+def replace_fully_connected_with_conv(input_shape, output_shape) -> tf.keras.Model:
+ """Get a replacement model for the fully connected layer."""
+ for name, shape in {
+ "Input": input_shape,
+ "Output": output_shape,
+ }.items():
+ if len(shape) != 1:
+ raise RuntimeError(f"{name}: shape (N,) expected, but it is {input_shape}.")
+
+ model = tf.keras.Sequential(name="RewriteModel")
+ model.add(tf.keras.Input(input_shape))
+ model.add(tf.keras.layers.Reshape((1, 1, input_shape[0])))
+ model.add(tf.keras.layers.Conv2D(filters=output_shape[0], kernel_size=(1, 1)))
+ model.add(tf.keras.layers.Reshape(output_shape))
+
+ return model
+
+
+def check_train(
+ tflite_model: Path,
+ tfrecord: Path,
+ batch_size: int = 1,
+ verbose: bool = False,
+ show_progress: bool = False,
+ augmentation_preset: tuple[float | None, float | None] = augmentation_presets[
+ "none"
+ ],
+ lr_schedule: str = "cosine",
+ use_unmodified_model: bool = False,
+ num_procs: int = 1,
+) -> None:
+ """Test the train() function."""
+ with TemporaryDirectory() as tmp_dir:
+ output_file = Path(tmp_dir, "out.tfrecord")
+ result = train(
+ source_model=str(tflite_model),
+ unmodified_model=str(tflite_model) if use_unmodified_model else None,
+ output_model=str(output_file),
+ input_tfrec=str(tfrecord),
+ replace_fn=replace_fully_connected_with_conv,
+ input_tensors=["sequential/flatten/Reshape"],
+ output_tensors=["StatefulPartitionedCall:0"],
+ augment=augmentation_preset,
+ steps=32,
+ lr=1e-3,
+ batch_size=batch_size,
+ verbose=verbose,
+ show_progress=show_progress,
+ learning_rate_schedule=lr_schedule,
+ num_procs=num_procs,
+ )
+ assert len(result) == 2
+ assert all(res >= 0.0 for res in result), f"Results out of bound: {result}"
+ assert output_file.is_file()
+
+
+@pytest.mark.parametrize(
+ (
+ "batch_size",
+ "verbose",
+ "show_progress",
+ "augmentation_preset",
+ "lr_schedule",
+ "use_unmodified_model",
+ "num_procs",
+ ),
+ (
+ (1, False, False, augmentation_presets["none"], "cosine", False, 2),
+ (32, True, True, augmentation_presets["gaussian"], "late", True, 1),
+ (2, False, False, augmentation_presets["mixup"], "constant", True, 0),
+ (
+ 1,
+ False,
+ False,
+ augmentation_presets["mix_gaussian_large"],
+ "cosine",
+ False,
+ 2,
+ ),
+ ),
+)
+def test_train(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+ batch_size: int,
+ verbose: bool,
+ show_progress: bool,
+ augmentation_preset: tuple[float | None, float | None],
+ lr_schedule: str,
+ use_unmodified_model: bool,
+ num_procs: int,
+) -> None:
+ """Test the train() function with valid parameters."""
+ check_train(
+ tflite_model=test_tflite_model_fp32,
+ tfrecord=test_tfrecord_fp32,
+ batch_size=batch_size,
+ verbose=verbose,
+ show_progress=show_progress,
+ augmentation_preset=augmentation_preset,
+ lr_schedule=lr_schedule,
+ use_unmodified_model=use_unmodified_model,
+ num_procs=num_procs,
+ )
+
+
+def test_train_invalid_schedule(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+) -> None:
+ """Test the train() function with an invalid schedule."""
+ with pytest.raises(ValueError):
+ check_train(
+ tflite_model=test_tflite_model_fp32,
+ tfrecord=test_tfrecord_fp32,
+ lr_schedule="unknown_schedule",
+ )
+
+
+def test_train_invalid_augmentation(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+) -> None:
+ """Test the train() function with an invalid augmentation."""
+ with pytest.raises(ValueError):
+ check_train(
+ tflite_model=test_tflite_model_fp32,
+ tfrecord=test_tfrecord_fp32,
+ augmentation_preset=(1.0, 2.0, 3.0), # type: ignore
+ )
+
+
+def test_mixup() -> None:
+ """Test the mixup() function."""
+ src = np.array((1, 2, 3))
+ dst = mixup(rng=np.random.default_rng(123), batch=src)
+ assert src.shape == dst.shape
+ assert np.all(dst >= 0.0)
+ assert np.all(dst <= 3.0)