aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_tensorflow_utils.py')
-rw-r--r--tests/test_nn_tensorflow_utils.py44
1 files changed, 2 insertions, 42 deletions
diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py
index dab8b4e..e356a49 100644
--- a/tests/test_nn_tensorflow_utils.py
+++ b/tests/test_nn_tensorflow_utils.py
@@ -8,43 +8,13 @@ import numpy as np
import pytest
import tensorflow as tf
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.utils import check_tflite_datatypes
-from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import get_tf_tensor_shape
from mlia.nn.tensorflow.utils import get_tflite_model_type_map
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_tflite_model
-from mlia.nn.tensorflow.utils import representative_dataset
from mlia.nn.tensorflow.utils import save_keras_model
-from mlia.nn.tensorflow.utils import save_tflite_model
-
-
-def test_generate_representative_dataset() -> None:
- """Test function for generating representative dataset."""
- dataset = representative_dataset([1, 3, 3], 5)
- data = list(dataset())
-
- assert len(data) == 5
- for elem in data:
- assert isinstance(elem, list)
- assert len(elem) == 1
-
- ndarray = elem[0]
- assert ndarray.dtype == np.float32
- assert isinstance(ndarray, np.ndarray)
-
-
-def test_convert_saved_model_to_tflite(test_tf_model: Path) -> None:
- """Test converting SavedModel to TensorFlow Lite."""
- result = convert_to_tflite(test_tf_model.as_posix())
- assert isinstance(result, bytes)
-
-
-def test_convert_keras_to_tflite(test_keras_model: Path) -> None:
- """Test converting Keras model to TensorFlow Lite."""
- keras_model = tf.keras.models.load_model(str(test_keras_model))
- result = convert_to_tflite(keras_model)
- assert isinstance(result, bytes)
def test_save_keras_model(tmp_path: Path, test_keras_model: Path) -> None:
@@ -62,23 +32,13 @@ def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None:
"""Test saving TensorFlow Lite model."""
keras_model = tf.keras.models.load_model(str(test_keras_model))
- tflite_model = convert_to_tflite(keras_model)
-
temp_file = tmp_path / "test_model_saving.tflite"
- save_tflite_model(tflite_model, temp_file)
+ convert_to_tflite(keras_model, output_path=temp_file)
interpreter = tf.lite.Interpreter(model_path=str(temp_file))
assert interpreter
-def test_convert_unknown_model_to_tflite() -> None:
- """Test that unknown model type cannot be converted to TensorFlow Lite."""
- with pytest.raises(
- ValueError, match="Unable to create TensorFlow Lite converter for 123"
- ):
- convert_to_tflite(123)
-
-
@pytest.mark.parametrize(
"model_path, expected_result",
[