aboutsummaryrefslogtreecommitdiff
path: root/tests/conftest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/conftest.py')
-rw-r--r--tests/conftest.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index 1092979..53bfb0c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,6 +7,7 @@ from typing import Callable
from typing import Generator
from unittest.mock import MagicMock
+import _pytest
import numpy as np
import pytest
import tensorflow as tf
@@ -256,11 +257,16 @@ def fixture_test_tfrecord_fp32(
@pytest.fixture(scope="session", autouse=True)
-def set_training_steps() -> Generator[None, None, None]:
+def set_training_steps(
+ request: _pytest.fixtures.SubRequest,
+) -> Generator[None, None, None]:
"""Speed up tests by using MockTrainingParameters."""
- with pytest.MonkeyPatch.context() as monkeypatch:
- monkeypatch.setattr(
- "mlia.nn.select._get_rewrite_train_params",
- MagicMock(return_value=MockTrainingParameters()),
- )
+ if "set_training_steps" == request.fixturename:
yield
+ else:
+ with pytest.MonkeyPatch.context() as monkeypatch:
+ monkeypatch.setattr(
+ "mlia.nn.select._get_rewrite_params",
+ MagicMock(return_value=[MockTrainingParameters(), None, None]),
+ )
+ yield