diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 67 |
1 files changed, 64 insertions, 3 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index b001a09..ef52320 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Tests for module mlia.nn.rewrite.train.""" +"""Tests for module mlia.nn.rewrite.core.train.""" # pylint: disable=too-many-arguments from __future__ import annotations @@ -47,10 +47,11 @@ def check_train( tfrecord: Path, train_params: TrainingParameters = TestTrainingParameters(), use_unmodified_model: bool = False, + quantized: bool = False, ) -> None: """Test the train() function.""" with TemporaryDirectory() as tmp_dir: - output_file = Path(tmp_dir, "out.tfrecord") + output_file = Path(tmp_dir, "out.tflite") result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, @@ -65,6 +66,17 @@ def check_train( assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}" assert output_file.is_file() + if quantized: + interpreter = tf.lite.Interpreter(model_path=str(output_file)) + interpreter.allocate_tensors() + # Check that the quantization parameters are non-zero + assert all(interpreter.get_output_details()[0]["quantization"]) + assert all(interpreter.get_input_details()[0]["quantization"]) + dtypes = [] + for tensor_detail in interpreter.get_tensor_details(): + dtypes.append(tensor_detail["dtype"]) + assert all(np.issubdtype(dtype, np.integer) for dtype in dtypes) + @pytest.mark.parametrize( ( @@ -89,7 +101,7 @@ def check_train( ), ), ) -def test_train( +def test_train_fp32( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, batch_size: int, @@ -114,6 +126,55 @@ def test_train( ) +@pytest.mark.parametrize( + ( + "batch_size", + "show_progress", + "augmentation_preset", + "lr_schedule", + "use_unmodified_model", + "num_procs", + ), + ( + (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2), + (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1), + (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0), + ( + 1, + False, + AUGMENTATION_PRESETS["mix_gaussian_large"], + "cosine", + False, + 2, + ), + ), +) +def test_train_int8( + test_tflite_model: Path, + test_tfrecord: Path, + batch_size: int, + show_progress: bool, + augmentation_preset: tuple[float | None, float | None], + lr_schedule: LearningRateSchedule, + use_unmodified_model: bool, + num_procs: int, +) -> None: + """Test the train() function with valid parameters.""" + check_train( + tflite_model=test_tflite_model, + tfrecord=test_tfrecord, + train_params=TestTrainingParameters( + batch_size=batch_size, + show_progress=show_progress, + augmentations=augmentation_preset, + learning_rate_schedule=lr_schedule, + num_procs=num_procs, + ), + use_unmodified_model=use_unmodified_model, + quantized=True, + ) + + def test_train_invalid_schedule( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, |