# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Pytest conf module.""" import shutil from pathlib import Path from typing import Callable from typing import Generator from unittest.mock import MagicMock import _pytest import numpy as np import pytest import tensorflow as tf from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.backend.vela.compiler import compile_model from mlia.core.context import ExecutionContext from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.utils import save_keras_model from mlia.target.ethos_u.config import EthosUConfiguration from tests.utils.rewrite import MockTrainingParameters @pytest.fixture(scope="session", name="test_resources_path") def fixture_test_resources_path() -> Path: """Return test resources path.""" return Path(__file__).parent / "test_resources" @pytest.fixture(name="sample_context") def fixture_sample_context(tmpdir: str) -> ExecutionContext: """Return sample context fixture.""" return ExecutionContext(output_dir=tmpdir) @pytest.fixture(scope="session") def non_optimised_input_model_file(test_tflite_model: Path) -> Path: """Provide the path to a quantized test model file.""" return test_tflite_model @pytest.fixture(scope="session") def optimised_input_model_file(test_tflite_vela_model: Path) -> Path: """Provide path to Vela-optimised test model file.""" return test_tflite_vela_model @pytest.fixture(scope="session") def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path: """Provide the path to an invalid test model file.""" return test_tflite_invalid_model @pytest.fixture(scope="session", name="empty_test_csv_file") def fixture_empty_test_csv_file( test_csv_path: Path, ) -> Path: """Return empty test csv file path.""" return test_csv_path / "empty_test_csv_file.csv" @pytest.fixture(scope="session", name="test_csv_file") def fixture_test_csv_file( test_csv_path: Path, ) -> Path: """Return test csv file path.""" return test_csv_path / "test_csv_file.csv" @pytest.fixture(scope="session", name="test_csv_path") def fixture_test_csv_path( tmp_path_factory: pytest.TempPathFactory, ) -> Generator[Path, None, None]: """Return test csv file path.""" tmp_path = tmp_path_factory.mktemp("csv_files") yield tmp_path shutil.rmtree(tmp_path) @pytest.fixture(scope="session", name="test_vela_path") def fixture_test_vela_path( tmp_path_factory: pytest.TempPathFactory, ) -> Generator[Path, None, None]: """Return test vela file path.""" tmp_path = tmp_path_factory.mktemp("vela_file") yield tmp_path shutil.rmtree(tmp_path) @pytest.fixture(scope="session", name="empty_vela_ini_file") def fixture_empty_vela_ini_file( test_vela_path: Path, ) -> Path: """Return empty test vela file path.""" return test_vela_path / "empty_vela.ini" @pytest.fixture(scope="session", name="vela_ini_file") def fixture_vela_ini_file( test_vela_path: Path, ) -> Path: """Return empty test vela file path.""" return test_vela_path / "vela.ini" def get_test_keras_model() -> keras.Model: """Return test Keras model.""" model = keras.Sequential( [ keras.Input(shape=(28, 28, 1), batch_size=1, name="input"), keras.layers.Reshape((28, 28, 1)), keras.layers.Conv2D( filters=12, kernel_size=(3, 3), activation="relu", name="conv1" ), keras.layers.Conv2D( filters=12, kernel_size=(3, 3), activation="relu", name="conv2" ), keras.layers.MaxPool2D(2, 2), keras.layers.Flatten(), keras.layers.Dense(10, name="output"), ] ) model.compile(optimizer="sgd", loss="mean_squared_error") return model TEST_MODEL_KERAS_FILE = "test_model.h5" TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite" TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite" TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite" TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model" TEST_MODEL_INVALID_FILE = "invalid.tflite" @pytest.fixture(scope="session", name="test_models_path") def fixture_test_models_path( tmp_path_factory: pytest.TempPathFactory, ) -> Generator[Path, None, None]: """Provide path to the test models.""" tmp_path = tmp_path_factory.mktemp("models") # Need an output directory for verbose performance Path("output").mkdir(exist_ok=True) # Keras Model keras_model = get_test_keras_model() save_keras_model(keras_model, tmp_path / TEST_MODEL_KERAS_FILE) # Un-quantized TensorFlow Lite model (fp32) convert_to_tflite( keras_model, quantized=False, output_path=tmp_path / TEST_MODEL_TFLITE_FP32_FILE ) # Quantized TensorFlow Lite model (int8) tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path) # Vela-optimized TensorFlow Lite model (int8) target_config = EthosUConfiguration.load_profile("ethos-u55-256") compile_model(tflite_model_path, target_config.compiler_options) tf.saved_model.save(keras_model, str(tmp_path / TEST_MODEL_TF_SAVED_MODEL_FILE)) invalid_tflite_model = tmp_path / TEST_MODEL_INVALID_FILE invalid_tflite_model.touch() yield tmp_path shutil.rmtree(tmp_path) @pytest.fixture(scope="session", name="test_keras_model") def fixture_test_keras_model(test_models_path: Path) -> Path: """Return test Keras model.""" return test_models_path / TEST_MODEL_KERAS_FILE @pytest.fixture(scope="session", name="test_tflite_model") def fixture_test_tflite_model(test_models_path: Path) -> Path: """Return test TensorFlow Lite model.""" return test_models_path / TEST_MODEL_TFLITE_INT8_FILE @pytest.fixture(scope="session", name="test_tflite_model_fp32") def fixture_test_tflite_model_fp32(test_models_path: Path) -> Path: """Return test TensorFlow Lite model.""" return test_models_path / TEST_MODEL_TFLITE_FP32_FILE @pytest.fixture(scope="session", name="test_tflite_vela_model") def fixture_test_tflite_vela_model(test_models_path: Path) -> Path: """Return test Vela-optimized TensorFlow Lite model.""" return test_models_path / TEST_MODEL_TFLITE_VELA_FILE @pytest.fixture(scope="session", name="test_tf_model") def fixture_test_tf_model(test_models_path: Path) -> Path: """Return test TensorFlow Lite model.""" return test_models_path / TEST_MODEL_TF_SAVED_MODEL_FILE @pytest.fixture(scope="session", name="test_tflite_invalid_model") def fixture_test_tflite_invalid_model(test_models_path: Path) -> Path: """Return test invalid TensorFlow Lite model.""" return test_models_path / TEST_MODEL_INVALID_FILE def _write_tfrecord( tfrecord_file: Path, data_generator: Callable, input_name: str = "serving_default_input:0", num_records: int = 3, ) -> None: """Write data to a tfrecord.""" with NumpyTFWriter(str(tfrecord_file)) as writer: for _ in range(num_records): writer.write({input_name: data_generator()}) def create_tfrecord( tmp_path_factory: pytest.TempPathFactory, random_data: Callable ) -> Generator[Path, None, None]: """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" tmp_path = tmp_path_factory.mktemp("tfrecords") tfrecord_file = tmp_path / "test.tfrecord" _write_tfrecord(tfrecord_file, random_data) yield tfrecord_file shutil.rmtree(tmp_path) @pytest.fixture(scope="session", name="test_tfrecord") def fixture_test_tfrecord( tmp_path_factory: pytest.TempPathFactory, ) -> Generator[Path, None, None]: """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" def random_data() -> np.ndarray: return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8) yield from create_tfrecord(tmp_path_factory, random_data) @pytest.fixture(scope="session", name="test_tfrecord_fp32") def fixture_test_tfrecord_fp32( tmp_path_factory: pytest.TempPathFactory, ) -> Generator[Path, None, None]: """Create tfrecord with random data matching fixture 'test_tflite_model_fp32'.""" def random_data() -> np.ndarray: return np.random.rand(1, 28, 28, 1).astype(np.float32) yield from create_tfrecord(tmp_path_factory, random_data) @pytest.fixture(scope="function", autouse=True) def set_training_steps( request: _pytest.fixtures.SubRequest, ) -> Generator[None, None, None]: """Speed up tests by using MockTrainingParameters.""" if "skip_set_training_steps" not in request.keywords: with pytest.MonkeyPatch.context() as monkeypatch: monkeypatch.setattr( "mlia.nn.select._get_rewrite_params", MagicMock(return_value=MockTrainingParameters()), ) yield