aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2024-01-19 14:18:22 +0100
committerAnnie Tallund <annie.tallund@arm.com>2024-01-23 15:57:30 +0000
commitbe7bab89eb8ace6ab6a83687354beab156afb716 (patch)
tree55c54c0e25c5e9349abb2eb441052a3a60c315c3
parent732d8ef0bba9d23c731611dbed25e2e24a8a30d2 (diff)
downloadmlia-be7bab89eb8ace6ab6a83687354beab156afb716.tar.gz
fix: Improve error handling for invalid file
If a file has the right extension, MLIA previously tried to load files with invalid content, resulting in confusing errors. This patch adds better reporting for that scenario Resolves: MLIA-1051 Signed-off-by: Annie Tallund <annie.tallund@arm.com> Change-Id: I3f1fd578906a73a58367428f78409866f5da7836
-rw-r--r--src/mlia/nn/tensorflow/config.py24
-rw-r--r--tests/test_nn_tensorflow_config.py13
2 files changed, 30 insertions, 7 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index 0a17977..44fbaef 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Model configuration."""
from __future__ import annotations
@@ -59,7 +59,15 @@ class KerasModel(ModelConfiguration):
def get_keras_model(self) -> tf.keras.Model:
"""Return associated Keras model."""
- return tf.keras.models.load_model(self.model_path)
+ try:
+ keras_model = tf.keras.models.load_model(self.model_path)
+ except OSError as err:
+ raise RuntimeError(
+ f"Unable to load model content in {self.model_path}. "
+ f"Verify that it's a valid model file."
+ ) from err
+
+ return keras_model
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
@@ -104,9 +112,15 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
if not num_threads:
num_threads = None
if not batch_size:
- self.interpreter = tf.lite.Interpreter(
- model_path=self.model_path, num_threads=num_threads
- )
+ try:
+ self.interpreter = tf.lite.Interpreter(
+ model_path=self.model_path, num_threads=num_threads
+ )
+ except ValueError as err:
+ raise RuntimeError(
+ f"Unable to load model content in {self.model_path}. "
+ f"Verify that it's a valid model file."
+ ) from err
else: # if a batch size is specified, modify the TFLite model to use this size
with tempfile.TemporaryDirectory() as tmp:
flatbuffer = load_fb(self.model_path)
diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py
index fff3857..c781756 100644
--- a/tests/test_nn_tensorflow_config.py
+++ b/tests/test_nn_tensorflow_config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for config module."""
from contextlib import ExitStack as does_not_raise
@@ -50,10 +50,19 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None:
assert tflite_model_path.stat().st_size > 0
+def test_invalid_tflite_model(tmp_path: Path) -> None:
+ """Check that a RuntimeError is raised when a TFLite file is invalid."""
+ model_path = tmp_path / "test.tflite"
+ model_path.write_text("Not a TFLite file!")
+
+ with pytest.raises(RuntimeError):
+ TFLiteModel(model_path=model_path)
+
+
@pytest.mark.parametrize(
"model_path, expected_type, expected_error",
[
- ("test.tflite", TFLiteModel, pytest.raises(ValueError)),
+ ("test.tflite", TFLiteModel, pytest.raises(RuntimeError)),
("test.h5", KerasModel, does_not_raise()),
("test.hdf5", KerasModel, does_not_raise()),
(