aboutsummaryrefslogtreecommitdiff
path: root/tests/conftest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/conftest.py')
-rw-r--r--tests/conftest.py41
1 files changed, 31 insertions, 10 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index 9dc1d16..1092979 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -11,7 +11,7 @@ import numpy as np
import pytest
import tensorflow as tf
-from mlia.backend.vela.compiler import optimize_model
+from mlia.backend.vela.compiler import compile_model
from mlia.core.context import ExecutionContext
from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter
from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
@@ -51,7 +51,7 @@ def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path:
@pytest.fixture(scope="session", name="empty_test_csv_file")
-def fixture_empty_test_csv_file( # pylint: disable=too-many-locals
+def fixture_empty_test_csv_file(
test_csv_path: Path,
) -> Path:
"""Return empty test csv file path."""
@@ -59,7 +59,7 @@ def fixture_empty_test_csv_file( # pylint: disable=too-many-locals
@pytest.fixture(scope="session", name="test_csv_file")
-def fixture_test_csv_file( # pylint: disable=too-many-locals
+def fixture_test_csv_file(
test_csv_path: Path,
) -> Path:
"""Return test csv file path."""
@@ -67,7 +67,7 @@ def fixture_test_csv_file( # pylint: disable=too-many-locals
@pytest.fixture(scope="session", name="test_csv_path")
-def fixture_test_csv_path( # pylint: disable=too-many-locals
+def fixture_test_csv_path(
tmp_path_factory: pytest.TempPathFactory,
) -> Generator[Path, None, None]:
"""Return test csv file path."""
@@ -76,6 +76,32 @@ def fixture_test_csv_path( # pylint: disable=too-many-locals
shutil.rmtree(tmp_path)
+@pytest.fixture(scope="session", name="test_vela_path")
+def fixture_test_vela_path(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Return test vela file path."""
+ tmp_path = tmp_path_factory.mktemp("vela_file")
+ yield tmp_path
+ shutil.rmtree(tmp_path)
+
+
+@pytest.fixture(scope="session", name="empty_vela_ini_file")
+def fixture_empty_vela_ini_file(
+ test_vela_path: Path,
+) -> Path:
+ """Return empty test vela file path."""
+ return test_vela_path / "empty_vela.ini"
+
+
+@pytest.fixture(scope="session", name="vela_ini_file")
+def fixture_vela_ini_file(
+ test_vela_path: Path,
+) -> Path:
+ """Return empty test vela file path."""
+ return test_vela_path / "vela.ini"
+
+
def get_test_keras_model() -> tf.keras.Model:
"""Return test Keras model."""
model = tf.keras.Sequential(
@@ -130,13 +156,8 @@ def fixture_test_models_path(
convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path)
# Vela-optimized TensorFlow Lite model (int8)
- tflite_vela_model = tmp_path / TEST_MODEL_TFLITE_VELA_FILE
target_config = EthosUConfiguration.load_profile("ethos-u55-256")
- optimize_model(
- tflite_model_path,
- target_config.compiler_options,
- tflite_vela_model,
- )
+ compile_model(tflite_model_path, target_config.compiler_options)
tf.saved_model.save(keras_model, str(tmp_path / TEST_MODEL_TF_SAVED_MODEL_FILE))