diff options
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_graph.py')
-rw-r--r-- | tests/test_nn_tensorflow_tflite_graph.py | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py index cd1fad6..3512cdd 100644 --- a/tests/test_nn_tensorflow_tflite_graph.py +++ b/tests/test_nn_tensorflow_tflite_graph.py @@ -1,15 +1,22 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the tflite_graph module.""" import json from pathlib import Path +import pytest +import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT + +from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import Op from mlia.nn.tensorflow.tflite_graph import parse_subgraphs +from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.nn.tensorflow.tflite_graph import TensorInfo from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION from mlia.nn.tensorflow.tflite_graph import TFL_OP from mlia.nn.tensorflow.tflite_graph import TFL_TYPE +from tests.utils.rewrite import models_are_equal def test_tensor_info() -> None: @@ -79,3 +86,24 @@ def test_parse_subgraphs(test_tflite_model: Path) -> None: assert TFL_OP[oper.type] in TFL_OP assert len(oper.inputs) > 0 assert len(oper.outputs) > 0 + + +def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None: + """Test the load/save functions for TensorFlow Lite models.""" + with pytest.raises(FileNotFoundError): + load_fb("THIS_IS_NOT_A_REAL_FILE") + + model = load_fb(test_tflite_model) + assert isinstance(model, ModelT) + assert model.subgraphs + + output_file = tmp_path / "test.tflite" + assert not output_file.is_file() + save_fb(model, output_file) + assert output_file.is_file() + + model_copy = load_fb(str(output_file)) + assert models_are_equal(model, model_copy) + + # Double check that the TensorFlow Lite Interpreter can still load the file. + tf.lite.Interpreter(model_path=str(output_file)) |