diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-07-19 16:35:57 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 16:06:17 +0100 |
commit | 3cd84481fa25e64c29e57396d4bf32d7a3ca490a (patch) | |
tree | ad81fb520a965bd3a3c7c983833b7cd48f9b8dea /tests/test_nn_tensorflow_tflite_graph.py | |
parent | f3e6597dd50ec70f043d692b773f2d9fd31519ae (diff) | |
download | mlia-3cd84481fa25e64c29e57396d4bf32d7a3ca490a.tar.gz |
Bug-fixes and re-factoring for the rewrite module
- Fix input shape of rewrite replacement:
During and after training of the replacement model for a rewrite the
Keras model is converted and saved in TensorFlow Lite format. If the
input shape does not match the teacher model exactly, e.g. if the
batch size is undefined, the TFLiteConverter adds extra operators
during conversion.
- Fix rewritten model output
- Save the model output with the rewritten operator in the output dir
- Log MAE and NRMSE of the rewrite
- Remove 'verbose' flag from rewrite module and rely on the logging
mechanism to control verbose output.
- Re-factor utility classes for rewrites
- Merge the two TFLiteModel classes
- Move functionality to load/save TensorFlow Lite flatbuffers to
nn/tensorflow/tflite_graph
- Fix issue with unknown shape in datasets
After upgrading to TensorFlow 2.12 the unknown shape of the
TFRecordDataset is causing problems when training the replacement models
for rewrites. By explicitly setting the right shape of the tensors we
can work around the issue.
- Adapt default parameters for rewrites. The training steps especially
had to be increased significantly to be effective.
Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_graph.py')
-rw-r--r-- | tests/test_nn_tensorflow_tflite_graph.py | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py index cd1fad6..3512cdd 100644 --- a/tests/test_nn_tensorflow_tflite_graph.py +++ b/tests/test_nn_tensorflow_tflite_graph.py @@ -1,15 +1,22 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the tflite_graph module.""" import json from pathlib import Path +import pytest +import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT + +from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import Op from mlia.nn.tensorflow.tflite_graph import parse_subgraphs +from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.nn.tensorflow.tflite_graph import TensorInfo from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION from mlia.nn.tensorflow.tflite_graph import TFL_OP from mlia.nn.tensorflow.tflite_graph import TFL_TYPE +from tests.utils.rewrite import models_are_equal def test_tensor_info() -> None: @@ -79,3 +86,24 @@ def test_parse_subgraphs(test_tflite_model: Path) -> None: assert TFL_OP[oper.type] in TFL_OP assert len(oper.inputs) > 0 assert len(oper.outputs) > 0 + + +def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None: + """Test the load/save functions for TensorFlow Lite models.""" + with pytest.raises(FileNotFoundError): + load_fb("THIS_IS_NOT_A_REAL_FILE") + + model = load_fb(test_tflite_model) + assert isinstance(model, ModelT) + assert model.subgraphs + + output_file = tmp_path / "test.tflite" + assert not output_file.is_file() + save_fb(model, output_file) + assert output_file.is_file() + + model_copy = load_fb(str(output_file)) + assert models_are_equal(model, model_copy) + + # Double check that the TensorFlow Lite Interpreter can still load the file. + tf.lite.Interpreter(model_path=str(output_file)) |