diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_nn_rewrite_core_extract.py | 38 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_graph_edit_record.py | 63 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 67 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_config.py | 12 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_optimizations_quantization.py | 53 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_utils.py | 31 |
6 files changed, 244 insertions, 20 deletions
diff --git a/tests/test_nn_rewrite_core_extract.py b/tests/test_nn_rewrite_core_extract.py new file mode 100644 index 0000000..09eca77 --- /dev/null +++ b/tests/test_nn_rewrite_core_extract.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.core.extract.""" +from __future__ import annotations + +from pathlib import Path +from typing import Any +from typing import Iterable + +import pytest + +from mlia.nn.rewrite.core.extract import ExtractPaths +from mlia.nn.rewrite.core.graph_edit.record import DEQUANT_SUFFIX + + +@pytest.mark.parametrize("dir_path", ("/dev/null", Path("/dev/null"))) +@pytest.mark.parametrize("model_is_quantized", (False, True)) +@pytest.mark.parametrize( + ("obj", "func_names", "suffix"), + ( + (ExtractPaths.tflite, ("start", "replace", "end"), ".tflite"), + (ExtractPaths.tfrec, ("input", "output", "end"), ".tfrec"), + ), +) +def test_extract_paths( + dir_path: str | Path, + model_is_quantized: bool, + obj: Any, + func_names: Iterable[str], + suffix: str, +) -> None: + """Test class ExtractPaths.""" + for func_name in func_names: + func = getattr(obj, func_name) + path = func(dir_path, model_is_quantized) + assert path == Path(dir_path, path.relative_to(dir_path)) + assert path.suffix == suffix + assert not model_is_quantized or path.stem.endswith(DEQUANT_SUFFIX) diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py index 41b9c50..422b53e 100644 --- a/tests/test_nn_rewrite_core_graph_edit_record.py +++ b/tests/test_nn_rewrite_core_graph_edit_record.py @@ -3,43 +3,57 @@ """Tests for module mlia.nn.rewrite.graph_edit.record.""" from pathlib import Path +import numpy as np +import pytest import tensorflow as tf +from mlia.nn.rewrite.core.extract import ExtractPaths from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read +def data_matches_outputs( + name: str, + tensor: tf.Tensor, + model_outputs: list, + dequantized_output: bool, +) -> bool: + """Check that the name and the tensor match any of the model outputs.""" + for model_output in model_outputs: + if model_output["name"] == name: + # If the name is a match, tensor shape and type have to match! + tensor_shape = tensor.shape.as_list() + tensor_type = tensor.dtype.as_numpy_dtype + return all( + ( + tensor_shape == model_output["shape"].tolist(), + tensor_type == np.float32 + if dequantized_output + else model_output["dtype"], + ) + ) + return False + + def check_record_model( test_tflite_model: Path, tmp_path: Path, test_tfrecord: Path, batch_size: int, + dequantize_output: bool, ) -> None: """Test the function record_model().""" - output_file = tmp_path / "out.tfrecord" + output_file = ExtractPaths.tfrec.output(tmp_path) record_model( input_filename=str(test_tfrecord), model_filename=str(test_tflite_model), output_filename=str(output_file), batch_size=batch_size, + dequantize_output=dequantize_output, ) + output_file = ExtractPaths.tfrec.output(tmp_path, dequantize_output) assert output_file.is_file() - def data_matches_outputs(name: str, tensor: tf.Tensor, model_outputs: list) -> bool: - """Check that the name and the tensor match any of the model outputs.""" - for model_output in model_outputs: - if model_output["name"] == name: - # If the name is a match, tensor shape and type have to match! - tensor_shape = tensor.shape.as_list() - tensor_type = tensor.dtype.as_numpy_dtype - return all( - ( - tensor_shape == model_output["shape"].tolist(), - tensor_type == model_output["dtype"], - ) - ) - return False - # Now load model and the data and make sure that the written data matches # any of the model outputs interpreter = tf.lite.Interpreter(str(test_tflite_model)) @@ -47,4 +61,19 @@ def check_record_model( dataset = numpytf_read(str(output_file)) for data in dataset: for name, tensor in data.items(): - assert data_matches_outputs(name, tensor, model_outputs) + assert data_matches_outputs(name, tensor, model_outputs, dequantize_output) + + +@pytest.mark.parametrize("batch_size", (None, 1, 2)) +@pytest.mark.parametrize("dequantize_output", (True, False)) +def test_record_model( + test_tflite_model: Path, + tmp_path: Path, + test_tfrecord: Path, + batch_size: int, + dequantize_output: bool, +) -> None: + """Test the function record_model().""" + check_record_model( + test_tflite_model, tmp_path, test_tfrecord, batch_size, dequantize_output + ) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index b001a09..ef52320 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Tests for module mlia.nn.rewrite.train.""" +"""Tests for module mlia.nn.rewrite.core.train.""" # pylint: disable=too-many-arguments from __future__ import annotations @@ -47,10 +47,11 @@ def check_train( tfrecord: Path, train_params: TrainingParameters = TestTrainingParameters(), use_unmodified_model: bool = False, + quantized: bool = False, ) -> None: """Test the train() function.""" with TemporaryDirectory() as tmp_dir: - output_file = Path(tmp_dir, "out.tfrecord") + output_file = Path(tmp_dir, "out.tflite") result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, @@ -65,6 +66,17 @@ def check_train( assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}" assert output_file.is_file() + if quantized: + interpreter = tf.lite.Interpreter(model_path=str(output_file)) + interpreter.allocate_tensors() + # Check that the quantization parameters are non-zero + assert all(interpreter.get_output_details()[0]["quantization"]) + assert all(interpreter.get_input_details()[0]["quantization"]) + dtypes = [] + for tensor_detail in interpreter.get_tensor_details(): + dtypes.append(tensor_detail["dtype"]) + assert all(np.issubdtype(dtype, np.integer) for dtype in dtypes) + @pytest.mark.parametrize( ( @@ -89,7 +101,7 @@ def check_train( ), ), ) -def test_train( +def test_train_fp32( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, batch_size: int, @@ -114,6 +126,55 @@ def test_train( ) +@pytest.mark.parametrize( + ( + "batch_size", + "show_progress", + "augmentation_preset", + "lr_schedule", + "use_unmodified_model", + "num_procs", + ), + ( + (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2), + (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1), + (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0), + ( + 1, + False, + AUGMENTATION_PRESETS["mix_gaussian_large"], + "cosine", + False, + 2, + ), + ), +) +def test_train_int8( + test_tflite_model: Path, + test_tfrecord: Path, + batch_size: int, + show_progress: bool, + augmentation_preset: tuple[float | None, float | None], + lr_schedule: LearningRateSchedule, + use_unmodified_model: bool, + num_procs: int, +) -> None: + """Test the train() function with valid parameters.""" + check_train( + tflite_model=test_tflite_model, + tfrecord=test_tfrecord, + train_params=TestTrainingParameters( + batch_size=batch_size, + show_progress=show_progress, + augmentations=augmentation_preset, + learning_rate_schedule=lr_schedule, + num_procs=num_procs, + ), + use_unmodified_model=use_unmodified_model, + quantized=True, + ) + + def test_train_invalid_schedule( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py index 48aec0a..fff3857 100644 --- a/tests/test_nn_tensorflow_config.py +++ b/tests/test_nn_tensorflow_config.py @@ -111,3 +111,15 @@ def test_tflite_model_call( for named_input in data.as_numpy_iterator(): res = model(named_input) assert res + + +def test_tflite_model_is_tensor_quantized(test_tflite_model: Path) -> None: + """Test function TFLiteModel.is_tensor_quantized().""" + model = TFLiteModel(test_tflite_model) + input_details = model.input_details[0] + assert model.is_tensor_quantized(name=input_details["name"]) + assert model.is_tensor_quantized(idx=input_details["index"]) + with pytest.raises(ValueError): + assert model.is_tensor_quantized() + with pytest.raises(NameError): + assert model.is_tensor_quantized(name="NAME_DOES_NOT_EXIST") diff --git a/tests/test_nn_tensorflow_optimizations_quantization.py b/tests/test_nn_tensorflow_optimizations_quantization.py new file mode 100644 index 0000000..5228cec --- /dev/null +++ b/tests/test_nn_tensorflow_optimizations_quantization.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module optimizations/quantization.""" +from __future__ import annotations + +from itertools import chain +from pathlib import Path +from typing import Generator + +import numpy as np +from numpy.core.numeric import isclose + +from mlia.nn.tensorflow.config import TFLiteModel +from mlia.nn.tensorflow.optimizations.quantization import dequantize +from mlia.nn.tensorflow.optimizations.quantization import is_quantized +from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters +from mlia.nn.tensorflow.optimizations.quantization import quantize + + +def model_io_quant_params(model_path: Path) -> Generator: + """Generate QuantizationParameters for all model inputs and outputs.""" + model = TFLiteModel(model_path=model_path) + for details in chain(model.input_details, model.output_details): + yield QuantizationParameters(**details["quantization_parameters"]) + + +def test_is_quantized(test_tflite_model: Path) -> None: + """Test function is_quantized() with a quantized model.""" + for quant_params in model_io_quant_params(test_tflite_model): + assert is_quantized(quant_params) + + +def test_is_not_quantized(test_tflite_model_fp32: Path) -> None: + """Test function is_quantized() with an unquantized model.""" + for quant_params in model_io_quant_params(test_tflite_model_fp32): + assert not is_quantized(quant_params) + + +def test_quantize() -> None: + """Test function quantize().""" + ref_dequant = np.array((0.0, 0.1, 0.2, 0.3)) + ref_quant = np.array((0, 10, 20, 30), dtype=np.int8) + quant_params = QuantizationParameters( + scales=np.array([0.01]), zero_points=np.array([0.0]), quantized_dimension=0 + ) + + quant = quantize(ref_dequant, quant_params) + assert quant.dtype == np.int8 + assert np.all(quant == ref_quant) + + dequant = dequantize(quant, quant_params) + assert dequant.dtype == np.float32 + assert np.all(isclose(dequant, ref_dequant, atol=0.03)) diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py index 14b06c4..dab8b4e 100644 --- a/tests/test_nn_tensorflow_utils.py +++ b/tests/test_nn_tensorflow_utils.py @@ -1,14 +1,17 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Test for module utils/test_utils.""" +import re from pathlib import Path import numpy as np import pytest import tensorflow as tf +from mlia.nn.tensorflow.utils import check_tflite_datatypes from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import get_tf_tensor_shape +from mlia.nn.tensorflow.utils import get_tflite_model_type_map from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_tflite_model from mlia.nn.tensorflow.utils import representative_dataset @@ -109,3 +112,31 @@ def test_is_keras_model(model_path: Path, expected_result: bool) -> None: def test_get_tf_tensor_shape(test_tf_model: Path) -> None: """Test get_tf_tensor_shape with test model.""" assert get_tf_tensor_shape(str(test_tf_model)) == [1, 28, 28, 1] + + +def test_tflite_model_type_map( + test_tflite_model_fp32: Path, test_tflite_model: Path +) -> None: + """Test the model type map function.""" + assert get_tflite_model_type_map(test_tflite_model_fp32) == { + "serving_default_input:0": np.float32 + } + assert get_tflite_model_type_map(test_tflite_model) == { + "serving_default_input:0": np.int8 + } + + +def test_check_tflite_datatypes( + test_tflite_model_fp32: Path, test_tflite_model: Path +) -> None: + """Test the model type map function.""" + check_tflite_datatypes(test_tflite_model_fp32, np.float32) + check_tflite_datatypes(test_tflite_model, np.int8) + + with pytest.raises( + Exception, + match=re.escape( + "unexpected data types: ['float32']. Only ['int8'] are allowed" + ), + ): + check_tflite_datatypes(test_tflite_model_fp32, np.int8) |