diff options
Diffstat (limited to 'tests/conftest.py')
-rw-r--r-- | tests/conftest.py | 95 |
1 files changed, 84 insertions, 11 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index 30889ca..c42b8cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,13 +3,16 @@ """Pytest conf module.""" import shutil from pathlib import Path +from typing import Callable from typing import Generator +import numpy as np import pytest import tensorflow as tf from mlia.backend.vela.compiler import optimize_model from mlia.core.context import ExecutionContext +from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import save_keras_model from mlia.nn.tensorflow.utils import save_tflite_model @@ -68,6 +71,14 @@ def get_test_keras_model() -> tf.keras.Model: return model +TEST_MODEL_KERAS_FILE = "test_model.h5" +TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite" +TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite" +TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite" +TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model" +TEST_MODEL_INVALID_FILE = "invalid.tflite" + + @pytest.fixture(scope="session", name="test_models_path") def fixture_test_models_path( tmp_path_factory: pytest.TempPathFactory, @@ -75,15 +86,23 @@ def fixture_test_models_path( """Provide path to the test models.""" tmp_path = tmp_path_factory.mktemp("models") + # Keras Model keras_model = get_test_keras_model() - save_keras_model(keras_model, tmp_path / "test_model.h5") + save_keras_model(keras_model, tmp_path / TEST_MODEL_KERAS_FILE) + + # Un-quantized TensorFlow Lite model (fp32) + save_tflite_model( + convert_to_tflite(keras_model, quantized=False), + tmp_path / TEST_MODEL_TFLITE_FP32_FILE, + ) + # Quantized TensorFlow Lite model (int8) tflite_model = convert_to_tflite(keras_model, quantized=True) - tflite_model_path = tmp_path / "test_model.tflite" + tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE save_tflite_model(tflite_model, tflite_model_path) - tflite_vela_model = tmp_path / "test_model_vela.tflite" - + # Vela-optimized TensorFlow Lite model (int8) + tflite_vela_model = tmp_path / TEST_MODEL_TFLITE_VELA_FILE target_config = EthosUConfiguration.load_profile("ethos-u55-256") optimize_model( tflite_model_path, @@ -91,9 +110,9 @@ def fixture_test_models_path( tflite_vela_model, ) - tf.saved_model.save(keras_model, str(tmp_path / "tf_model_test_model")) + tf.saved_model.save(keras_model, str(tmp_path / TEST_MODEL_TF_SAVED_MODEL_FILE)) - invalid_tflite_model = tmp_path / "invalid.tflite" + invalid_tflite_model = tmp_path / TEST_MODEL_INVALID_FILE invalid_tflite_model.touch() yield tmp_path @@ -104,28 +123,82 @@ def fixture_test_models_path( @pytest.fixture(scope="session", name="test_keras_model") def fixture_test_keras_model(test_models_path: Path) -> Path: """Return test Keras model.""" - return test_models_path / "test_model.h5" + return test_models_path / TEST_MODEL_KERAS_FILE @pytest.fixture(scope="session", name="test_tflite_model") def fixture_test_tflite_model(test_models_path: Path) -> Path: """Return test TensorFlow Lite model.""" - return test_models_path / "test_model.tflite" + return test_models_path / TEST_MODEL_TFLITE_INT8_FILE + + +@pytest.fixture(scope="session", name="test_tflite_model_fp32") +def fixture_test_tflite_model_fp32(test_models_path: Path) -> Path: + """Return test TensorFlow Lite model.""" + return test_models_path / TEST_MODEL_TFLITE_FP32_FILE @pytest.fixture(scope="session", name="test_tflite_vela_model") def fixture_test_tflite_vela_model(test_models_path: Path) -> Path: """Return test Vela-optimized TensorFlow Lite model.""" - return test_models_path / "test_model_vela.tflite" + return test_models_path / TEST_MODEL_TFLITE_VELA_FILE @pytest.fixture(scope="session", name="test_tf_model") def fixture_test_tf_model(test_models_path: Path) -> Path: """Return test TensorFlow Lite model.""" - return test_models_path / "tf_model_test_model" + return test_models_path / TEST_MODEL_TF_SAVED_MODEL_FILE @pytest.fixture(scope="session", name="test_tflite_invalid_model") def fixture_test_tflite_invalid_model(test_models_path: Path) -> Path: """Return test invalid TensorFlow Lite model.""" - return test_models_path / "invalid.tflite" + return test_models_path / TEST_MODEL_INVALID_FILE + + +def _write_tfrecord( + tfrecord_file: Path, + data_generator: Callable, + input_name: str = "serving_default_input:0", + num_records: int = 3, +) -> None: + """Write data to a tfrecord.""" + with NumpyTFWriter(str(tfrecord_file)) as writer: + for _ in range(num_records): + writer.write({input_name: data_generator()}) + + +@pytest.fixture(scope="session", name="test_tfrecord") +def fixture_test_tfrecord( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" + tmp_path = tmp_path_factory.mktemp("tfrecords") + tfrecord_file = tmp_path / "test_int8.tfrecord" + + def random_data() -> np.ndarray: + return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8) + + _write_tfrecord(tfrecord_file, random_data) + + yield tfrecord_file + + shutil.rmtree(tmp_path) + + +@pytest.fixture(scope="session", name="test_tfrecord_fp32") +def fixture_test_tfrecord_fp32( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create tfrecord with random data matching fixture 'test_tflite_model_fp32'.""" + tmp_path = tmp_path_factory.mktemp("tfrecords") + tfrecord_file = tmp_path / "test_fp32.tfrecord" + + def random_data() -> np.ndarray: + return np.random.rand(1, 28, 28, 1).astype(np.float32) + + _write_tfrecord(tfrecord_file, random_data) + + yield tfrecord_file + + shutil.rmtree(tmp_path) |