diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r-- | src/mlia/nn/tensorflow/config.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index 44fbaef..c6fae1c 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -15,6 +15,7 @@ from typing import List import numpy as np import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.core.context import Context from mlia.nn.tensorflow.optimizations.quantization import dequantize @@ -30,6 +31,7 @@ from mlia.nn.tensorflow.utils import is_saved_model from mlia.nn.tensorflow.utils import is_tflite_model from mlia.utils.logging import log_action + logger = logging.getLogger(__name__) @@ -57,10 +59,10 @@ class KerasModel(ModelConfiguration): Supports all models supported by Keras API: saved model, H5, HDF5 """ - def get_keras_model(self) -> tf.keras.Model: + def get_keras_model(self) -> keras.Model: """Return associated Keras model.""" try: - keras_model = tf.keras.models.load_model(self.model_path) + keras_model = keras.models.load_model(self.model_path) except OSError as err: raise RuntimeError( f"Unable to load model content in {self.model_path}. " |