diff options
author | Annie Tallund <annie.tallund@arm.com> | 2024-01-19 14:18:22 +0100 |
---|---|---|
committer | Annie Tallund <annie.tallund@arm.com> | 2024-01-23 15:57:30 +0000 |
commit | be7bab89eb8ace6ab6a83687354beab156afb716 (patch) | |
tree | 55c54c0e25c5e9349abb2eb441052a3a60c315c3 /tests/test_nn_tensorflow_config.py | |
parent | 732d8ef0bba9d23c731611dbed25e2e24a8a30d2 (diff) | |
download | mlia-be7bab89eb8ace6ab6a83687354beab156afb716.tar.gz |
fix: Improve error handling for invalid file
If a file has the right extension, MLIA previously tried
to load files with invalid content, resulting in confusing
errors. This patch adds better reporting for that scenario
Resolves: MLIA-1051
Signed-off-by: Annie Tallund <annie.tallund@arm.com>
Change-Id: I3f1fd578906a73a58367428f78409866f5da7836
Diffstat (limited to 'tests/test_nn_tensorflow_config.py')
-rw-r--r-- | tests/test_nn_tensorflow_config.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py index fff3857..c781756 100644 --- a/tests/test_nn_tensorflow_config.py +++ b/tests/test_nn_tensorflow_config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for config module.""" from contextlib import ExitStack as does_not_raise @@ -50,10 +50,19 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None: assert tflite_model_path.stat().st_size > 0 +def test_invalid_tflite_model(tmp_path: Path) -> None: + """Check that a RuntimeError is raised when a TFLite file is invalid.""" + model_path = tmp_path / "test.tflite" + model_path.write_text("Not a TFLite file!") + + with pytest.raises(RuntimeError): + TFLiteModel(model_path=model_path) + + @pytest.mark.parametrize( "model_path, expected_type, expected_error", [ - ("test.tflite", TFLiteModel, pytest.raises(ValueError)), + ("test.tflite", TFLiteModel, pytest.raises(RuntimeError)), ("test.h5", KerasModel, does_not_raise()), ("test.hdf5", KerasModel, does_not_raise()), ( |