aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_train.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r--tests/test_nn_rewrite_core_train.py67
1 files changed, 64 insertions, 3 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index b001a09..ef52320 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Tests for module mlia.nn.rewrite.train."""
+"""Tests for module mlia.nn.rewrite.core.train."""
# pylint: disable=too-many-arguments
from __future__ import annotations
@@ -47,10 +47,11 @@ def check_train(
tfrecord: Path,
train_params: TrainingParameters = TestTrainingParameters(),
use_unmodified_model: bool = False,
+ quantized: bool = False,
) -> None:
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
- output_file = Path(tmp_dir, "out.tfrecord")
+ output_file = Path(tmp_dir, "out.tflite")
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
@@ -65,6 +66,17 @@ def check_train(
assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}"
assert output_file.is_file()
+ if quantized:
+ interpreter = tf.lite.Interpreter(model_path=str(output_file))
+ interpreter.allocate_tensors()
+ # Check that the quantization parameters are non-zero
+ assert all(interpreter.get_output_details()[0]["quantization"])
+ assert all(interpreter.get_input_details()[0]["quantization"])
+ dtypes = []
+ for tensor_detail in interpreter.get_tensor_details():
+ dtypes.append(tensor_detail["dtype"])
+ assert all(np.issubdtype(dtype, np.integer) for dtype in dtypes)
+
@pytest.mark.parametrize(
(
@@ -89,7 +101,7 @@ def check_train(
),
),
)
-def test_train(
+def test_train_fp32(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
batch_size: int,
@@ -114,6 +126,55 @@ def test_train(
)
+@pytest.mark.parametrize(
+ (
+ "batch_size",
+ "show_progress",
+ "augmentation_preset",
+ "lr_schedule",
+ "use_unmodified_model",
+ "num_procs",
+ ),
+ (
+ (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2),
+ (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1),
+ (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0),
+ (
+ 1,
+ False,
+ AUGMENTATION_PRESETS["mix_gaussian_large"],
+ "cosine",
+ False,
+ 2,
+ ),
+ ),
+)
+def test_train_int8(
+ test_tflite_model: Path,
+ test_tfrecord: Path,
+ batch_size: int,
+ show_progress: bool,
+ augmentation_preset: tuple[float | None, float | None],
+ lr_schedule: LearningRateSchedule,
+ use_unmodified_model: bool,
+ num_procs: int,
+) -> None:
+ """Test the train() function with valid parameters."""
+ check_train(
+ tflite_model=test_tflite_model,
+ tfrecord=test_tfrecord,
+ train_params=TestTrainingParameters(
+ batch_size=batch_size,
+ show_progress=show_progress,
+ augmentations=augmentation_preset,
+ learning_rate_schedule=lr_schedule,
+ num_procs=num_procs,
+ ),
+ use_unmodified_model=use_unmodified_model,
+ quantized=True,
+ )
+
+
def test_train_invalid_schedule(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,