From be7bab89eb8ace6ab6a83687354beab156afb716 Mon Sep 17 00:00:00 2001 From: Annie Tallund Date: Fri, 19 Jan 2024 14:18:22 +0100 Subject: fix: Improve error handling for invalid file If a file has the right extension, MLIA previously tried to load files with invalid content, resulting in confusing errors. This patch adds better reporting for that scenario Resolves: MLIA-1051 Signed-off-by: Annie Tallund Change-Id: I3f1fd578906a73a58367428f78409866f5da7836 --- src/mlia/nn/tensorflow/config.py | 24 +++++++++++++++++++----- tests/test_nn_tensorflow_config.py | 13 +++++++++++-- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index 0a17977..44fbaef 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Model configuration.""" from __future__ import annotations @@ -59,7 +59,15 @@ class KerasModel(ModelConfiguration): def get_keras_model(self) -> tf.keras.Model: """Return associated Keras model.""" - return tf.keras.models.load_model(self.model_path) + try: + keras_model = tf.keras.models.load_model(self.model_path) + except OSError as err: + raise RuntimeError( + f"Unable to load model content in {self.model_path}. " + f"Verify that it's a valid model file." + ) from err + + return keras_model def convert_to_tflite( self, tflite_model_path: str | Path, quantized: bool = False @@ -104,9 +112,15 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method if not num_threads: num_threads = None if not batch_size: - self.interpreter = tf.lite.Interpreter( - model_path=self.model_path, num_threads=num_threads - ) + try: + self.interpreter = tf.lite.Interpreter( + model_path=self.model_path, num_threads=num_threads + ) + except ValueError as err: + raise RuntimeError( + f"Unable to load model content in {self.model_path}. " + f"Verify that it's a valid model file." + ) from err else: # if a batch size is specified, modify the TFLite model to use this size with tempfile.TemporaryDirectory() as tmp: flatbuffer = load_fb(self.model_path) diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py index fff3857..c781756 100644 --- a/tests/test_nn_tensorflow_config.py +++ b/tests/test_nn_tensorflow_config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for config module.""" from contextlib import ExitStack as does_not_raise @@ -50,10 +50,19 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None: assert tflite_model_path.stat().st_size > 0 +def test_invalid_tflite_model(tmp_path: Path) -> None: + """Check that a RuntimeError is raised when a TFLite file is invalid.""" + model_path = tmp_path / "test.tflite" + model_path.write_text("Not a TFLite file!") + + with pytest.raises(RuntimeError): + TFLiteModel(model_path=model_path) + + @pytest.mark.parametrize( "model_path, expected_type, expected_error", [ - ("test.tflite", TFLiteModel, pytest.raises(ValueError)), + ("test.tflite", TFLiteModel, pytest.raises(RuntimeError)), ("test.h5", KerasModel, does_not_raise()), ("test.hdf5", KerasModel, does_not_raise()), ( -- cgit v1.2.1