aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-25 16:26:24 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-11-03 10:13:55 +0000
commit2dd8f2e64e4ecfbd49ef05d9f6d2644dd11a0462 (patch)
tree5b86753c82c193edd59053d33db59c1dcf0f366a
parent0655c1cc43108c741a77e775b2a4ef46984829ce (diff)
downloadmlia-2dd8f2e64e4ecfbd49ef05d9f6d2644dd11a0462.tar.gz
Fix PytestCollectionWarning in unit tests
Rename 'TestTrainingParameters' to 'MockTrainingParameters' to avoid a PytestCollectionWarning during test parsing Change-Id: I26b52d46aa71bcc6748e38e92331be21a667e8c9 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
-rw-r--r--tests/conftest.py6
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py4
-rw-r--r--tests/test_nn_rewrite_core_train.py12
-rw-r--r--tests/utils/rewrite.py2
4 files changed, 12 insertions, 12 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index bb2423f..d700206 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -18,7 +18,7 @@ from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.target.ethos_u.config import EthosUConfiguration
-from tests.utils.rewrite import TestTrainingParameters
+from tests.utils.rewrite import MockTrainingParameters
@pytest.fixture(scope="session", name="test_resources_path")
@@ -210,10 +210,10 @@ def fixture_test_tfrecord_fp32(
@pytest.fixture(scope="session", autouse=True)
def set_training_steps() -> Generator[None, None, None]:
- """Speed up tests by using TestTrainingParameters."""
+ """Speed up tests by using MockTrainingParameters."""
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"mlia.nn.select._get_rewrite_train_params",
- MagicMock(return_value=TestTrainingParameters()),
+ MagicMock(return_value=MockTrainingParameters()),
)
yield
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index d4aac56..487784d 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -17,7 +17,7 @@ from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
from mlia.nn.tensorflow.config import TFLiteModel
-from tests.utils.rewrite import TestTrainingParameters
+from tests.utils.rewrite import MockTrainingParameters
def mock_rewrite_function(*_: Any) -> Any:
@@ -69,7 +69,7 @@ def test_rewriting_optimizer(
"fully_connected",
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
test_tfrecord_fp32,
- train_params=TestTrainingParameters(),
+ train_params=MockTrainingParameters(),
)
test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index ef52320..7fb6f85 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -19,7 +19,7 @@ from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
-from tests.utils.rewrite import TestTrainingParameters
+from tests.utils.rewrite import MockTrainingParameters
def replace_fully_connected_with_conv(
@@ -45,7 +45,7 @@ def replace_fully_connected_with_conv(
def check_train(
tflite_model: Path,
tfrecord: Path,
- train_params: TrainingParameters = TestTrainingParameters(),
+ train_params: TrainingParameters = MockTrainingParameters(),
use_unmodified_model: bool = False,
quantized: bool = False,
) -> None:
@@ -115,7 +115,7 @@ def test_train_fp32(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
batch_size=batch_size,
show_progress=show_progress,
augmentations=augmentation_preset,
@@ -163,7 +163,7 @@ def test_train_int8(
check_train(
tflite_model=test_tflite_model,
tfrecord=test_tfrecord,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
batch_size=batch_size,
show_progress=show_progress,
augmentations=augmentation_preset,
@@ -184,7 +184,7 @@ def test_train_invalid_schedule(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
learning_rate_schedule="unknown_schedule",
),
)
@@ -199,7 +199,7 @@ def test_train_invalid_augmentation(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- train_params=TestTrainingParameters(
+ train_params=MockTrainingParameters(
augmentations=(1.0, 2.0, 3.0),
),
)
diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py
index 739bb11..29351bf 100644
--- a/tests/utils/rewrite.py
+++ b/tests/utils/rewrite.py
@@ -31,7 +31,7 @@ def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
return True
-class TestTrainingParameters(
+class MockTrainingParameters(
TrainingParameters
): # pylint: disable=too-few-public-methods
"""