aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_config.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_tensorflow_config.py')
-rw-r--r--tests/test_nn_tensorflow_config.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py
index fff3857..c781756 100644
--- a/tests/test_nn_tensorflow_config.py
+++ b/tests/test_nn_tensorflow_config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for config module."""
from contextlib import ExitStack as does_not_raise
@@ -50,10 +50,19 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None:
assert tflite_model_path.stat().st_size > 0
+def test_invalid_tflite_model(tmp_path: Path) -> None:
+ """Check that a RuntimeError is raised when a TFLite file is invalid."""
+ model_path = tmp_path / "test.tflite"
+ model_path.write_text("Not a TFLite file!")
+
+ with pytest.raises(RuntimeError):
+ TFLiteModel(model_path=model_path)
+
+
@pytest.mark.parametrize(
"model_path, expected_type, expected_error",
[
- ("test.tflite", TFLiteModel, pytest.raises(ValueError)),
+ ("test.tflite", TFLiteModel, pytest.raises(RuntimeError)),
("test.h5", KerasModel, does_not_raise()),
("test.hdf5", KerasModel, does_not_raise()),
(