diff options
Diffstat (limited to 'tests/conftest.py')
-rw-r--r-- | tests/conftest.py | 41 |
1 files changed, 31 insertions, 10 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index 9dc1d16..1092979 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ import numpy as np import pytest import tensorflow as tf -from mlia.backend.vela.compiler import optimize_model +from mlia.backend.vela.compiler import compile_model from mlia.core.context import ExecutionContext from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter from mlia.nn.tensorflow.tflite_convert import convert_to_tflite @@ -51,7 +51,7 @@ def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path: @pytest.fixture(scope="session", name="empty_test_csv_file") -def fixture_empty_test_csv_file( # pylint: disable=too-many-locals +def fixture_empty_test_csv_file( test_csv_path: Path, ) -> Path: """Return empty test csv file path.""" @@ -59,7 +59,7 @@ def fixture_empty_test_csv_file( # pylint: disable=too-many-locals @pytest.fixture(scope="session", name="test_csv_file") -def fixture_test_csv_file( # pylint: disable=too-many-locals +def fixture_test_csv_file( test_csv_path: Path, ) -> Path: """Return test csv file path.""" @@ -67,7 +67,7 @@ def fixture_test_csv_file( # pylint: disable=too-many-locals @pytest.fixture(scope="session", name="test_csv_path") -def fixture_test_csv_path( # pylint: disable=too-many-locals +def fixture_test_csv_path( tmp_path_factory: pytest.TempPathFactory, ) -> Generator[Path, None, None]: """Return test csv file path.""" @@ -76,6 +76,32 @@ def fixture_test_csv_path( # pylint: disable=too-many-locals shutil.rmtree(tmp_path) +@pytest.fixture(scope="session", name="test_vela_path") +def fixture_test_vela_path( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Return test vela file path.""" + tmp_path = tmp_path_factory.mktemp("vela_file") + yield tmp_path + shutil.rmtree(tmp_path) + + +@pytest.fixture(scope="session", name="empty_vela_ini_file") +def fixture_empty_vela_ini_file( + test_vela_path: Path, +) -> Path: + """Return empty test vela file path.""" + return test_vela_path / "empty_vela.ini" + + +@pytest.fixture(scope="session", name="vela_ini_file") +def fixture_vela_ini_file( + test_vela_path: Path, +) -> Path: + """Return empty test vela file path.""" + return test_vela_path / "vela.ini" + + def get_test_keras_model() -> tf.keras.Model: """Return test Keras model.""" model = tf.keras.Sequential( @@ -130,13 +156,8 @@ def fixture_test_models_path( convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path) # Vela-optimized TensorFlow Lite model (int8) - tflite_vela_model = tmp_path / TEST_MODEL_TFLITE_VELA_FILE target_config = EthosUConfiguration.load_profile("ethos-u55-256") - optimize_model( - tflite_model_path, - target_config.compiler_options, - tflite_vela_model, - ) + compile_model(tflite_model_path, target_config.compiler_options) tf.saved_model.save(keras_model, str(tmp_path / TEST_MODEL_TF_SAVED_MODEL_FILE)) |