aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r--src/mlia/nn/tensorflow/config.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index 44fbaef..c6fae1c 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -15,6 +15,7 @@ from typing import List
import numpy as np
import tensorflow as tf
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.context import Context
from mlia.nn.tensorflow.optimizations.quantization import dequantize
@@ -30,6 +31,7 @@ from mlia.nn.tensorflow.utils import is_saved_model
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.utils.logging import log_action
+
logger = logging.getLogger(__name__)
@@ -57,10 +59,10 @@ class KerasModel(ModelConfiguration):
Supports all models supported by Keras API: saved model, H5, HDF5
"""
- def get_keras_model(self) -> tf.keras.Model:
+ def get_keras_model(self) -> keras.Model:
"""Return associated Keras model."""
try:
- keras_model = tf.keras.models.load_model(self.model_path)
+ keras_model = keras.models.load_model(self.model_path)
except OSError as err:
raise RuntimeError(
f"Unable to load model content in {self.model_path}. "