diff options
Diffstat (limited to 'tests/conftest.py')
-rw-r--r-- | tests/conftest.py | 42 |
1 files changed, 37 insertions, 5 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index 3d0b832..a64f320 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -126,9 +126,28 @@ def get_test_keras_model() -> keras.Model: return model +def get_test_keras_model_no_activation() -> keras.Model: + """Return test Keras model.""" + model = keras.Sequential( + [ + keras.Input(shape=(28, 28, 1), batch_size=1, name="input"), + keras.layers.Reshape((28, 28, 1)), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv1"), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv2"), + keras.layers.MaxPool2D(2, 2), + keras.layers.Flatten(), + keras.layers.Dense(10, name="output"), + ] + ) + + model.compile(optimizer="sgd", loss="mean_squared_error") + return model + + TEST_MODEL_KERAS_FILE = "test_model.h5" TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite" TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite" +TEST_MODEL_TFLITE_NO_ACT_FILE = "test_model_no_act.tflite" TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite" TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model" TEST_MODEL_INVALID_FILE = "invalid.tflite" @@ -153,6 +172,13 @@ def fixture_test_models_path( keras_model, quantized=False, output_path=tmp_path / TEST_MODEL_TFLITE_FP32_FILE ) + # Un-quantized TensorFlow Lite model with ReLU activation (fp32) + convert_to_tflite( + get_test_keras_model_no_activation(), + quantized=False, + output_path=tmp_path / TEST_MODEL_TFLITE_NO_ACT_FILE, + ) + # Quantized TensorFlow Lite model (int8) tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path) @@ -195,6 +221,12 @@ def fixture_test_tflite_vela_model(test_models_path: Path) -> Path: return test_models_path / TEST_MODEL_TFLITE_VELA_FILE +@pytest.fixture(scope="session", name="test_tflite_no_act_model") +def fixture_test_tflite_no_act_model(test_models_path: Path) -> Path: + """Return test TensorFlow Lite model with relu activation.""" + return test_models_path / TEST_MODEL_TFLITE_NO_ACT_FILE + + @pytest.fixture(scope="session", name="test_tf_model") def fixture_test_tf_model(test_models_path: Path) -> Path: """Return test TensorFlow Lite model.""" @@ -257,17 +289,17 @@ def fixture_test_tfrecord_fp32( yield from create_tfrecord(tmp_path_factory, random_data) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="function", autouse=True) def set_training_steps( request: _pytest.fixtures.SubRequest, ) -> Generator[None, None, None]: """Speed up tests by using MockTrainingParameters.""" - if "set_training_steps" == request.fixturename: - yield - else: + if "skip_set_training_steps" not in request.keywords: with pytest.MonkeyPatch.context() as monkeypatch: monkeypatch.setattr( "mlia.nn.select._get_rewrite_params", - MagicMock(return_value=[MockTrainingParameters(), None, None]), + MagicMock(return_value=MockTrainingParameters()), ) yield + else: + yield |