aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r--src/mlia/nn/tensorflow/config.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index 0a17977..44fbaef 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Model configuration."""
from __future__ import annotations
@@ -59,7 +59,15 @@ class KerasModel(ModelConfiguration):
def get_keras_model(self) -> tf.keras.Model:
"""Return associated Keras model."""
- return tf.keras.models.load_model(self.model_path)
+ try:
+ keras_model = tf.keras.models.load_model(self.model_path)
+ except OSError as err:
+ raise RuntimeError(
+ f"Unable to load model content in {self.model_path}. "
+ f"Verify that it's a valid model file."
+ ) from err
+
+ return keras_model
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
@@ -104,9 +112,15 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
if not num_threads:
num_threads = None
if not batch_size:
- self.interpreter = tf.lite.Interpreter(
- model_path=self.model_path, num_threads=num_threads
- )
+ try:
+ self.interpreter = tf.lite.Interpreter(
+ model_path=self.model_path, num_threads=num_threads
+ )
+ except ValueError as err:
+ raise RuntimeError(
+ f"Unable to load model content in {self.model_path}. "
+ f"Verify that it's a valid model file."
+ ) from err
else: # if a batch size is specified, modify the TFLite model to use this size
with tempfile.TemporaryDirectory() as tmp:
flatbuffer = load_fb(self.model_path)