From 62768232c5fe4ed6b87136c336b65e13d030e9d4 Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Mon, 20 Mar 2023 18:07:54 +0000 Subject: MLIA-843 Add unit tests for module mlia.nn.rewrite Note: The unit tests mostly call the main functions from the respective modules only. Change-Id: Ib2ce5c53d0c3eb222b8b8be42fba33ac8e007574 Signed-off-by: Benjamin Klimczak --- src/mlia/nn/rewrite/__init__.py | 2 +- src/mlia/nn/rewrite/core/__init__.py | 2 +- src/mlia/nn/rewrite/core/graph_edit/__init__.py | 2 +- src/mlia/nn/rewrite/core/graph_edit/diff.py | 30 ---- src/mlia/nn/rewrite/core/graph_edit/record.py | 33 +++-- src/mlia/nn/rewrite/core/train.py | 88 +----------- src/mlia/nn/rewrite/core/utils/__init__.py | 2 +- src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py | 3 - src/mlia/nn/rewrite/core/utils/parallel.py | 4 +- src/mlia/nn/rewrite/core/utils/utils.py | 2 +- tests/conftest.py | 95 +++++++++++-- tests/test_backend_vela_compat.py | 3 +- tests/test_nn_rewrite_core_graph_edit_cut.py | 29 ++++ tests/test_nn_rewrite_core_graph_edit_join.py | 50 +++++++ tests/test_nn_rewrite_core_graph_edit_record.py | 52 +++++++ tests/test_nn_rewrite_core_train.py | 157 +++++++++++++++++++++ tests/test_nn_rewrite_core_utils.py | 33 +++++ tests/test_nn_rewrite_core_utils_numpy_tfrecord.py | 18 +++ tests/utils/rewrite.py | 27 ++++ 19 files changed, 484 insertions(+), 148 deletions(-) create mode 100644 tests/test_nn_rewrite_core_graph_edit_cut.py create mode 100644 tests/test_nn_rewrite_core_graph_edit_join.py create mode 100644 tests/test_nn_rewrite_core_graph_edit_record.py create mode 100644 tests/test_nn_rewrite_core_train.py create mode 100644 tests/test_nn_rewrite_core_utils.py create mode 100644 tests/test_nn_rewrite_core_utils_numpy_tfrecord.py create mode 100644 tests/utils/rewrite.py diff --git a/src/mlia/nn/rewrite/__init__.py b/src/mlia/nn/rewrite/__init__.py index 48b1622..8c1f750 100644 --- a/src/mlia/nn/rewrite/__init__.py +++ b/src/mlia/nn/rewrite/__init__.py @@ -1,2 +1,2 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file +# SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/nn/rewrite/core/__init__.py b/src/mlia/nn/rewrite/core/__init__.py index 48b1622..8c1f750 100644 --- a/src/mlia/nn/rewrite/core/__init__.py +++ b/src/mlia/nn/rewrite/core/__init__.py @@ -1,2 +1,2 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file +# SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py index 48b1622..8c1f750 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/__init__.py +++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py @@ -1,2 +1,2 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file +# SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py index b6c9616..0829f0a 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/diff.py +++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py @@ -13,36 +13,6 @@ from tensorflow.lite.python import interpreter as interpreter_wrapper from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter -def diff(file1, file2): - results = [] - - dataset1 = NumpyTFReader(file1) - dataset2 = NumpyTFReader(file2) - - for i, (x1, x2) in enumerate(zip(dataset1, dataset2)): - assert x1.keys() == x2.keys(), ( - "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n" - % ( - i, - file1, - ", ".join(x1.keys()), - file2, - ", ".join(x2.keys()), - ) - ) - results.append({}) - for k in x1.keys(): - v1 = x1[k].numpy().astype(np.double) - v2 = x2[k].numpy().astype(np.double) - mae = abs(v1 - v2).mean() - results[-1][k] = mae - - total = sum(sum(x.values()) for x in results) - count = sum(len(x.values()) for x in results) - mean = total / count - return results, mean - - def diff_stats(file1, file2, per_tensor_and_channel=False): dataset1 = NumpyTFReader(file1) dataset2 = NumpyTFReader(file2) diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index 03cd3f9..ae13313 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -37,7 +37,6 @@ def record_model( total = numpytf_count(input_filename) dataset = NumpyTFReader(input_filename) - writer = NumpyTFWriter(output_filename) if batch_size > 1: # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now. @@ -47,16 +46,22 @@ def record_model( dataset = dataset.batch(batch_size, drop_remainder=False) total = int(math.ceil(total / batch_size)) - for j, named_x in enumerate( - tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress) - ): - named_y = model(named_x) - if batch_size > 1: - for i in range(batch_size): - # Expand the batches and recreate each dict as a batch-size 1 item for the tfrec output - d = {k: v[i : i + 1] for k, v in named_y.items() if i < v.shape[0]} - if d: - writer.write(d) - else: - writer.write(named_y) - model.close() + with NumpyTFWriter(output_filename) as writer: + for _, named_x in enumerate( + tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress) + ): + named_y = model(named_x) + if batch_size > 1: + for i in range(batch_size): + # Expand the batches and recreate each dict as a + # batch-size 1 item for the tfrec output + recreated_dict = { + k: v[i : i + 1] # noqa: E203 + for k, v in named_y.items() + if i < v.shape[0] + } + if recreated_dict: + writer.write(recreated_dict) + else: + writer.write(named_y) + model.close() diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index a929b14..096daf4 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -40,85 +40,7 @@ augmentation_presets = { "mix_gaussian_small": (1.6, 0.3), } - -class SequentialTrainer: - def __init__( - self, - source_model, - output_model, - input_tfrec, - augment="gaussian", - steps=6000, - lr=1e-3, - batch_size=32, - show_progress=True, - eval_fn=None, - num_procs=1, - num_threads=0, - ): - self.source_model = source_model - self.output_model = output_model - self.input_tfrec = input_tfrec - self.default_augment = augment - self.default_steps = steps - self.default_lr = lr - self.default_batch_size = batch_size - self.show_progress = show_progress - self.num_procs = num_procs - self.num_threads = num_threads - self.first_replace = True - self.eval_fn = eval_fn - - def replace( - self, - model_fn, - input_tensors, - output_tensors, - augment=None, - steps=None, - lr=None, - batch_size=None, - ): - augment = self.default_augment if augment is None else augment - steps = self.default_steps if steps is None else steps - lr = self.default_lr if lr is None else lr - batch_size = self.default_batch_size if batch_size is None else batch_size - - if isinstance(augment, str): - augment = augmentation_presets[augment] - - if self.first_replace: - source_model = self.source_model - unmodified_model = None - else: - source_model = self.output_model - unmodified_model = self.source_model - - mae, nrmse = train( - source_model, - unmodified_model, - self.output_model, - self.input_tfrec, - model_fn, - input_tensors, - output_tensors, - augment, - steps, - lr, - batch_size, - False, - self.show_progress, - None, - 0, - self.num_procs, - self.num_threads, - ) - - self.first_replace = False - if self.eval_fn: - return self.eval_fn(mae, nrmse, self.output_model) - else: - return mae, nrmse +learning_rate_schedules = {"cosine", "late", "constant"} def train( @@ -135,6 +57,7 @@ def train( batch_size, verbose, show_progress, + learning_rate_schedule="cosine", checkpoint_at=None, checkpoint_decay_steps=0, num_procs=1, @@ -183,6 +106,7 @@ def train( show_progress=show_progress, num_procs=num_procs, num_threads=num_threads, + schedule=learning_rate_schedule, ) for i, filename in enumerate(tflite_filenames): @@ -363,9 +287,9 @@ def train_in_dir( elif schedule == "constant": callbacks = [] else: - assert False, ( - 'LR schedule "%s" not implemented - expected "cosine", "constant" or "late"' - % schedule + assert schedule not in learning_rate_schedules + raise ValueError( + f'LR schedule "{schedule}" not implemented - expected one of {learning_rate_schedules}.' ) output_filenames = [] diff --git a/src/mlia/nn/rewrite/core/utils/__init__.py b/src/mlia/nn/rewrite/core/utils/__init__.py index 48b1622..8c1f750 100644 --- a/src/mlia/nn/rewrite/core/utils/__init__.py +++ b/src/mlia/nn/rewrite/core/utils/__init__.py @@ -1,2 +1,2 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file +# SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py index ac3e875..2141003 100644 --- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py +++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py @@ -63,9 +63,6 @@ class NumpyTFWriter: def __exit__(self, type, value, traceback): self.close() - def __del__(self): - self.close() - def write(self, array_dict): type_map = {n: str(a.dtype.name) for n, a in array_dict.items()} self.type_map.update(type_map) diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py index 5affc03..b1a2914 100644 --- a/src/mlia/nn/rewrite/core/utils/parallel.py +++ b/src/mlia/nn/rewrite/core/utils/parallel.py @@ -3,10 +3,10 @@ import math import os from collections import defaultdict +from multiprocessing import cpu_count from multiprocessing import Pool import numpy as np -from psutil import cpu_count os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf @@ -25,7 +25,7 @@ class ParallelTFLiteModel(TFLiteModel): self.pool = None self.filename = filename if not num_procs: - self.num_procs = cpu_count(logical=False) + self.num_procs = cpu_count() else: self.num_procs = int(num_procs) diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py index ed6c81d..d1ed322 100644 --- a/src/mlia/nn/rewrite/core/utils/utils.py +++ b/src/mlia/nn/rewrite/core/utils/utils.py @@ -8,7 +8,7 @@ from tensorflow.lite.python import schema_py_generated as schema_fb def load(input_tflite_file): if not os.path.exists(input_tflite_file): - raise RuntimeError("TFLite file not found at %r\n" % input_tflite_file) + raise FileNotFoundError("TFLite file not found at %r\n" % input_tflite_file) with open(input_tflite_file, "rb") as file_handle: file_data = bytearray(file_handle.read()) model_obj = schema_fb.Model.GetRootAsModel(file_data, 0) diff --git a/tests/conftest.py b/tests/conftest.py index 30889ca..c42b8cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,13 +3,16 @@ """Pytest conf module.""" import shutil from pathlib import Path +from typing import Callable from typing import Generator +import numpy as np import pytest import tensorflow as tf from mlia.backend.vela.compiler import optimize_model from mlia.core.context import ExecutionContext +from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import save_keras_model from mlia.nn.tensorflow.utils import save_tflite_model @@ -68,6 +71,14 @@ def get_test_keras_model() -> tf.keras.Model: return model +TEST_MODEL_KERAS_FILE = "test_model.h5" +TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite" +TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite" +TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite" +TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model" +TEST_MODEL_INVALID_FILE = "invalid.tflite" + + @pytest.fixture(scope="session", name="test_models_path") def fixture_test_models_path( tmp_path_factory: pytest.TempPathFactory, @@ -75,15 +86,23 @@ def fixture_test_models_path( """Provide path to the test models.""" tmp_path = tmp_path_factory.mktemp("models") + # Keras Model keras_model = get_test_keras_model() - save_keras_model(keras_model, tmp_path / "test_model.h5") + save_keras_model(keras_model, tmp_path / TEST_MODEL_KERAS_FILE) + + # Un-quantized TensorFlow Lite model (fp32) + save_tflite_model( + convert_to_tflite(keras_model, quantized=False), + tmp_path / TEST_MODEL_TFLITE_FP32_FILE, + ) + # Quantized TensorFlow Lite model (int8) tflite_model = convert_to_tflite(keras_model, quantized=True) - tflite_model_path = tmp_path / "test_model.tflite" + tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE save_tflite_model(tflite_model, tflite_model_path) - tflite_vela_model = tmp_path / "test_model_vela.tflite" - + # Vela-optimized TensorFlow Lite model (int8) + tflite_vela_model = tmp_path / TEST_MODEL_TFLITE_VELA_FILE target_config = EthosUConfiguration.load_profile("ethos-u55-256") optimize_model( tflite_model_path, @@ -91,9 +110,9 @@ def fixture_test_models_path( tflite_vela_model, ) - tf.saved_model.save(keras_model, str(tmp_path / "tf_model_test_model")) + tf.saved_model.save(keras_model, str(tmp_path / TEST_MODEL_TF_SAVED_MODEL_FILE)) - invalid_tflite_model = tmp_path / "invalid.tflite" + invalid_tflite_model = tmp_path / TEST_MODEL_INVALID_FILE invalid_tflite_model.touch() yield tmp_path @@ -104,28 +123,82 @@ def fixture_test_models_path( @pytest.fixture(scope="session", name="test_keras_model") def fixture_test_keras_model(test_models_path: Path) -> Path: """Return test Keras model.""" - return test_models_path / "test_model.h5" + return test_models_path / TEST_MODEL_KERAS_FILE @pytest.fixture(scope="session", name="test_tflite_model") def fixture_test_tflite_model(test_models_path: Path) -> Path: """Return test TensorFlow Lite model.""" - return test_models_path / "test_model.tflite" + return test_models_path / TEST_MODEL_TFLITE_INT8_FILE + + +@pytest.fixture(scope="session", name="test_tflite_model_fp32") +def fixture_test_tflite_model_fp32(test_models_path: Path) -> Path: + """Return test TensorFlow Lite model.""" + return test_models_path / TEST_MODEL_TFLITE_FP32_FILE @pytest.fixture(scope="session", name="test_tflite_vela_model") def fixture_test_tflite_vela_model(test_models_path: Path) -> Path: """Return test Vela-optimized TensorFlow Lite model.""" - return test_models_path / "test_model_vela.tflite" + return test_models_path / TEST_MODEL_TFLITE_VELA_FILE @pytest.fixture(scope="session", name="test_tf_model") def fixture_test_tf_model(test_models_path: Path) -> Path: """Return test TensorFlow Lite model.""" - return test_models_path / "tf_model_test_model" + return test_models_path / TEST_MODEL_TF_SAVED_MODEL_FILE @pytest.fixture(scope="session", name="test_tflite_invalid_model") def fixture_test_tflite_invalid_model(test_models_path: Path) -> Path: """Return test invalid TensorFlow Lite model.""" - return test_models_path / "invalid.tflite" + return test_models_path / TEST_MODEL_INVALID_FILE + + +def _write_tfrecord( + tfrecord_file: Path, + data_generator: Callable, + input_name: str = "serving_default_input:0", + num_records: int = 3, +) -> None: + """Write data to a tfrecord.""" + with NumpyTFWriter(str(tfrecord_file)) as writer: + for _ in range(num_records): + writer.write({input_name: data_generator()}) + + +@pytest.fixture(scope="session", name="test_tfrecord") +def fixture_test_tfrecord( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" + tmp_path = tmp_path_factory.mktemp("tfrecords") + tfrecord_file = tmp_path / "test_int8.tfrecord" + + def random_data() -> np.ndarray: + return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8) + + _write_tfrecord(tfrecord_file, random_data) + + yield tfrecord_file + + shutil.rmtree(tmp_path) + + +@pytest.fixture(scope="session", name="test_tfrecord_fp32") +def fixture_test_tfrecord_fp32( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create tfrecord with random data matching fixture 'test_tflite_model_fp32'.""" + tmp_path = tmp_path_factory.mktemp("tfrecords") + tfrecord_file = tmp_path / "test_fp32.tfrecord" + + def random_data() -> np.ndarray: + return np.random.rand(1, 28, 28, 1).astype(np.float32) + + _write_tfrecord(tfrecord_file, random_data) + + yield tfrecord_file + + shutil.rmtree(tmp_path) diff --git a/tests/test_backend_vela_compat.py b/tests/test_backend_vela_compat.py index 0c39dd6..bea6a1b 100644 --- a/tests/test_backend_vela_compat.py +++ b/tests/test_backend_vela_compat.py @@ -12,13 +12,14 @@ from mlia.backend.vela.compat import Operators from mlia.backend.vela.compat import supported_operators from mlia.target.ethos_u.config import EthosUConfiguration from mlia.utils.filesystem import working_directory +from tests.conftest import TEST_MODEL_TFLITE_INT8_FILE @pytest.mark.parametrize( "model, expected_ops", [ ( - "test_model.tflite", + TEST_MODEL_TFLITE_INT8_FILE, Operators( ops=[ Operator( diff --git a/tests/test_nn_rewrite_core_graph_edit_cut.py b/tests/test_nn_rewrite_core_graph_edit_cut.py new file mode 100644 index 0000000..914fdfd --- /dev/null +++ b/tests/test_nn_rewrite_core_graph_edit_cut.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.graph_edit.cut.""" +from pathlib import Path + +import numpy as np +import tensorflow as tf + +from mlia.nn.rewrite.core.graph_edit.cut import cut_model + + +def test_cut_model(test_tflite_model: Path, tmp_path: Path) -> None: + """Test the function cut_model().""" + output_file = tmp_path / "out.tflite" + cut_model( + model_file=test_tflite_model, + input_names=["serving_default_input:0"], + output_names=["sequential/flatten/Reshape"], + subgraph_index=0, + output_file=output_file, + ) + assert output_file.is_file() + + interpreter = tf.lite.Interpreter(model_path=str(output_file)) + output_details = interpreter.get_output_details() + assert len(output_details) == 1 + out = output_details[0] + assert "Reshape" in out["name"] + assert np.prod(out["shape"]) == 1728 diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py new file mode 100644 index 0000000..cbbbeba --- /dev/null +++ b/tests/test_nn_rewrite_core_graph_edit_join.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.graph_edit.join.""" +from pathlib import Path + +from mlia.nn.rewrite.core.graph_edit.cut import cut_model +from mlia.nn.rewrite.core.graph_edit.join import join_models +from mlia.nn.rewrite.core.utils.utils import load +from tests.utils.rewrite import models_are_equal + + +def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None: + """Test the function join_models().""" + # Cut model into two parts first + first_file = tmp_path / "first_part.tflite" + second_file = tmp_path / "second_part.tflite" + cut_model( + model_file=str(test_tflite_model), + input_names=["serving_default_input:0"], + output_names=["sequential/flatten/Reshape"], + subgraph_index=0, + output_file=str(first_file), + ) + cut_model( + model_file=str(test_tflite_model), + input_names=["sequential/flatten/Reshape"], + output_names=["StatefulPartitionedCall:0"], + subgraph_index=0, + output_file=str(second_file), + ) + assert first_file.is_file() + assert second_file.is_file() + + joined_file = tmp_path / "joined.tflite" + + # Now re-join the cut model and check the result is the same as the original + for in_src, in_dst in ((first_file, second_file), (second_file, first_file)): + join_models( + input_src=str(in_src), + input_dst=str(in_dst), + output_file=str(joined_file), + subgraph_src=0, + subgraph_dst=0, + ) + assert joined_file.is_file() + + orig_model = load(str(test_tflite_model)) + joined_model = load(str(joined_file)) + + assert models_are_equal(orig_model, joined_model) diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py new file mode 100644 index 0000000..39aeef5 --- /dev/null +++ b/tests/test_nn_rewrite_core_graph_edit_record.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.graph_edit.record.""" +from pathlib import Path + +import pytest +import tensorflow as tf + +from mlia.nn.rewrite.core.graph_edit.record import record_model +from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader + + +@pytest.mark.parametrize("batch_size", (None, 1, 2)) +def test_record_model( + test_tflite_model: Path, + tmp_path: Path, + test_tfrecord: Path, + batch_size: int, +) -> None: + """Test the function record_model().""" + output_file = tmp_path / "out.tfrecord" + record_model( + input_filename=str(test_tfrecord), + model_filename=str(test_tflite_model), + output_filename=str(output_file), + batch_size=batch_size, + ) + assert output_file.is_file() + + def data_matches_outputs(name: str, tensor: tf.Tensor, model_outputs: list) -> bool: + """Check that the name and the tensor match any of the model outputs.""" + for model_output in model_outputs: + if model_output["name"] == name: + # If the name is a match, tensor shape and type have to match! + tensor_shape = tensor.shape.as_list() + tensor_type = tensor.dtype.as_numpy_dtype + return all( + ( + tensor_shape == model_output["shape"].tolist(), + tensor_type == model_output["dtype"], + ) + ) + return False + + # Now load model and the data and make sure that the written data matches + # any of the model outputs + interpreter = tf.lite.Interpreter(str(test_tflite_model)) + model_outputs = interpreter.get_output_details() + dataset = NumpyTFReader(str(output_file)) + for data in dataset: + for name, tensor in data.items(): + assert data_matches_outputs(name, tensor, model_outputs) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py new file mode 100644 index 0000000..d2bc1e0 --- /dev/null +++ b/tests/test_nn_rewrite_core_train.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.train.""" +# pylint: disable=too-many-arguments +from __future__ import annotations + +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +import pytest +import tensorflow as tf + +from mlia.nn.rewrite.core.train import augmentation_presets +from mlia.nn.rewrite.core.train import mixup +from mlia.nn.rewrite.core.train import train + + +def replace_fully_connected_with_conv(input_shape, output_shape) -> tf.keras.Model: + """Get a replacement model for the fully connected layer.""" + for name, shape in { + "Input": input_shape, + "Output": output_shape, + }.items(): + if len(shape) != 1: + raise RuntimeError(f"{name}: shape (N,) expected, but it is {input_shape}.") + + model = tf.keras.Sequential(name="RewriteModel") + model.add(tf.keras.Input(input_shape)) + model.add(tf.keras.layers.Reshape((1, 1, input_shape[0]))) + model.add(tf.keras.layers.Conv2D(filters=output_shape[0], kernel_size=(1, 1))) + model.add(tf.keras.layers.Reshape(output_shape)) + + return model + + +def check_train( + tflite_model: Path, + tfrecord: Path, + batch_size: int = 1, + verbose: bool = False, + show_progress: bool = False, + augmentation_preset: tuple[float | None, float | None] = augmentation_presets[ + "none" + ], + lr_schedule: str = "cosine", + use_unmodified_model: bool = False, + num_procs: int = 1, +) -> None: + """Test the train() function.""" + with TemporaryDirectory() as tmp_dir: + output_file = Path(tmp_dir, "out.tfrecord") + result = train( + source_model=str(tflite_model), + unmodified_model=str(tflite_model) if use_unmodified_model else None, + output_model=str(output_file), + input_tfrec=str(tfrecord), + replace_fn=replace_fully_connected_with_conv, + input_tensors=["sequential/flatten/Reshape"], + output_tensors=["StatefulPartitionedCall:0"], + augment=augmentation_preset, + steps=32, + lr=1e-3, + batch_size=batch_size, + verbose=verbose, + show_progress=show_progress, + learning_rate_schedule=lr_schedule, + num_procs=num_procs, + ) + assert len(result) == 2 + assert all(res >= 0.0 for res in result), f"Results out of bound: {result}" + assert output_file.is_file() + + +@pytest.mark.parametrize( + ( + "batch_size", + "verbose", + "show_progress", + "augmentation_preset", + "lr_schedule", + "use_unmodified_model", + "num_procs", + ), + ( + (1, False, False, augmentation_presets["none"], "cosine", False, 2), + (32, True, True, augmentation_presets["gaussian"], "late", True, 1), + (2, False, False, augmentation_presets["mixup"], "constant", True, 0), + ( + 1, + False, + False, + augmentation_presets["mix_gaussian_large"], + "cosine", + False, + 2, + ), + ), +) +def test_train( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, + batch_size: int, + verbose: bool, + show_progress: bool, + augmentation_preset: tuple[float | None, float | None], + lr_schedule: str, + use_unmodified_model: bool, + num_procs: int, +) -> None: + """Test the train() function with valid parameters.""" + check_train( + tflite_model=test_tflite_model_fp32, + tfrecord=test_tfrecord_fp32, + batch_size=batch_size, + verbose=verbose, + show_progress=show_progress, + augmentation_preset=augmentation_preset, + lr_schedule=lr_schedule, + use_unmodified_model=use_unmodified_model, + num_procs=num_procs, + ) + + +def test_train_invalid_schedule( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, +) -> None: + """Test the train() function with an invalid schedule.""" + with pytest.raises(ValueError): + check_train( + tflite_model=test_tflite_model_fp32, + tfrecord=test_tfrecord_fp32, + lr_schedule="unknown_schedule", + ) + + +def test_train_invalid_augmentation( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, +) -> None: + """Test the train() function with an invalid augmentation.""" + with pytest.raises(ValueError): + check_train( + tflite_model=test_tflite_model_fp32, + tfrecord=test_tfrecord_fp32, + augmentation_preset=(1.0, 2.0, 3.0), # type: ignore + ) + + +def test_mixup() -> None: + """Test the mixup() function.""" + src = np.array((1, 2, 3)) + dst = mixup(rng=np.random.default_rng(123), batch=src) + assert src.shape == dst.shape + assert np.all(dst >= 0.0) + assert np.all(dst <= 3.0) diff --git a/tests/test_nn_rewrite_core_utils.py b/tests/test_nn_rewrite_core_utils.py new file mode 100644 index 0000000..d806a7b --- /dev/null +++ b/tests/test_nn_rewrite_core_utils.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.utils.""" +from pathlib import Path + +import pytest +import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT + +from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.rewrite.core.utils.utils import save +from tests.utils.rewrite import models_are_equal + + +def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None: + """Test the load/save functions for TensorFlow Lite models.""" + with pytest.raises(FileNotFoundError): + load("THIS_IS_NOT_A_REAL_FILE") + + model = load(test_tflite_model) + assert isinstance(model, ModelT) + assert model.subgraphs + + output_file = tmp_path / "test.tflite" + assert not output_file.is_file() + save(model, output_file) + assert output_file.is_file() + + model_copy = load(str(output_file)) + assert models_are_equal(model, model_copy) + + # Double check that the TensorFlow Lite Interpreter can still load the file. + tf.lite.Interpreter(model_path=str(output_file)) diff --git a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py new file mode 100644 index 0000000..7fc8048 --- /dev/null +++ b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.core.utils.numpy_tfrecord.""" +from __future__ import annotations + +from pathlib import Path + +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count +from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec + + +def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None: + """Test function sample_tfrec().""" + output_file = tmp_path / "output.tfrecord" + # Sample 1 sample from test_tfrecord + sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file)) + assert output_file.is_file() + assert numpytf_count(str(output_file)) == 1 diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py new file mode 100644 index 0000000..4264b4b --- /dev/null +++ b/tests/utils/rewrite.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Common test utils for the rewrite tests.""" +from __future__ import annotations + +from tensorflow.lite.python.schema_py_generated import ModelT + + +def models_are_equal(model1: ModelT, model2: ModelT) -> bool: + """Check that the two models are equal.""" + if len(model1.subgraphs) != len(model2.subgraphs): + return False + + for graph1, graph2 in zip(model1.subgraphs, model2.subgraphs): + if graph1.name != graph2.name or len(graph1.tensors) != len(graph2.tensors): + return False + for tensor1 in graph1.tensors: + for tensor2 in graph2.tensors: + if tensor1.name == tensor2.name: + if ( + tensor1.shape == tensor2.shape + ).all() or tensor1.type == tensor2.type: + break + else: + return False # Tensor from graph1 not found in other graph.") + + return True -- cgit v1.2.1