diff options
author | Diego Russo <diego.russo@arm.com> | 2022-05-30 13:34:14 +0100 |
---|---|---|
committer | Diego Russo <diego.russo@arm.com> | 2022-05-30 13:34:14 +0100 |
commit | 0efca3cadbad5517a59884576ddb90cfe7ac30f8 (patch) | |
tree | abed6cb6fbf3c439fc8d947f505b6a53d5daeb1e /tests/conftest.py | |
parent | 0777092695c143c3a54680b5748287d40c914c35 (diff) | |
download | mlia-0efca3cadbad5517a59884576ddb90cfe7ac30f8.tar.gz |
Add MLIA codebase0.3.0-rc.1
Add MLIA codebase including sources and tests.
Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd
Diffstat (limited to 'tests/conftest.py')
-rw-r--r-- | tests/conftest.py | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5c6156c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Pytest conf module.""" +import shutil +from pathlib import Path +from typing import Generator + +import pytest +import tensorflow as tf + +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.nn.tensorflow.utils import convert_to_tflite +from mlia.nn.tensorflow.utils import save_keras_model +from mlia.nn.tensorflow.utils import save_tflite_model +from mlia.tools.vela_wrapper import optimize_model + + +def get_test_keras_model() -> tf.keras.Model: + """Return test Keras model.""" + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=(28, 28, 1), batch_size=1, name="input"), + tf.keras.layers.Reshape((28, 28, 1)), + tf.keras.layers.Conv2D( + filters=12, kernel_size=(3, 3), activation="relu", name="conv1" + ), + tf.keras.layers.Conv2D( + filters=12, kernel_size=(3, 3), activation="relu", name="conv2" + ), + tf.keras.layers.MaxPool2D(2, 2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(10, name="output"), + ] + ) + + model.compile(optimizer="sgd", loss="mean_squared_error") + return model + + +@pytest.fixture(scope="session", name="test_models_path") +def fixture_test_models_path( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Provide path to the test models.""" + tmp_path = tmp_path_factory.mktemp("models") + + keras_model = get_test_keras_model() + save_keras_model(keras_model, tmp_path / "test_model.h5") + + tflite_model = convert_to_tflite(keras_model, quantized=True) + tflite_model_path = tmp_path / "test_model.tflite" + save_tflite_model(tflite_model, tflite_model_path) + + tflite_vela_model = tmp_path / "test_model_vela.tflite" + device = EthosUConfiguration("ethos-u55-256") + optimize_model(tflite_model_path, device.compiler_options, tflite_vela_model) + + tf.saved_model.save(keras_model, str(tmp_path / "tf_model_test_model")) + + invalid_tflite_model = tmp_path / "invalid.tflite" + invalid_tflite_model.touch() + + yield tmp_path + + shutil.rmtree(tmp_path) + + +@pytest.fixture(scope="session", name="test_keras_model") +def fixture_test_keras_model(test_models_path: Path) -> Path: + """Return test Keras model.""" + return test_models_path / "test_model.h5" + + +@pytest.fixture(scope="session", name="test_tflite_model") +def fixture_test_tflite_model(test_models_path: Path) -> Path: + """Return test TFLite model.""" + return test_models_path / "test_model.tflite" + + +@pytest.fixture(scope="session", name="test_tflite_vela_model") +def fixture_test_tflite_vela_model(test_models_path: Path) -> Path: + """Return test Vela-optimized TFLite model.""" + return test_models_path / "test_model_vela.tflite" + + +@pytest.fixture(scope="session", name="test_tf_model") +def fixture_test_tf_model(test_models_path: Path) -> Path: + """Return test TFLite model.""" + return test_models_path / "tf_model_test_model" + + +@pytest.fixture(scope="session", name="test_tflite_invalid_model") +def fixture_test_tflite_invalid_model(test_models_path: Path) -> Path: + """Return test invalid TFLite model.""" + return test_models_path / "invalid.tflite" |