diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index ef52320..7fb6f85 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -19,7 +19,7 @@ from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters -from tests.utils.rewrite import TestTrainingParameters +from tests.utils.rewrite import MockTrainingParameters def replace_fully_connected_with_conv( @@ -45,7 +45,7 @@ def replace_fully_connected_with_conv( def check_train( tflite_model: Path, tfrecord: Path, - train_params: TrainingParameters = TestTrainingParameters(), + train_params: TrainingParameters = MockTrainingParameters(), use_unmodified_model: bool = False, quantized: bool = False, ) -> None: @@ -115,7 +115,7 @@ def test_train_fp32( check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - train_params=TestTrainingParameters( + train_params=MockTrainingParameters( batch_size=batch_size, show_progress=show_progress, augmentations=augmentation_preset, @@ -163,7 +163,7 @@ def test_train_int8( check_train( tflite_model=test_tflite_model, tfrecord=test_tfrecord, - train_params=TestTrainingParameters( + train_params=MockTrainingParameters( batch_size=batch_size, show_progress=show_progress, augmentations=augmentation_preset, @@ -184,7 +184,7 @@ def test_train_invalid_schedule( check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - train_params=TestTrainingParameters( + train_params=MockTrainingParameters( learning_rate_schedule="unknown_schedule", ), ) @@ -199,7 +199,7 @@ def test_train_invalid_augmentation( check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - train_params=TestTrainingParameters( + train_params=MockTrainingParameters( augmentations=(1.0, 2.0, 3.0), ), ) |