diff options
Diffstat (limited to 'tests/conftest.py')
-rw-r--r-- | tests/conftest.py | 39 |
1 files changed, 27 insertions, 12 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index c42b8cb..bb2423f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import shutil from pathlib import Path from typing import Callable from typing import Generator +from unittest.mock import MagicMock import numpy as np import pytest @@ -17,6 +18,7 @@ from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import save_keras_model from mlia.nn.tensorflow.utils import save_tflite_model from mlia.target.ethos_u.config import EthosUConfiguration +from tests.utils.rewrite import TestTrainingParameters @pytest.fixture(scope="session", name="test_resources_path") @@ -168,16 +170,12 @@ def _write_tfrecord( writer.write({input_name: data_generator()}) -@pytest.fixture(scope="session", name="test_tfrecord") -def fixture_test_tfrecord( - tmp_path_factory: pytest.TempPathFactory, +def create_tfrecord( + tmp_path_factory: pytest.TempPathFactory, random_data: Callable ) -> Generator[Path, None, None]: """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" tmp_path = tmp_path_factory.mktemp("tfrecords") - tfrecord_file = tmp_path / "test_int8.tfrecord" - - def random_data() -> np.ndarray: - return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8) + tfrecord_file = tmp_path / "test.tfrecord" _write_tfrecord(tfrecord_file, random_data) @@ -186,19 +184,36 @@ def fixture_test_tfrecord( shutil.rmtree(tmp_path) +@pytest.fixture(scope="session", name="test_tfrecord") +def fixture_test_tfrecord( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" + + def random_data() -> np.ndarray: + return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8) + + yield from create_tfrecord(tmp_path_factory, random_data) + + @pytest.fixture(scope="session", name="test_tfrecord_fp32") def fixture_test_tfrecord_fp32( tmp_path_factory: pytest.TempPathFactory, ) -> Generator[Path, None, None]: """Create tfrecord with random data matching fixture 'test_tflite_model_fp32'.""" - tmp_path = tmp_path_factory.mktemp("tfrecords") - tfrecord_file = tmp_path / "test_fp32.tfrecord" def random_data() -> np.ndarray: return np.random.rand(1, 28, 28, 1).astype(np.float32) - _write_tfrecord(tfrecord_file, random_data) + yield from create_tfrecord(tmp_path_factory, random_data) - yield tfrecord_file - shutil.rmtree(tmp_path) +@pytest.fixture(scope="session", autouse=True) +def set_training_steps() -> Generator[None, None, None]: + """Speed up tests by using TestTrainingParameters.""" + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "mlia.nn.select._get_rewrite_train_params", + MagicMock(return_value=TestTrainingParameters()), + ) + yield |