aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_train.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-07-19 16:35:57 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 16:06:17 +0100
commit3cd84481fa25e64c29e57396d4bf32d7a3ca490a (patch)
treead81fb520a965bd3a3c7c983833b7cd48f9b8dea /tests/test_nn_rewrite_core_train.py
parentf3e6597dd50ec70f043d692b773f2d9fd31519ae (diff)
downloadmlia-3cd84481fa25e64c29e57396d4bf32d7a3ca490a.tar.gz
Bug-fixes and re-factoring for the rewrite module
- Fix input shape of rewrite replacement: During and after training of the replacement model for a rewrite the Keras model is converted and saved in TensorFlow Lite format. If the input shape does not match the teacher model exactly, e.g. if the batch size is undefined, the TFLiteConverter adds extra operators during conversion. - Fix rewritten model output - Save the model output with the rewritten operator in the output dir - Log MAE and NRMSE of the rewrite - Remove 'verbose' flag from rewrite module and rely on the logging mechanism to control verbose output. - Re-factor utility classes for rewrites - Merge the two TFLiteModel classes - Move functionality to load/save TensorFlow Lite flatbuffers to nn/tensorflow/tflite_graph - Fix issue with unknown shape in datasets After upgrading to TensorFlow 2.12 the unknown shape of the TFRecordDataset is causing problems when training the replacement models for rewrites. By explicitly setting the right shape of the tensors we can work around the issue. - Adapt default parameters for rewrites. The training steps especially had to be increased significantly to be effective. Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r--tests/test_nn_rewrite_core_train.py76
1 files changed, 42 insertions, 34 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 3c2ef3e..4493671 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -4,6 +4,7 @@
# pylint: disable=too-many-arguments
from __future__ import annotations
+from contextlib import ExitStack as does_not_raise
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
@@ -12,10 +13,13 @@ import numpy as np
import pytest
import tensorflow as tf
-from mlia.nn.rewrite.core.train import augmentation_presets
+from mlia.nn.rewrite.core.train import augment_fn_twins
+from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
+from mlia.nn.rewrite.core.train import TrainingParameters
+from tests.utils.rewrite import TestTrainingParameters
def replace_fully_connected_with_conv(
@@ -41,15 +45,8 @@ def replace_fully_connected_with_conv(
def check_train(
tflite_model: Path,
tfrecord: Path,
- batch_size: int = 1,
- verbose: bool = False,
- show_progress: bool = False,
- augmentation_preset: tuple[float | None, float | None] = augmentation_presets[
- "none"
- ],
- lr_schedule: LearningRateSchedule = "cosine",
+ train_params: TrainingParameters = TestTrainingParameters(),
use_unmodified_model: bool = False,
- num_procs: int = 1,
) -> None:
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
@@ -62,14 +59,7 @@ def check_train(
replace_fn=replace_fully_connected_with_conv,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
- augment=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- learning_rate_schedule=lr_schedule,
- num_procs=num_procs,
+ train_params=train_params,
)
assert len(result) == 2
assert all(res >= 0.0 for res in result), f"Results out of bound: {result}"
@@ -79,7 +69,6 @@ def check_train(
@pytest.mark.parametrize(
(
"batch_size",
- "verbose",
"show_progress",
"augmentation_preset",
"lr_schedule",
@@ -87,14 +76,13 @@ def check_train(
"num_procs",
),
(
- (1, False, False, augmentation_presets["none"], "cosine", False, 2),
- (32, True, True, augmentation_presets["gaussian"], "late", True, 1),
- (2, False, False, augmentation_presets["mixup"], "constant", True, 0),
+ (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2),
+ (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1),
+ (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0),
(
1,
False,
- False,
- augmentation_presets["mix_gaussian_large"],
+ AUGMENTATION_PRESETS["mix_gaussian_large"],
"cosine",
False,
2,
@@ -105,7 +93,6 @@ def test_train(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
batch_size: int,
- verbose: bool,
show_progress: bool,
augmentation_preset: tuple[float | None, float | None],
lr_schedule: LearningRateSchedule,
@@ -116,13 +103,14 @@ def test_train(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- augmentation_preset=augmentation_preset,
- lr_schedule=lr_schedule,
+ train_params=TestTrainingParameters(
+ batch_size=batch_size,
+ show_progress=show_progress,
+ augmentations=augmentation_preset,
+ learning_rate_schedule=lr_schedule,
+ num_procs=num_procs,
+ ),
use_unmodified_model=use_unmodified_model,
- num_procs=num_procs,
)
@@ -131,11 +119,13 @@ def test_train_invalid_schedule(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid schedule."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- lr_schedule="unknown_schedule", # type: ignore
+ train_params=TestTrainingParameters(
+ learning_rate_schedule="unknown_schedule",
+ ),
)
@@ -144,11 +134,13 @@ def test_train_invalid_augmentation(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid augmentation."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- augmentation_preset=(1.0, 2.0, 3.0), # type: ignore
+ train_params=TestTrainingParameters(
+ augmentations=(1.0, 2.0, 3.0),
+ ),
)
@@ -159,3 +151,19 @@ def test_mixup() -> None:
assert src.shape == dst.shape
assert np.all(dst >= 0.0)
assert np.all(dst <= 3.0)
+
+
+@pytest.mark.parametrize(
+ "augmentations, expected_error",
+ [
+ (AUGMENTATION_PRESETS["none"], does_not_raise()),
+ (AUGMENTATION_PRESETS["mix_gaussian_large"], does_not_raise()),
+ ((None,) * 3, pytest.raises(AssertionError)),
+ ],
+)
+def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None:
+ """Test function augment_fn()."""
+ dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2, 3], "b": [4, 5, 6]})
+ with expected_error:
+ fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore
+ assert len(fn_twins) == 2