aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/conftest.py39
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_join.py6
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_record.py4
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py6
-rw-r--r--tests/test_nn_rewrite_core_train.py76
-rw-r--r--tests/test_nn_rewrite_core_utils.py33
-rw-r--r--tests/test_nn_rewrite_core_utils_numpy_tfrecord.py25
-rw-r--r--tests/test_nn_tensorflow_config.py40
-rw-r--r--tests/test_nn_tensorflow_tflite_graph.py30
-rw-r--r--tests/utils/rewrite.py18
10 files changed, 189 insertions, 88 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index c42b8cb..bb2423f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -5,6 +5,7 @@ import shutil
from pathlib import Path
from typing import Callable
from typing import Generator
+from unittest.mock import MagicMock
import numpy as np
import pytest
@@ -17,6 +18,7 @@ from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.target.ethos_u.config import EthosUConfiguration
+from tests.utils.rewrite import TestTrainingParameters
@pytest.fixture(scope="session", name="test_resources_path")
@@ -168,16 +170,12 @@ def _write_tfrecord(
writer.write({input_name: data_generator()})
-@pytest.fixture(scope="session", name="test_tfrecord")
-def fixture_test_tfrecord(
- tmp_path_factory: pytest.TempPathFactory,
+def create_tfrecord(
+ tmp_path_factory: pytest.TempPathFactory, random_data: Callable
) -> Generator[Path, None, None]:
"""Create a tfrecord with random data matching fixture 'test_tflite_model'."""
tmp_path = tmp_path_factory.mktemp("tfrecords")
- tfrecord_file = tmp_path / "test_int8.tfrecord"
-
- def random_data() -> np.ndarray:
- return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8)
+ tfrecord_file = tmp_path / "test.tfrecord"
_write_tfrecord(tfrecord_file, random_data)
@@ -186,19 +184,36 @@ def fixture_test_tfrecord(
shutil.rmtree(tmp_path)
+@pytest.fixture(scope="session", name="test_tfrecord")
+def fixture_test_tfrecord(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Create a tfrecord with random data matching fixture 'test_tflite_model'."""
+
+ def random_data() -> np.ndarray:
+ return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8)
+
+ yield from create_tfrecord(tmp_path_factory, random_data)
+
+
@pytest.fixture(scope="session", name="test_tfrecord_fp32")
def fixture_test_tfrecord_fp32(
tmp_path_factory: pytest.TempPathFactory,
) -> Generator[Path, None, None]:
"""Create tfrecord with random data matching fixture 'test_tflite_model_fp32'."""
- tmp_path = tmp_path_factory.mktemp("tfrecords")
- tfrecord_file = tmp_path / "test_fp32.tfrecord"
def random_data() -> np.ndarray:
return np.random.rand(1, 28, 28, 1).astype(np.float32)
- _write_tfrecord(tfrecord_file, random_data)
+ yield from create_tfrecord(tmp_path_factory, random_data)
- yield tfrecord_file
- shutil.rmtree(tmp_path)
+@pytest.fixture(scope="session", autouse=True)
+def set_training_steps() -> Generator[None, None, None]:
+ """Speed up tests by using TestTrainingParameters."""
+ with pytest.MonkeyPatch.context() as monkeypatch:
+ monkeypatch.setattr(
+ "mlia.nn.select._get_rewrite_train_params",
+ MagicMock(return_value=TestTrainingParameters()),
+ )
+ yield
diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py
index cb3e4e2..0cb121e 100644
--- a/tests/test_nn_rewrite_core_graph_edit_join.py
+++ b/tests/test_nn_rewrite_core_graph_edit_join.py
@@ -10,7 +10,7 @@ import pytest
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
from mlia.nn.rewrite.core.graph_edit.join import append_relabel
from mlia.nn.rewrite.core.graph_edit.join import join_models
-from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.tensorflow.tflite_graph import load_fb
from tests.utils.rewrite import models_are_equal
@@ -49,8 +49,8 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None:
)
assert joined_file.is_file()
- orig_model = load(str(test_tflite_model))
- joined_model = load(str(joined_file))
+ orig_model = load_fb(str(test_tflite_model))
+ joined_model = load_fb(str(joined_file))
assert models_are_equal(orig_model, joined_model)
diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py
index cd728af..41b9c50 100644
--- a/tests/test_nn_rewrite_core_graph_edit_record.py
+++ b/tests/test_nn_rewrite_core_graph_edit_record.py
@@ -3,15 +3,13 @@
"""Tests for module mlia.nn.rewrite.graph_edit.record."""
from pathlib import Path
-import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
-@pytest.mark.parametrize("batch_size", (None, 1, 2))
-def test_record_model(
+def check_record_model(
test_tflite_model: Path,
tmp_path: Path,
test_tfrecord: Path,
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index b98971e..2542db2 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -12,6 +12,7 @@ import pytest
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import Rewriter
from mlia.nn.tensorflow.config import TFLiteModel
+from tests.utils.rewrite import TestTrainingParameters
@pytest.mark.parametrize(
@@ -32,12 +33,14 @@ def test_rewrite_configuration(
None,
)
+ assert config_obj.optimization_target in str(config_obj)
+
rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj)
assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name
assert isinstance(rewriter_obj, Rewriter)
-def test_rewriter(
+def test_rewriting_optimizer(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
) -> None:
@@ -46,6 +49,7 @@ def test_rewriter(
"fully_connected",
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
test_tfrecord_fp32,
+ train_params=TestTrainingParameters(),
)
test_obj = Rewriter(test_tflite_model_fp32, config_obj)
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 3c2ef3e..4493671 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -4,6 +4,7 @@
# pylint: disable=too-many-arguments
from __future__ import annotations
+from contextlib import ExitStack as does_not_raise
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
@@ -12,10 +13,13 @@ import numpy as np
import pytest
import tensorflow as tf
-from mlia.nn.rewrite.core.train import augmentation_presets
+from mlia.nn.rewrite.core.train import augment_fn_twins
+from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
+from mlia.nn.rewrite.core.train import TrainingParameters
+from tests.utils.rewrite import TestTrainingParameters
def replace_fully_connected_with_conv(
@@ -41,15 +45,8 @@ def replace_fully_connected_with_conv(
def check_train(
tflite_model: Path,
tfrecord: Path,
- batch_size: int = 1,
- verbose: bool = False,
- show_progress: bool = False,
- augmentation_preset: tuple[float | None, float | None] = augmentation_presets[
- "none"
- ],
- lr_schedule: LearningRateSchedule = "cosine",
+ train_params: TrainingParameters = TestTrainingParameters(),
use_unmodified_model: bool = False,
- num_procs: int = 1,
) -> None:
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
@@ -62,14 +59,7 @@ def check_train(
replace_fn=replace_fully_connected_with_conv,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
- augment=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- learning_rate_schedule=lr_schedule,
- num_procs=num_procs,
+ train_params=train_params,
)
assert len(result) == 2
assert all(res >= 0.0 for res in result), f"Results out of bound: {result}"
@@ -79,7 +69,6 @@ def check_train(
@pytest.mark.parametrize(
(
"batch_size",
- "verbose",
"show_progress",
"augmentation_preset",
"lr_schedule",
@@ -87,14 +76,13 @@ def check_train(
"num_procs",
),
(
- (1, False, False, augmentation_presets["none"], "cosine", False, 2),
- (32, True, True, augmentation_presets["gaussian"], "late", True, 1),
- (2, False, False, augmentation_presets["mixup"], "constant", True, 0),
+ (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2),
+ (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1),
+ (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0),
(
1,
False,
- False,
- augmentation_presets["mix_gaussian_large"],
+ AUGMENTATION_PRESETS["mix_gaussian_large"],
"cosine",
False,
2,
@@ -105,7 +93,6 @@ def test_train(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
batch_size: int,
- verbose: bool,
show_progress: bool,
augmentation_preset: tuple[float | None, float | None],
lr_schedule: LearningRateSchedule,
@@ -116,13 +103,14 @@ def test_train(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- augmentation_preset=augmentation_preset,
- lr_schedule=lr_schedule,
+ train_params=TestTrainingParameters(
+ batch_size=batch_size,
+ show_progress=show_progress,
+ augmentations=augmentation_preset,
+ learning_rate_schedule=lr_schedule,
+ num_procs=num_procs,
+ ),
use_unmodified_model=use_unmodified_model,
- num_procs=num_procs,
)
@@ -131,11 +119,13 @@ def test_train_invalid_schedule(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid schedule."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- lr_schedule="unknown_schedule", # type: ignore
+ train_params=TestTrainingParameters(
+ learning_rate_schedule="unknown_schedule",
+ ),
)
@@ -144,11 +134,13 @@ def test_train_invalid_augmentation(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid augmentation."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- augmentation_preset=(1.0, 2.0, 3.0), # type: ignore
+ train_params=TestTrainingParameters(
+ augmentations=(1.0, 2.0, 3.0),
+ ),
)
@@ -159,3 +151,19 @@ def test_mixup() -> None:
assert src.shape == dst.shape
assert np.all(dst >= 0.0)
assert np.all(dst <= 3.0)
+
+
+@pytest.mark.parametrize(
+ "augmentations, expected_error",
+ [
+ (AUGMENTATION_PRESETS["none"], does_not_raise()),
+ (AUGMENTATION_PRESETS["mix_gaussian_large"], does_not_raise()),
+ ((None,) * 3, pytest.raises(AssertionError)),
+ ],
+)
+def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None:
+ """Test function augment_fn()."""
+ dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2, 3], "b": [4, 5, 6]})
+ with expected_error:
+ fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore
+ assert len(fn_twins) == 2
diff --git a/tests/test_nn_rewrite_core_utils.py b/tests/test_nn_rewrite_core_utils.py
deleted file mode 100644
index d806a7b..0000000
--- a/tests/test_nn_rewrite_core_utils.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Tests for module mlia.nn.rewrite.utils."""
-from pathlib import Path
-
-import pytest
-import tensorflow as tf
-from tensorflow.lite.python.schema_py_generated import ModelT
-
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
-from tests.utils.rewrite import models_are_equal
-
-
-def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None:
- """Test the load/save functions for TensorFlow Lite models."""
- with pytest.raises(FileNotFoundError):
- load("THIS_IS_NOT_A_REAL_FILE")
-
- model = load(test_tflite_model)
- assert isinstance(model, ModelT)
- assert model.subgraphs
-
- output_file = tmp_path / "test.tflite"
- assert not output_file.is_file()
- save(model, output_file)
- assert output_file.is_file()
-
- model_copy = load(str(output_file))
- assert models_are_equal(model, model_copy)
-
- # Double check that the TensorFlow Lite Interpreter can still load the file.
- tf.lite.Interpreter(model_path=str(output_file))
diff --git a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
index 7fc8048..d030350 100644
--- a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
+++ b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
@@ -5,6 +5,10 @@ from __future__ import annotations
from pathlib import Path
+import pytest
+import tensorflow as tf
+
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import make_decode_fn
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec
@@ -16,3 +20,24 @@ def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None:
sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file))
assert output_file.is_file()
assert numpytf_count(str(output_file)) == 1
+
+
+def test_make_decode_fn(test_tfrecord: Path) -> None:
+ """Test function make_decode_fn()."""
+ decode = make_decode_fn(str(test_tfrecord))
+ dataset = tf.data.TFRecordDataset(str(test_tfrecord))
+ features = decode(next(iter(dataset)))
+ assert isinstance(features, dict)
+ assert len(features) == 1
+ key, val = next(iter(features.items()))
+ assert isinstance(key, str)
+ assert isinstance(val, tf.Tensor)
+ assert val.dtype == tf.int8
+
+ with pytest.raises(FileNotFoundError):
+ make_decode_fn(str(test_tfrecord) + "_")
+
+
+def test_numpytf_count(test_tfrecord: Path) -> None:
+ """Test function numpytf_count()."""
+ assert numpytf_count(test_tfrecord) == 3
diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py
index 656619d..48aec0a 100644
--- a/tests/test_nn_tensorflow_config.py
+++ b/tests/test_nn_tensorflow_config.py
@@ -4,13 +4,28 @@
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
+from typing import Generator
+import numpy as np
import pytest
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.tensorflow.config import get_model
from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.config import ModelConfiguration
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.config import TfModel
+from tests.conftest import create_tfrecord
+
+
+def test_model_configuration(test_keras_model: Path) -> None:
+ """Test ModelConfiguration class."""
+ model = ModelConfiguration(model_path=test_keras_model)
+ assert test_keras_model.match(model.model_path)
+ with pytest.raises(NotImplementedError):
+ model.convert_to_keras("keras_model.h5")
+ with pytest.raises(NotImplementedError):
+ model.convert_to_tflite("model.tflite")
def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None:
@@ -38,7 +53,7 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None:
@pytest.mark.parametrize(
"model_path, expected_type, expected_error",
[
- ("test.tflite", TFLiteModel, does_not_raise()),
+ ("test.tflite", TFLiteModel, pytest.raises(ValueError)),
("test.h5", KerasModel, does_not_raise()),
("test.hdf5", KerasModel, does_not_raise()),
(
@@ -73,3 +88,26 @@ def test_get_model_dir(
"""Test TensorFlow Lite model type."""
model = get_model(str(test_models_path / model_path))
assert isinstance(model, expected_type)
+
+
+@pytest.fixture(scope="session", name="test_tfrecord_fp32_batch_3")
+def fixture_test_tfrecord_fp32_batch_3(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Create tfrecord (same as test_tfrecord_fp32) but with batch size 3."""
+
+ def random_data() -> np.ndarray:
+ return np.random.rand(3, 28, 28, 1).astype(np.float32)
+
+ yield from create_tfrecord(tmp_path_factory, random_data)
+
+
+def test_tflite_model_call(
+ test_tflite_model_fp32: Path, test_tfrecord_fp32_batch_3: Path
+) -> None:
+ """Test inference function of class TFLiteModel."""
+ model = TFLiteModel(test_tflite_model_fp32, batch_size=2)
+ data = numpytf_read(test_tfrecord_fp32_batch_3)
+ for named_input in data.as_numpy_iterator():
+ res = model(named_input)
+ assert res
diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py
index cd1fad6..3512cdd 100644
--- a/tests/test_nn_tensorflow_tflite_graph.py
+++ b/tests/test_nn_tensorflow_tflite_graph.py
@@ -1,15 +1,22 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the tflite_graph module."""
import json
from pathlib import Path
+import pytest
+import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+
+from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import Op
from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.nn.tensorflow.tflite_graph import TensorInfo
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
from mlia.nn.tensorflow.tflite_graph import TFL_OP
from mlia.nn.tensorflow.tflite_graph import TFL_TYPE
+from tests.utils.rewrite import models_are_equal
def test_tensor_info() -> None:
@@ -79,3 +86,24 @@ def test_parse_subgraphs(test_tflite_model: Path) -> None:
assert TFL_OP[oper.type] in TFL_OP
assert len(oper.inputs) > 0
assert len(oper.outputs) > 0
+
+
+def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None:
+ """Test the load/save functions for TensorFlow Lite models."""
+ with pytest.raises(FileNotFoundError):
+ load_fb("THIS_IS_NOT_A_REAL_FILE")
+
+ model = load_fb(test_tflite_model)
+ assert isinstance(model, ModelT)
+ assert model.subgraphs
+
+ output_file = tmp_path / "test.tflite"
+ assert not output_file.is_file()
+ save_fb(model, output_file)
+ assert output_file.is_file()
+
+ model_copy = load_fb(str(output_file))
+ assert models_are_equal(model, model_copy)
+
+ # Double check that the TensorFlow Lite Interpreter can still load the file.
+ tf.lite.Interpreter(model_path=str(output_file))
diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py
index 4264b4b..739bb11 100644
--- a/tests/utils/rewrite.py
+++ b/tests/utils/rewrite.py
@@ -3,8 +3,12 @@
"""Common test utils for the rewrite tests."""
from __future__ import annotations
+from typing import Any
+
from tensorflow.lite.python.schema_py_generated import ModelT
+from mlia.nn.rewrite.core.train import TrainingParameters
+
def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
"""Check that the two models are equal."""
@@ -25,3 +29,17 @@ def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
return False # Tensor from graph1 not found in other graph.")
return True
+
+
+class TestTrainingParameters(
+ TrainingParameters
+): # pylint: disable=too-few-public-methods
+ """
+ TrainingParameter class for rewrites with different default values.
+
+ To speed things up for the unit tests.
+ """
+
+ def __init__(self, *args: Any, steps: int = 32, **kwargs: Any) -> None:
+ """Initialize TrainingParameters with different defaults."""
+ super().__init__(*args, steps=steps, **kwargs) # type: ignore