aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_utils.py
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-24 15:08:08 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-26 17:08:13 +0100
commit58a65fee574c00329cf92b387a6d2513dcbf6100 (patch)
tree47e3185f78b4298ab029785ddee68456e44cac10 /tests/test_nn_tensorflow_utils.py
parent9d34cb72d45a6d0a2ec1063ebf32536c1efdba75 (diff)
downloadmlia-58a65fee574c00329cf92b387a6d2513dcbf6100.tar.gz
MLIA-433 Add TensorFlow Lite compatibility check
- Add ability to intercept low level TensorFlow output - Produce advice for the models that could not be converted to the TensorFlow Lite format - Refactor utility functions for TensorFlow Lite conversion - Add TensorFlow Lite compatibility checker Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d
Diffstat (limited to 'tests/test_nn_tensorflow_utils.py')
-rw-r--r--tests/test_nn_tensorflow_utils.py44
1 files changed, 40 insertions, 4 deletions
diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py
index 199c7db..5131171 100644
--- a/tests/test_nn_tensorflow_utils.py
+++ b/tests/test_nn_tensorflow_utils.py
@@ -3,6 +3,7 @@
"""Test for module utils/test_utils."""
from pathlib import Path
+import numpy as np
import pytest
import tensorflow as tf
@@ -10,16 +11,43 @@ from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import get_tf_tensor_shape
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.nn.tensorflow.utils import representative_dataset
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.nn.tensorflow.utils import save_tflite_model
-def test_convert_to_tflite(test_keras_model: Path) -> None:
+def test_generate_representative_dataset() -> None:
+ """Test function for generating representative dataset."""
+ dataset = representative_dataset([1, 3, 3], 5)
+ data = list(dataset())
+
+ assert len(data) == 5
+ for elem in data:
+ assert isinstance(elem, list)
+ assert len(elem) == 1
+
+ ndarray = elem[0]
+ assert ndarray.dtype == np.float32
+ assert isinstance(ndarray, np.ndarray)
+
+
+def test_generate_representative_dataset_wrong_shape() -> None:
+ """Test that only shape with batch size=1 is supported."""
+ with pytest.raises(Exception, match="Only the input batch_size=1 is supported!"):
+ representative_dataset([2, 3, 3], 5)
+
+
+def test_convert_saved_model_to_tflite(test_tf_model: Path) -> None:
+ """Test converting SavedModel to TensorFlow Lite."""
+ result = convert_to_tflite(test_tf_model.as_posix())
+ assert isinstance(result, bytes)
+
+
+def test_convert_keras_to_tflite(test_keras_model: Path) -> None:
"""Test converting Keras model to TensorFlow Lite."""
keras_model = tf.keras.models.load_model(str(test_keras_model))
- tflite_model = convert_to_tflite(keras_model)
-
- assert tflite_model
+ result = convert_to_tflite(keras_model)
+ assert isinstance(result, bytes)
def test_save_keras_model(tmp_path: Path, test_keras_model: Path) -> None:
@@ -46,6 +74,14 @@ def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None:
assert interpreter
+def test_convert_unknown_model_to_tflite() -> None:
+ """Test that unknown model type cannot be converted to TensorFlow Lite."""
+ with pytest.raises(
+ ValueError, match="Unable to create TensorFlow Lite converter for 123"
+ ):
+ convert_to_tflite(123)
+
+
@pytest.mark.parametrize(
"model_path, expected_result",
[