From 62768232c5fe4ed6b87136c336b65e13d030e9d4 Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Mon, 20 Mar 2023 18:07:54 +0000 Subject: MLIA-843 Add unit tests for module mlia.nn.rewrite Note: The unit tests mostly call the main functions from the respective modules only. Change-Id: Ib2ce5c53d0c3eb222b8b8be42fba33ac8e007574 Signed-off-by: Benjamin Klimczak --- tests/test_nn_rewrite_core_train.py | 157 ++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 tests/test_nn_rewrite_core_train.py (limited to 'tests/test_nn_rewrite_core_train.py') diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py new file mode 100644 index 0000000..d2bc1e0 --- /dev/null +++ b/tests/test_nn_rewrite_core_train.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.train.""" +# pylint: disable=too-many-arguments +from __future__ import annotations + +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +import pytest +import tensorflow as tf + +from mlia.nn.rewrite.core.train import augmentation_presets +from mlia.nn.rewrite.core.train import mixup +from mlia.nn.rewrite.core.train import train + + +def replace_fully_connected_with_conv(input_shape, output_shape) -> tf.keras.Model: + """Get a replacement model for the fully connected layer.""" + for name, shape in { + "Input": input_shape, + "Output": output_shape, + }.items(): + if len(shape) != 1: + raise RuntimeError(f"{name}: shape (N,) expected, but it is {input_shape}.") + + model = tf.keras.Sequential(name="RewriteModel") + model.add(tf.keras.Input(input_shape)) + model.add(tf.keras.layers.Reshape((1, 1, input_shape[0]))) + model.add(tf.keras.layers.Conv2D(filters=output_shape[0], kernel_size=(1, 1))) + model.add(tf.keras.layers.Reshape(output_shape)) + + return model + + +def check_train( + tflite_model: Path, + tfrecord: Path, + batch_size: int = 1, + verbose: bool = False, + show_progress: bool = False, + augmentation_preset: tuple[float | None, float | None] = augmentation_presets[ + "none" + ], + lr_schedule: str = "cosine", + use_unmodified_model: bool = False, + num_procs: int = 1, +) -> None: + """Test the train() function.""" + with TemporaryDirectory() as tmp_dir: + output_file = Path(tmp_dir, "out.tfrecord") + result = train( + source_model=str(tflite_model), + unmodified_model=str(tflite_model) if use_unmodified_model else None, + output_model=str(output_file), + input_tfrec=str(tfrecord), + replace_fn=replace_fully_connected_with_conv, + input_tensors=["sequential/flatten/Reshape"], + output_tensors=["StatefulPartitionedCall:0"], + augment=augmentation_preset, + steps=32, + lr=1e-3, + batch_size=batch_size, + verbose=verbose, + show_progress=show_progress, + learning_rate_schedule=lr_schedule, + num_procs=num_procs, + ) + assert len(result) == 2 + assert all(res >= 0.0 for res in result), f"Results out of bound: {result}" + assert output_file.is_file() + + +@pytest.mark.parametrize( + ( + "batch_size", + "verbose", + "show_progress", + "augmentation_preset", + "lr_schedule", + "use_unmodified_model", + "num_procs", + ), + ( + (1, False, False, augmentation_presets["none"], "cosine", False, 2), + (32, True, True, augmentation_presets["gaussian"], "late", True, 1), + (2, False, False, augmentation_presets["mixup"], "constant", True, 0), + ( + 1, + False, + False, + augmentation_presets["mix_gaussian_large"], + "cosine", + False, + 2, + ), + ), +) +def test_train( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, + batch_size: int, + verbose: bool, + show_progress: bool, + augmentation_preset: tuple[float | None, float | None], + lr_schedule: str, + use_unmodified_model: bool, + num_procs: int, +) -> None: + """Test the train() function with valid parameters.""" + check_train( + tflite_model=test_tflite_model_fp32, + tfrecord=test_tfrecord_fp32, + batch_size=batch_size, + verbose=verbose, + show_progress=show_progress, + augmentation_preset=augmentation_preset, + lr_schedule=lr_schedule, + use_unmodified_model=use_unmodified_model, + num_procs=num_procs, + ) + + +def test_train_invalid_schedule( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, +) -> None: + """Test the train() function with an invalid schedule.""" + with pytest.raises(ValueError): + check_train( + tflite_model=test_tflite_model_fp32, + tfrecord=test_tfrecord_fp32, + lr_schedule="unknown_schedule", + ) + + +def test_train_invalid_augmentation( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, +) -> None: + """Test the train() function with an invalid augmentation.""" + with pytest.raises(ValueError): + check_train( + tflite_model=test_tflite_model_fp32, + tfrecord=test_tfrecord_fp32, + augmentation_preset=(1.0, 2.0, 3.0), # type: ignore + ) + + +def test_mixup() -> None: + """Test the mixup() function.""" + src = np.array((1, 2, 3)) + dst = mixup(rng=np.random.default_rng(123), batch=src) + assert src.shape == dst.shape + assert np.all(dst >= 0.0) + assert np.all(dst <= 3.0) -- cgit v1.2.1