aboutsummaryrefslogtreecommitdiff
path: root/tests/conftest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/conftest.py')
-rw-r--r--tests/conftest.py42
1 files changed, 37 insertions, 5 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index 3d0b832..a64f320 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -126,9 +126,28 @@ def get_test_keras_model() -> keras.Model:
return model
+def get_test_keras_model_no_activation() -> keras.Model:
+ """Return test Keras model."""
+ model = keras.Sequential(
+ [
+ keras.Input(shape=(28, 28, 1), batch_size=1, name="input"),
+ keras.layers.Reshape((28, 28, 1)),
+ keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv1"),
+ keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv2"),
+ keras.layers.MaxPool2D(2, 2),
+ keras.layers.Flatten(),
+ keras.layers.Dense(10, name="output"),
+ ]
+ )
+
+ model.compile(optimizer="sgd", loss="mean_squared_error")
+ return model
+
+
TEST_MODEL_KERAS_FILE = "test_model.h5"
TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite"
TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite"
+TEST_MODEL_TFLITE_NO_ACT_FILE = "test_model_no_act.tflite"
TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite"
TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model"
TEST_MODEL_INVALID_FILE = "invalid.tflite"
@@ -153,6 +172,13 @@ def fixture_test_models_path(
keras_model, quantized=False, output_path=tmp_path / TEST_MODEL_TFLITE_FP32_FILE
)
+ # Un-quantized TensorFlow Lite model with ReLU activation (fp32)
+ convert_to_tflite(
+ get_test_keras_model_no_activation(),
+ quantized=False,
+ output_path=tmp_path / TEST_MODEL_TFLITE_NO_ACT_FILE,
+ )
+
# Quantized TensorFlow Lite model (int8)
tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE
convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path)
@@ -195,6 +221,12 @@ def fixture_test_tflite_vela_model(test_models_path: Path) -> Path:
return test_models_path / TEST_MODEL_TFLITE_VELA_FILE
+@pytest.fixture(scope="session", name="test_tflite_no_act_model")
+def fixture_test_tflite_no_act_model(test_models_path: Path) -> Path:
+ """Return test TensorFlow Lite model with relu activation."""
+ return test_models_path / TEST_MODEL_TFLITE_NO_ACT_FILE
+
+
@pytest.fixture(scope="session", name="test_tf_model")
def fixture_test_tf_model(test_models_path: Path) -> Path:
"""Return test TensorFlow Lite model."""
@@ -257,17 +289,17 @@ def fixture_test_tfrecord_fp32(
yield from create_tfrecord(tmp_path_factory, random_data)
-@pytest.fixture(scope="session", autouse=True)
+@pytest.fixture(scope="function", autouse=True)
def set_training_steps(
request: _pytest.fixtures.SubRequest,
) -> Generator[None, None, None]:
"""Speed up tests by using MockTrainingParameters."""
- if "set_training_steps" == request.fixturename:
- yield
- else:
+ if "skip_set_training_steps" not in request.keywords:
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"mlia.nn.select._get_rewrite_params",
- MagicMock(return_value=[MockTrainingParameters(), None, None]),
+ MagicMock(return_value=MockTrainingParameters()),
)
yield
+ else:
+ yield