aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_tflite_graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_graph.py')
-rw-r--r--tests/test_nn_tensorflow_tflite_graph.py30
1 files changed, 29 insertions, 1 deletions
diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py
index cd1fad6..3512cdd 100644
--- a/tests/test_nn_tensorflow_tflite_graph.py
+++ b/tests/test_nn_tensorflow_tflite_graph.py
@@ -1,15 +1,22 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the tflite_graph module."""
import json
from pathlib import Path
+import pytest
+import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+
+from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import Op
from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.nn.tensorflow.tflite_graph import TensorInfo
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
from mlia.nn.tensorflow.tflite_graph import TFL_OP
from mlia.nn.tensorflow.tflite_graph import TFL_TYPE
+from tests.utils.rewrite import models_are_equal
def test_tensor_info() -> None:
@@ -79,3 +86,24 @@ def test_parse_subgraphs(test_tflite_model: Path) -> None:
assert TFL_OP[oper.type] in TFL_OP
assert len(oper.inputs) > 0
assert len(oper.outputs) > 0
+
+
+def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None:
+ """Test the load/save functions for TensorFlow Lite models."""
+ with pytest.raises(FileNotFoundError):
+ load_fb("THIS_IS_NOT_A_REAL_FILE")
+
+ model = load_fb(test_tflite_model)
+ assert isinstance(model, ModelT)
+ assert model.subgraphs
+
+ output_file = tmp_path / "test.tflite"
+ assert not output_file.is_file()
+ save_fb(model, output_file)
+ assert output_file.is_file()
+
+ model_copy = load_fb(str(output_file))
+ assert models_are_equal(model, model_copy)
+
+ # Double check that the TensorFlow Lite Interpreter can still load the file.
+ tf.lite.Interpreter(model_path=str(output_file))