diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-07-19 16:35:57 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 16:06:17 +0100 |
commit | 3cd84481fa25e64c29e57396d4bf32d7a3ca490a (patch) | |
tree | ad81fb520a965bd3a3c7c983833b7cd48f9b8dea /tests/test_nn_tensorflow_config.py | |
parent | f3e6597dd50ec70f043d692b773f2d9fd31519ae (diff) | |
download | mlia-3cd84481fa25e64c29e57396d4bf32d7a3ca490a.tar.gz |
Bug-fixes and re-factoring for the rewrite module
- Fix input shape of rewrite replacement:
During and after training of the replacement model for a rewrite the
Keras model is converted and saved in TensorFlow Lite format. If the
input shape does not match the teacher model exactly, e.g. if the
batch size is undefined, the TFLiteConverter adds extra operators
during conversion.
- Fix rewritten model output
- Save the model output with the rewritten operator in the output dir
- Log MAE and NRMSE of the rewrite
- Remove 'verbose' flag from rewrite module and rely on the logging
mechanism to control verbose output.
- Re-factor utility classes for rewrites
- Merge the two TFLiteModel classes
- Move functionality to load/save TensorFlow Lite flatbuffers to
nn/tensorflow/tflite_graph
- Fix issue with unknown shape in datasets
After upgrading to TensorFlow 2.12 the unknown shape of the
TFRecordDataset is causing problems when training the replacement models
for rewrites. By explicitly setting the right shape of the tensors we
can work around the issue.
- Adapt default parameters for rewrites. The training steps especially
had to be increased significantly to be effective.
Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c
Diffstat (limited to 'tests/test_nn_tensorflow_config.py')
-rw-r--r-- | tests/test_nn_tensorflow_config.py | 40 |
1 files changed, 39 insertions, 1 deletions
diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py index 656619d..48aec0a 100644 --- a/tests/test_nn_tensorflow_config.py +++ b/tests/test_nn_tensorflow_config.py @@ -4,13 +4,28 @@ from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any +from typing import Generator +import numpy as np import pytest +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.tensorflow.config import get_model from mlia.nn.tensorflow.config import KerasModel +from mlia.nn.tensorflow.config import ModelConfiguration from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.config import TfModel +from tests.conftest import create_tfrecord + + +def test_model_configuration(test_keras_model: Path) -> None: + """Test ModelConfiguration class.""" + model = ModelConfiguration(model_path=test_keras_model) + assert test_keras_model.match(model.model_path) + with pytest.raises(NotImplementedError): + model.convert_to_keras("keras_model.h5") + with pytest.raises(NotImplementedError): + model.convert_to_tflite("model.tflite") def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None: @@ -38,7 +53,7 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None: @pytest.mark.parametrize( "model_path, expected_type, expected_error", [ - ("test.tflite", TFLiteModel, does_not_raise()), + ("test.tflite", TFLiteModel, pytest.raises(ValueError)), ("test.h5", KerasModel, does_not_raise()), ("test.hdf5", KerasModel, does_not_raise()), ( @@ -73,3 +88,26 @@ def test_get_model_dir( """Test TensorFlow Lite model type.""" model = get_model(str(test_models_path / model_path)) assert isinstance(model, expected_type) + + +@pytest.fixture(scope="session", name="test_tfrecord_fp32_batch_3") +def fixture_test_tfrecord_fp32_batch_3( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create tfrecord (same as test_tfrecord_fp32) but with batch size 3.""" + + def random_data() -> np.ndarray: + return np.random.rand(3, 28, 28, 1).astype(np.float32) + + yield from create_tfrecord(tmp_path_factory, random_data) + + +def test_tflite_model_call( + test_tflite_model_fp32: Path, test_tfrecord_fp32_batch_3: Path +) -> None: + """Test inference function of class TFLiteModel.""" + model = TFLiteModel(test_tflite_model_fp32, batch_size=2) + data = numpytf_read(test_tfrecord_fp32_batch_3) + for named_input in data.as_numpy_iterator(): + res = model(named_input) + assert res |