From 2dd8f2e64e4ecfbd49ef05d9f6d2644dd11a0462 Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Wed, 25 Oct 2023 16:26:24 +0100 Subject: Fix PytestCollectionWarning in unit tests Rename 'TestTrainingParameters' to 'MockTrainingParameters' to avoid a PytestCollectionWarning during test parsing Change-Id: I26b52d46aa71bcc6748e38e92331be21a667e8c9 Signed-off-by: Benjamin Klimczak --- tests/conftest.py | 6 +++--- tests/test_nn_rewrite_core_rewrite.py | 4 ++-- tests/test_nn_rewrite_core_train.py | 12 ++++++------ tests/utils/rewrite.py | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bb2423f..d700206 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import save_keras_model from mlia.nn.tensorflow.utils import save_tflite_model from mlia.target.ethos_u.config import EthosUConfiguration -from tests.utils.rewrite import TestTrainingParameters +from tests.utils.rewrite import MockTrainingParameters @pytest.fixture(scope="session", name="test_resources_path") @@ -210,10 +210,10 @@ def fixture_test_tfrecord_fp32( @pytest.fixture(scope="session", autouse=True) def set_training_steps() -> Generator[None, None, None]: - """Speed up tests by using TestTrainingParameters.""" + """Speed up tests by using MockTrainingParameters.""" with pytest.MonkeyPatch.context() as monkeypatch: monkeypatch.setattr( "mlia.nn.select._get_rewrite_train_params", - MagicMock(return_value=TestTrainingParameters()), + MagicMock(return_value=MockTrainingParameters()), ) yield diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index d4aac56..487784d 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -17,7 +17,7 @@ from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer from mlia.nn.tensorflow.config import TFLiteModel -from tests.utils.rewrite import TestTrainingParameters +from tests.utils.rewrite import MockTrainingParameters def mock_rewrite_function(*_: Any) -> Any: @@ -69,7 +69,7 @@ def test_rewriting_optimizer( "fully_connected", ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], test_tfrecord_fp32, - train_params=TestTrainingParameters(), + train_params=MockTrainingParameters(), ) test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index ef52320..7fb6f85 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -19,7 +19,7 @@ from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters -from tests.utils.rewrite import TestTrainingParameters +from tests.utils.rewrite import MockTrainingParameters def replace_fully_connected_with_conv( @@ -45,7 +45,7 @@ def replace_fully_connected_with_conv( def check_train( tflite_model: Path, tfrecord: Path, - train_params: TrainingParameters = TestTrainingParameters(), + train_params: TrainingParameters = MockTrainingParameters(), use_unmodified_model: bool = False, quantized: bool = False, ) -> None: @@ -115,7 +115,7 @@ def test_train_fp32( check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - train_params=TestTrainingParameters( + train_params=MockTrainingParameters( batch_size=batch_size, show_progress=show_progress, augmentations=augmentation_preset, @@ -163,7 +163,7 @@ def test_train_int8( check_train( tflite_model=test_tflite_model, tfrecord=test_tfrecord, - train_params=TestTrainingParameters( + train_params=MockTrainingParameters( batch_size=batch_size, show_progress=show_progress, augmentations=augmentation_preset, @@ -184,7 +184,7 @@ def test_train_invalid_schedule( check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - train_params=TestTrainingParameters( + train_params=MockTrainingParameters( learning_rate_schedule="unknown_schedule", ), ) @@ -199,7 +199,7 @@ def test_train_invalid_augmentation( check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - train_params=TestTrainingParameters( + train_params=MockTrainingParameters( augmentations=(1.0, 2.0, 3.0), ), ) diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py index 739bb11..29351bf 100644 --- a/tests/utils/rewrite.py +++ b/tests/utils/rewrite.py @@ -31,7 +31,7 @@ def models_are_equal(model1: ModelT, model2: ModelT) -> bool: return True -class TestTrainingParameters( +class MockTrainingParameters( TrainingParameters ): # pylint: disable=too-few-public-methods """ -- cgit v1.2.1