aboutsummaryrefslogtreecommitdiff
path: root/tests/conftest.py
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-02-15 14:50:58 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-14 15:45:40 +0000
commit0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch)
tree09b40b939acbe0bcf02dcc77a7ed7ce4aba94322 /tests/conftest.py
parent09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff)
downloadmlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization Resolves: MLIA-1004 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
Diffstat (limited to 'tests/conftest.py')
-rw-r--r--tests/conftest.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index 1092979..53bfb0c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,6 +7,7 @@ from typing import Callable
from typing import Generator
from unittest.mock import MagicMock
+import _pytest
import numpy as np
import pytest
import tensorflow as tf
@@ -256,11 +257,16 @@ def fixture_test_tfrecord_fp32(
@pytest.fixture(scope="session", autouse=True)
-def set_training_steps() -> Generator[None, None, None]:
+def set_training_steps(
+ request: _pytest.fixtures.SubRequest,
+) -> Generator[None, None, None]:
"""Speed up tests by using MockTrainingParameters."""
- with pytest.MonkeyPatch.context() as monkeypatch:
- monkeypatch.setattr(
- "mlia.nn.select._get_rewrite_train_params",
- MagicMock(return_value=MockTrainingParameters()),
- )
+ if "set_training_steps" == request.fixturename:
yield
+ else:
+ with pytest.MonkeyPatch.context() as monkeypatch:
+ monkeypatch.setattr(
+ "mlia.nn.select._get_rewrite_params",
+ MagicMock(return_value=[MockTrainingParameters(), None, None]),
+ )
+ yield