aboutsummaryrefslogtreecommitdiff
path: root/tests/conftest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/conftest.py')
-rw-r--r--tests/conftest.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index 53bfb0c..3d0b832 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -11,6 +11,7 @@ import _pytest
import numpy as np
import pytest
import tensorflow as tf
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.backend.vela.compiler import compile_model
from mlia.core.context import ExecutionContext
@@ -103,21 +104,21 @@ def fixture_vela_ini_file(
return test_vela_path / "vela.ini"
-def get_test_keras_model() -> tf.keras.Model:
+def get_test_keras_model() -> keras.Model:
"""Return test Keras model."""
- model = tf.keras.Sequential(
+ model = keras.Sequential(
[
- tf.keras.Input(shape=(28, 28, 1), batch_size=1, name="input"),
- tf.keras.layers.Reshape((28, 28, 1)),
- tf.keras.layers.Conv2D(
+ keras.Input(shape=(28, 28, 1), batch_size=1, name="input"),
+ keras.layers.Reshape((28, 28, 1)),
+ keras.layers.Conv2D(
filters=12, kernel_size=(3, 3), activation="relu", name="conv1"
),
- tf.keras.layers.Conv2D(
+ keras.layers.Conv2D(
filters=12, kernel_size=(3, 3), activation="relu", name="conv2"
),
- tf.keras.layers.MaxPool2D(2, 2),
- tf.keras.layers.Flatten(),
- tf.keras.layers.Dense(10, name="output"),
+ keras.layers.MaxPool2D(2, 2),
+ keras.layers.Flatten(),
+ keras.layers.Dense(10, name="output"),
]
)