aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_train.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r--tests/test_nn_rewrite_core_train.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index ef52320..7fb6f85 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -19,7 +19,7 @@ from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
-from tests.utils.rewrite import TestTrainingParameters
+from tests.utils.rewrite import MockTrainingParameters
def replace_fully_connected_with_conv(
@@ -45,7 +45,7 @@ def replace_fully_connected_with_conv(
def check_train(
tflite_model: Path,
tfrecord: Path,
- train_params: TrainingParameters = TestTrainingParameters(),
+ train_params: TrainingParameters = MockTrainingParameters(),
use_unmodified_model: bool = False,
quantized: bool = False,
) -> None:
@@ -115,7 +115,7 @@ def test_train_fp32(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
batch_size=batch_size,
show_progress=show_progress,
augmentations=augmentation_preset,
@@ -163,7 +163,7 @@ def test_train_int8(
check_train(
tflite_model=test_tflite_model,
tfrecord=test_tfrecord,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
batch_size=batch_size,
show_progress=show_progress,
augmentations=augmentation_preset,
@@ -184,7 +184,7 @@ def test_train_invalid_schedule(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
learning_rate_schedule="unknown_schedule",
),
)
@@ -199,7 +199,7 @@ def test_train_invalid_augmentation(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
augmentations=(1.0, 2.0, 3.0),
),
)