aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2024-01-19 14:18:22 +0100
committerAnnie Tallund <annie.tallund@arm.com>2024-01-23 15:57:30 +0000
commitbe7bab89eb8ace6ab6a83687354beab156afb716 (patch)
tree55c54c0e25c5e9349abb2eb441052a3a60c315c3 /src
parent732d8ef0bba9d23c731611dbed25e2e24a8a30d2 (diff)
downloadmlia-be7bab89eb8ace6ab6a83687354beab156afb716.tar.gz
fix: Improve error handling for invalid file
If a file has the right extension, MLIA previously tried to load files with invalid content, resulting in confusing errors. This patch adds better reporting for that scenario Resolves: MLIA-1051 Signed-off-by: Annie Tallund <annie.tallund@arm.com> Change-Id: I3f1fd578906a73a58367428f78409866f5da7836
Diffstat (limited to 'src')
-rw-r--r--src/mlia/nn/tensorflow/config.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index 0a17977..44fbaef 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Model configuration."""
from __future__ import annotations
@@ -59,7 +59,15 @@ class KerasModel(ModelConfiguration):
def get_keras_model(self) -> tf.keras.Model:
"""Return associated Keras model."""
- return tf.keras.models.load_model(self.model_path)
+ try:
+ keras_model = tf.keras.models.load_model(self.model_path)
+ except OSError as err:
+ raise RuntimeError(
+ f"Unable to load model content in {self.model_path}. "
+ f"Verify that it's a valid model file."
+ ) from err
+
+ return keras_model
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
@@ -104,9 +112,15 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
if not num_threads:
num_threads = None
if not batch_size:
- self.interpreter = tf.lite.Interpreter(
- model_path=self.model_path, num_threads=num_threads
- )
+ try:
+ self.interpreter = tf.lite.Interpreter(
+ model_path=self.model_path, num_threads=num_threads
+ )
+ except ValueError as err:
+ raise RuntimeError(
+ f"Unable to load model content in {self.model_path}. "
+ f"Verify that it's a valid model file."
+ ) from err
else: # if a batch size is specified, modify the TFLite model to use this size
with tempfile.TemporaryDirectory() as tmp:
flatbuffer = load_fb(self.model_path)