aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_train.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-25 16:26:24 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-11-03 10:13:55 +0000
commit2dd8f2e64e4ecfbd49ef05d9f6d2644dd11a0462 (patch)
tree5b86753c82c193edd59053d33db59c1dcf0f366a /tests/test_nn_rewrite_core_train.py
parent0655c1cc43108c741a77e775b2a4ef46984829ce (diff)
downloadmlia-2dd8f2e64e4ecfbd49ef05d9f6d2644dd11a0462.tar.gz
Fix PytestCollectionWarning in unit tests
Rename 'TestTrainingParameters' to 'MockTrainingParameters' to avoid a PytestCollectionWarning during test parsing Change-Id: I26b52d46aa71bcc6748e38e92331be21a667e8c9 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r--tests/test_nn_rewrite_core_train.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index ef52320..7fb6f85 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -19,7 +19,7 @@ from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
-from tests.utils.rewrite import TestTrainingParameters
+from tests.utils.rewrite import MockTrainingParameters
def replace_fully_connected_with_conv(
@@ -45,7 +45,7 @@ def replace_fully_connected_with_conv(
def check_train(
tflite_model: Path,
tfrecord: Path,
- train_params: TrainingParameters = TestTrainingParameters(),
+ train_params: TrainingParameters = MockTrainingParameters(),
use_unmodified_model: bool = False,
quantized: bool = False,
) -> None:
@@ -115,7 +115,7 @@ def test_train_fp32(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
batch_size=batch_size,
show_progress=show_progress,
augmentations=augmentation_preset,
@@ -163,7 +163,7 @@ def test_train_int8(
check_train(
tflite_model=test_tflite_model,
tfrecord=test_tfrecord,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
batch_size=batch_size,
show_progress=show_progress,
augmentations=augmentation_preset,
@@ -184,7 +184,7 @@ def test_train_invalid_schedule(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
learning_rate_schedule="unknown_schedule",
),
)
@@ -199,7 +199,7 @@ def test_train_invalid_augmentation(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
augmentations=(1.0, 2.0, 3.0),
),
)