From 3cd84481fa25e64c29e57396d4bf32d7a3ca490a Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Wed, 19 Jul 2023 16:35:57 +0100 Subject: Bug-fixes and re-factoring for the rewrite module - Fix input shape of rewrite replacement: During and after training of the replacement model for a rewrite the Keras model is converted and saved in TensorFlow Lite format. If the input shape does not match the teacher model exactly, e.g. if the batch size is undefined, the TFLiteConverter adds extra operators during conversion. - Fix rewritten model output - Save the model output with the rewritten operator in the output dir - Log MAE and NRMSE of the rewrite - Remove 'verbose' flag from rewrite module and rely on the logging mechanism to control verbose output. - Re-factor utility classes for rewrites - Merge the two TFLiteModel classes - Move functionality to load/save TensorFlow Lite flatbuffers to nn/tensorflow/tflite_graph - Fix issue with unknown shape in datasets After upgrading to TensorFlow 2.12 the unknown shape of the TFRecordDataset is causing problems when training the replacement models for rewrites. By explicitly setting the right shape of the tensors we can work around the issue. - Adapt default parameters for rewrites. The training steps especially had to be increased significantly to be effective. Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979 Signed-off-by: Benjamin Klimczak Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c --- tests/conftest.py | 39 +++++++---- tests/test_nn_rewrite_core_graph_edit_join.py | 6 +- tests/test_nn_rewrite_core_graph_edit_record.py | 4 +- tests/test_nn_rewrite_core_rewrite.py | 6 +- tests/test_nn_rewrite_core_train.py | 76 ++++++++++++---------- tests/test_nn_rewrite_core_utils.py | 33 ---------- tests/test_nn_rewrite_core_utils_numpy_tfrecord.py | 25 +++++++ tests/test_nn_tensorflow_config.py | 40 +++++++++++- tests/test_nn_tensorflow_tflite_graph.py | 30 ++++++++- tests/utils/rewrite.py | 18 +++++ 10 files changed, 189 insertions(+), 88 deletions(-) delete mode 100644 tests/test_nn_rewrite_core_utils.py (limited to 'tests') diff --git a/tests/conftest.py b/tests/conftest.py index c42b8cb..bb2423f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import shutil from pathlib import Path from typing import Callable from typing import Generator +from unittest.mock import MagicMock import numpy as np import pytest @@ -17,6 +18,7 @@ from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import save_keras_model from mlia.nn.tensorflow.utils import save_tflite_model from mlia.target.ethos_u.config import EthosUConfiguration +from tests.utils.rewrite import TestTrainingParameters @pytest.fixture(scope="session", name="test_resources_path") @@ -168,16 +170,12 @@ def _write_tfrecord( writer.write({input_name: data_generator()}) -@pytest.fixture(scope="session", name="test_tfrecord") -def fixture_test_tfrecord( - tmp_path_factory: pytest.TempPathFactory, +def create_tfrecord( + tmp_path_factory: pytest.TempPathFactory, random_data: Callable ) -> Generator[Path, None, None]: """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" tmp_path = tmp_path_factory.mktemp("tfrecords") - tfrecord_file = tmp_path / "test_int8.tfrecord" - - def random_data() -> np.ndarray: - return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8) + tfrecord_file = tmp_path / "test.tfrecord" _write_tfrecord(tfrecord_file, random_data) @@ -186,19 +184,36 @@ def fixture_test_tfrecord( shutil.rmtree(tmp_path) +@pytest.fixture(scope="session", name="test_tfrecord") +def fixture_test_tfrecord( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" + + def random_data() -> np.ndarray: + return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8) + + yield from create_tfrecord(tmp_path_factory, random_data) + + @pytest.fixture(scope="session", name="test_tfrecord_fp32") def fixture_test_tfrecord_fp32( tmp_path_factory: pytest.TempPathFactory, ) -> Generator[Path, None, None]: """Create tfrecord with random data matching fixture 'test_tflite_model_fp32'.""" - tmp_path = tmp_path_factory.mktemp("tfrecords") - tfrecord_file = tmp_path / "test_fp32.tfrecord" def random_data() -> np.ndarray: return np.random.rand(1, 28, 28, 1).astype(np.float32) - _write_tfrecord(tfrecord_file, random_data) + yield from create_tfrecord(tmp_path_factory, random_data) - yield tfrecord_file - shutil.rmtree(tmp_path) +@pytest.fixture(scope="session", autouse=True) +def set_training_steps() -> Generator[None, None, None]: + """Speed up tests by using TestTrainingParameters.""" + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "mlia.nn.select._get_rewrite_train_params", + MagicMock(return_value=TestTrainingParameters()), + ) + yield diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py index cb3e4e2..0cb121e 100644 --- a/tests/test_nn_rewrite_core_graph_edit_join.py +++ b/tests/test_nn_rewrite_core_graph_edit_join.py @@ -10,7 +10,7 @@ import pytest from mlia.nn.rewrite.core.graph_edit.cut import cut_model from mlia.nn.rewrite.core.graph_edit.join import append_relabel from mlia.nn.rewrite.core.graph_edit.join import join_models -from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.tensorflow.tflite_graph import load_fb from tests.utils.rewrite import models_are_equal @@ -49,8 +49,8 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None: ) assert joined_file.is_file() - orig_model = load(str(test_tflite_model)) - joined_model = load(str(joined_file)) + orig_model = load_fb(str(test_tflite_model)) + joined_model = load_fb(str(joined_file)) assert models_are_equal(orig_model, joined_model) diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py index cd728af..41b9c50 100644 --- a/tests/test_nn_rewrite_core_graph_edit_record.py +++ b/tests/test_nn_rewrite_core_graph_edit_record.py @@ -3,15 +3,13 @@ """Tests for module mlia.nn.rewrite.graph_edit.record.""" from pathlib import Path -import pytest import tensorflow as tf from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read -@pytest.mark.parametrize("batch_size", (None, 1, 2)) -def test_record_model( +def check_record_model( test_tflite_model: Path, tmp_path: Path, test_tfrecord: Path, diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index b98971e..2542db2 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -12,6 +12,7 @@ import pytest from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import Rewriter from mlia.nn.tensorflow.config import TFLiteModel +from tests.utils.rewrite import TestTrainingParameters @pytest.mark.parametrize( @@ -32,12 +33,14 @@ def test_rewrite_configuration( None, ) + assert config_obj.optimization_target in str(config_obj) + rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj) assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name assert isinstance(rewriter_obj, Rewriter) -def test_rewriter( +def test_rewriting_optimizer( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, ) -> None: @@ -46,6 +49,7 @@ def test_rewriter( "fully_connected", ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], test_tfrecord_fp32, + train_params=TestTrainingParameters(), ) test_obj = Rewriter(test_tflite_model_fp32, config_obj) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 3c2ef3e..4493671 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -4,6 +4,7 @@ # pylint: disable=too-many-arguments from __future__ import annotations +from contextlib import ExitStack as does_not_raise from pathlib import Path from tempfile import TemporaryDirectory from typing import Any @@ -12,10 +13,13 @@ import numpy as np import pytest import tensorflow as tf -from mlia.nn.rewrite.core.train import augmentation_presets +from mlia.nn.rewrite.core.train import augment_fn_twins +from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train +from mlia.nn.rewrite.core.train import TrainingParameters +from tests.utils.rewrite import TestTrainingParameters def replace_fully_connected_with_conv( @@ -41,15 +45,8 @@ def replace_fully_connected_with_conv( def check_train( tflite_model: Path, tfrecord: Path, - batch_size: int = 1, - verbose: bool = False, - show_progress: bool = False, - augmentation_preset: tuple[float | None, float | None] = augmentation_presets[ - "none" - ], - lr_schedule: LearningRateSchedule = "cosine", + train_params: TrainingParameters = TestTrainingParameters(), use_unmodified_model: bool = False, - num_procs: int = 1, ) -> None: """Test the train() function.""" with TemporaryDirectory() as tmp_dir: @@ -62,14 +59,7 @@ def check_train( replace_fn=replace_fully_connected_with_conv, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], - augment=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=batch_size, - verbose=verbose, - show_progress=show_progress, - learning_rate_schedule=lr_schedule, - num_procs=num_procs, + train_params=train_params, ) assert len(result) == 2 assert all(res >= 0.0 for res in result), f"Results out of bound: {result}" @@ -79,7 +69,6 @@ def check_train( @pytest.mark.parametrize( ( "batch_size", - "verbose", "show_progress", "augmentation_preset", "lr_schedule", @@ -87,14 +76,13 @@ def check_train( "num_procs", ), ( - (1, False, False, augmentation_presets["none"], "cosine", False, 2), - (32, True, True, augmentation_presets["gaussian"], "late", True, 1), - (2, False, False, augmentation_presets["mixup"], "constant", True, 0), + (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2), + (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1), + (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0), ( 1, False, - False, - augmentation_presets["mix_gaussian_large"], + AUGMENTATION_PRESETS["mix_gaussian_large"], "cosine", False, 2, @@ -105,7 +93,6 @@ def test_train( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, batch_size: int, - verbose: bool, show_progress: bool, augmentation_preset: tuple[float | None, float | None], lr_schedule: LearningRateSchedule, @@ -116,13 +103,14 @@ def test_train( check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - batch_size=batch_size, - verbose=verbose, - show_progress=show_progress, - augmentation_preset=augmentation_preset, - lr_schedule=lr_schedule, + train_params=TestTrainingParameters( + batch_size=batch_size, + show_progress=show_progress, + augmentations=augmentation_preset, + learning_rate_schedule=lr_schedule, + num_procs=num_procs, + ), use_unmodified_model=use_unmodified_model, - num_procs=num_procs, ) @@ -131,11 +119,13 @@ def test_train_invalid_schedule( test_tfrecord_fp32: Path, ) -> None: """Test the train() function with an invalid schedule.""" - with pytest.raises(ValueError): + with pytest.raises(AssertionError): check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - lr_schedule="unknown_schedule", # type: ignore + train_params=TestTrainingParameters( + learning_rate_schedule="unknown_schedule", + ), ) @@ -144,11 +134,13 @@ def test_train_invalid_augmentation( test_tfrecord_fp32: Path, ) -> None: """Test the train() function with an invalid augmentation.""" - with pytest.raises(ValueError): + with pytest.raises(AssertionError): check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - augmentation_preset=(1.0, 2.0, 3.0), # type: ignore + train_params=TestTrainingParameters( + augmentations=(1.0, 2.0, 3.0), + ), ) @@ -159,3 +151,19 @@ def test_mixup() -> None: assert src.shape == dst.shape assert np.all(dst >= 0.0) assert np.all(dst <= 3.0) + + +@pytest.mark.parametrize( + "augmentations, expected_error", + [ + (AUGMENTATION_PRESETS["none"], does_not_raise()), + (AUGMENTATION_PRESETS["mix_gaussian_large"], does_not_raise()), + ((None,) * 3, pytest.raises(AssertionError)), + ], +) +def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None: + """Test function augment_fn().""" + dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2, 3], "b": [4, 5, 6]}) + with expected_error: + fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore + assert len(fn_twins) == 2 diff --git a/tests/test_nn_rewrite_core_utils.py b/tests/test_nn_rewrite_core_utils.py deleted file mode 100644 index d806a7b..0000000 --- a/tests/test_nn_rewrite_core_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for module mlia.nn.rewrite.utils.""" -from pathlib import Path - -import pytest -import tensorflow as tf -from tensorflow.lite.python.schema_py_generated import ModelT - -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save -from tests.utils.rewrite import models_are_equal - - -def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None: - """Test the load/save functions for TensorFlow Lite models.""" - with pytest.raises(FileNotFoundError): - load("THIS_IS_NOT_A_REAL_FILE") - - model = load(test_tflite_model) - assert isinstance(model, ModelT) - assert model.subgraphs - - output_file = tmp_path / "test.tflite" - assert not output_file.is_file() - save(model, output_file) - assert output_file.is_file() - - model_copy = load(str(output_file)) - assert models_are_equal(model, model_copy) - - # Double check that the TensorFlow Lite Interpreter can still load the file. - tf.lite.Interpreter(model_path=str(output_file)) diff --git a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py index 7fc8048..d030350 100644 --- a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py +++ b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py @@ -5,6 +5,10 @@ from __future__ import annotations from pathlib import Path +import pytest +import tensorflow as tf + +from mlia.nn.rewrite.core.utils.numpy_tfrecord import make_decode_fn from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec @@ -16,3 +20,24 @@ def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None: sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file)) assert output_file.is_file() assert numpytf_count(str(output_file)) == 1 + + +def test_make_decode_fn(test_tfrecord: Path) -> None: + """Test function make_decode_fn().""" + decode = make_decode_fn(str(test_tfrecord)) + dataset = tf.data.TFRecordDataset(str(test_tfrecord)) + features = decode(next(iter(dataset))) + assert isinstance(features, dict) + assert len(features) == 1 + key, val = next(iter(features.items())) + assert isinstance(key, str) + assert isinstance(val, tf.Tensor) + assert val.dtype == tf.int8 + + with pytest.raises(FileNotFoundError): + make_decode_fn(str(test_tfrecord) + "_") + + +def test_numpytf_count(test_tfrecord: Path) -> None: + """Test function numpytf_count().""" + assert numpytf_count(test_tfrecord) == 3 diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py index 656619d..48aec0a 100644 --- a/tests/test_nn_tensorflow_config.py +++ b/tests/test_nn_tensorflow_config.py @@ -4,13 +4,28 @@ from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any +from typing import Generator +import numpy as np import pytest +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.tensorflow.config import get_model from mlia.nn.tensorflow.config import KerasModel +from mlia.nn.tensorflow.config import ModelConfiguration from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.config import TfModel +from tests.conftest import create_tfrecord + + +def test_model_configuration(test_keras_model: Path) -> None: + """Test ModelConfiguration class.""" + model = ModelConfiguration(model_path=test_keras_model) + assert test_keras_model.match(model.model_path) + with pytest.raises(NotImplementedError): + model.convert_to_keras("keras_model.h5") + with pytest.raises(NotImplementedError): + model.convert_to_tflite("model.tflite") def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None: @@ -38,7 +53,7 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None: @pytest.mark.parametrize( "model_path, expected_type, expected_error", [ - ("test.tflite", TFLiteModel, does_not_raise()), + ("test.tflite", TFLiteModel, pytest.raises(ValueError)), ("test.h5", KerasModel, does_not_raise()), ("test.hdf5", KerasModel, does_not_raise()), ( @@ -73,3 +88,26 @@ def test_get_model_dir( """Test TensorFlow Lite model type.""" model = get_model(str(test_models_path / model_path)) assert isinstance(model, expected_type) + + +@pytest.fixture(scope="session", name="test_tfrecord_fp32_batch_3") +def fixture_test_tfrecord_fp32_batch_3( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create tfrecord (same as test_tfrecord_fp32) but with batch size 3.""" + + def random_data() -> np.ndarray: + return np.random.rand(3, 28, 28, 1).astype(np.float32) + + yield from create_tfrecord(tmp_path_factory, random_data) + + +def test_tflite_model_call( + test_tflite_model_fp32: Path, test_tfrecord_fp32_batch_3: Path +) -> None: + """Test inference function of class TFLiteModel.""" + model = TFLiteModel(test_tflite_model_fp32, batch_size=2) + data = numpytf_read(test_tfrecord_fp32_batch_3) + for named_input in data.as_numpy_iterator(): + res = model(named_input) + assert res diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py index cd1fad6..3512cdd 100644 --- a/tests/test_nn_tensorflow_tflite_graph.py +++ b/tests/test_nn_tensorflow_tflite_graph.py @@ -1,15 +1,22 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the tflite_graph module.""" import json from pathlib import Path +import pytest +import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT + +from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import Op from mlia.nn.tensorflow.tflite_graph import parse_subgraphs +from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.nn.tensorflow.tflite_graph import TensorInfo from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION from mlia.nn.tensorflow.tflite_graph import TFL_OP from mlia.nn.tensorflow.tflite_graph import TFL_TYPE +from tests.utils.rewrite import models_are_equal def test_tensor_info() -> None: @@ -79,3 +86,24 @@ def test_parse_subgraphs(test_tflite_model: Path) -> None: assert TFL_OP[oper.type] in TFL_OP assert len(oper.inputs) > 0 assert len(oper.outputs) > 0 + + +def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None: + """Test the load/save functions for TensorFlow Lite models.""" + with pytest.raises(FileNotFoundError): + load_fb("THIS_IS_NOT_A_REAL_FILE") + + model = load_fb(test_tflite_model) + assert isinstance(model, ModelT) + assert model.subgraphs + + output_file = tmp_path / "test.tflite" + assert not output_file.is_file() + save_fb(model, output_file) + assert output_file.is_file() + + model_copy = load_fb(str(output_file)) + assert models_are_equal(model, model_copy) + + # Double check that the TensorFlow Lite Interpreter can still load the file. + tf.lite.Interpreter(model_path=str(output_file)) diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py index 4264b4b..739bb11 100644 --- a/tests/utils/rewrite.py +++ b/tests/utils/rewrite.py @@ -3,8 +3,12 @@ """Common test utils for the rewrite tests.""" from __future__ import annotations +from typing import Any + from tensorflow.lite.python.schema_py_generated import ModelT +from mlia.nn.rewrite.core.train import TrainingParameters + def models_are_equal(model1: ModelT, model2: ModelT) -> bool: """Check that the two models are equal.""" @@ -25,3 +29,17 @@ def models_are_equal(model1: ModelT, model2: ModelT) -> bool: return False # Tensor from graph1 not found in other graph.") return True + + +class TestTrainingParameters( + TrainingParameters +): # pylint: disable=too-few-public-methods + """ + TrainingParameter class for rewrites with different default values. + + To speed things up for the unit tests. + """ + + def __init__(self, *args: Any, steps: int = 32, **kwargs: Any) -> None: + """Initialize TrainingParameters with different defaults.""" + super().__init__(*args, steps=steps, **kwargs) # type: ignore -- cgit v1.2.1