diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-02-15 14:50:58 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-14 15:45:40 +0000 |
commit | 0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch) | |
tree | 09b40b939acbe0bcf02dcc77a7ed7ce4aba94322 /tests/conftest.py | |
parent | 09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff) | |
download | mlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz |
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization
Resolves: MLIA-1004
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
Diffstat (limited to 'tests/conftest.py')
-rw-r--r-- | tests/conftest.py | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index 1092979..53bfb0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from typing import Callable from typing import Generator from unittest.mock import MagicMock +import _pytest import numpy as np import pytest import tensorflow as tf @@ -256,11 +257,16 @@ def fixture_test_tfrecord_fp32( @pytest.fixture(scope="session", autouse=True) -def set_training_steps() -> Generator[None, None, None]: +def set_training_steps( + request: _pytest.fixtures.SubRequest, +) -> Generator[None, None, None]: """Speed up tests by using MockTrainingParameters.""" - with pytest.MonkeyPatch.context() as monkeypatch: - monkeypatch.setattr( - "mlia.nn.select._get_rewrite_train_params", - MagicMock(return_value=MockTrainingParameters()), - ) + if "set_training_steps" == request.fixturename: yield + else: + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "mlia.nn.select._get_rewrite_params", + MagicMock(return_value=[MockTrainingParameters(), None, None]), + ) + yield |