diff options
Diffstat (limited to 'tests/conftest.py')
-rw-r--r-- | tests/conftest.py | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index 53bfb0c..3d0b832 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ import _pytest import numpy as np import pytest import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.backend.vela.compiler import compile_model from mlia.core.context import ExecutionContext @@ -103,21 +104,21 @@ def fixture_vela_ini_file( return test_vela_path / "vela.ini" -def get_test_keras_model() -> tf.keras.Model: +def get_test_keras_model() -> keras.Model: """Return test Keras model.""" - model = tf.keras.Sequential( + model = keras.Sequential( [ - tf.keras.Input(shape=(28, 28, 1), batch_size=1, name="input"), - tf.keras.layers.Reshape((28, 28, 1)), - tf.keras.layers.Conv2D( + keras.Input(shape=(28, 28, 1), batch_size=1, name="input"), + keras.layers.Reshape((28, 28, 1)), + keras.layers.Conv2D( filters=12, kernel_size=(3, 3), activation="relu", name="conv1" ), - tf.keras.layers.Conv2D( + keras.layers.Conv2D( filters=12, kernel_size=(3, 3), activation="relu", name="conv2" ), - tf.keras.layers.MaxPool2D(2, 2), - tf.keras.layers.Flatten(), - tf.keras.layers.Dense(10, name="output"), + keras.layers.MaxPool2D(2, 2), + keras.layers.Flatten(), + keras.layers.Dense(10, name="output"), ] ) |