aboutsummaryrefslogtreecommitdiff
path: root/tests/conftest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/conftest.py')
-rw-r--r--tests/conftest.py39
1 files changed, 27 insertions, 12 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index c42b8cb..bb2423f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -5,6 +5,7 @@ import shutil
from pathlib import Path
from typing import Callable
from typing import Generator
+from unittest.mock import MagicMock
import numpy as np
import pytest
@@ -17,6 +18,7 @@ from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.target.ethos_u.config import EthosUConfiguration
+from tests.utils.rewrite import TestTrainingParameters
@pytest.fixture(scope="session", name="test_resources_path")
@@ -168,16 +170,12 @@ def _write_tfrecord(
writer.write({input_name: data_generator()})
-@pytest.fixture(scope="session", name="test_tfrecord")
-def fixture_test_tfrecord(
- tmp_path_factory: pytest.TempPathFactory,
+def create_tfrecord(
+ tmp_path_factory: pytest.TempPathFactory, random_data: Callable
) -> Generator[Path, None, None]:
"""Create a tfrecord with random data matching fixture 'test_tflite_model'."""
tmp_path = tmp_path_factory.mktemp("tfrecords")
- tfrecord_file = tmp_path / "test_int8.tfrecord"
-
- def random_data() -> np.ndarray:
- return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8)
+ tfrecord_file = tmp_path / "test.tfrecord"
_write_tfrecord(tfrecord_file, random_data)
@@ -186,19 +184,36 @@ def fixture_test_tfrecord(
shutil.rmtree(tmp_path)
+@pytest.fixture(scope="session", name="test_tfrecord")
+def fixture_test_tfrecord(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Create a tfrecord with random data matching fixture 'test_tflite_model'."""
+
+ def random_data() -> np.ndarray:
+ return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8)
+
+ yield from create_tfrecord(tmp_path_factory, random_data)
+
+
@pytest.fixture(scope="session", name="test_tfrecord_fp32")
def fixture_test_tfrecord_fp32(
tmp_path_factory: pytest.TempPathFactory,
) -> Generator[Path, None, None]:
"""Create tfrecord with random data matching fixture 'test_tflite_model_fp32'."""
- tmp_path = tmp_path_factory.mktemp("tfrecords")
- tfrecord_file = tmp_path / "test_fp32.tfrecord"
def random_data() -> np.ndarray:
return np.random.rand(1, 28, 28, 1).astype(np.float32)
- _write_tfrecord(tfrecord_file, random_data)
+ yield from create_tfrecord(tmp_path_factory, random_data)
- yield tfrecord_file
- shutil.rmtree(tmp_path)
+@pytest.fixture(scope="session", autouse=True)
+def set_training_steps() -> Generator[None, None, None]:
+ """Speed up tests by using TestTrainingParameters."""
+ with pytest.MonkeyPatch.context() as monkeypatch:
+ monkeypatch.setattr(
+ "mlia.nn.select._get_rewrite_train_params",
+ MagicMock(return_value=TestTrainingParameters()),
+ )
+ yield