diff options
Diffstat (limited to 'tests/test_nn_tensorflow_utils.py')
-rw-r--r-- | tests/test_nn_tensorflow_utils.py | 44 |
1 files changed, 2 insertions, 42 deletions
diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py index dab8b4e..e356a49 100644 --- a/tests/test_nn_tensorflow_utils.py +++ b/tests/test_nn_tensorflow_utils.py @@ -8,43 +8,13 @@ import numpy as np import pytest import tensorflow as tf +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.utils import check_tflite_datatypes -from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import get_tf_tensor_shape from mlia.nn.tensorflow.utils import get_tflite_model_type_map from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_tflite_model -from mlia.nn.tensorflow.utils import representative_dataset from mlia.nn.tensorflow.utils import save_keras_model -from mlia.nn.tensorflow.utils import save_tflite_model - - -def test_generate_representative_dataset() -> None: - """Test function for generating representative dataset.""" - dataset = representative_dataset([1, 3, 3], 5) - data = list(dataset()) - - assert len(data) == 5 - for elem in data: - assert isinstance(elem, list) - assert len(elem) == 1 - - ndarray = elem[0] - assert ndarray.dtype == np.float32 - assert isinstance(ndarray, np.ndarray) - - -def test_convert_saved_model_to_tflite(test_tf_model: Path) -> None: - """Test converting SavedModel to TensorFlow Lite.""" - result = convert_to_tflite(test_tf_model.as_posix()) - assert isinstance(result, bytes) - - -def test_convert_keras_to_tflite(test_keras_model: Path) -> None: - """Test converting Keras model to TensorFlow Lite.""" - keras_model = tf.keras.models.load_model(str(test_keras_model)) - result = convert_to_tflite(keras_model) - assert isinstance(result, bytes) def test_save_keras_model(tmp_path: Path, test_keras_model: Path) -> None: @@ -62,23 +32,13 @@ def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None: """Test saving TensorFlow Lite model.""" keras_model = tf.keras.models.load_model(str(test_keras_model)) - tflite_model = convert_to_tflite(keras_model) - temp_file = tmp_path / "test_model_saving.tflite" - save_tflite_model(tflite_model, temp_file) + convert_to_tflite(keras_model, output_path=temp_file) interpreter = tf.lite.Interpreter(model_path=str(temp_file)) assert interpreter -def test_convert_unknown_model_to_tflite() -> None: - """Test that unknown model type cannot be converted to TensorFlow Lite.""" - with pytest.raises( - ValueError, match="Unable to create TensorFlow Lite converter for 123" - ): - convert_to_tflite(123) - - @pytest.mark.parametrize( "model_path, expected_result", [ |