aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_optimizations_quantization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_tensorflow_optimizations_quantization.py')
-rw-r--r--tests/test_nn_tensorflow_optimizations_quantization.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/tests/test_nn_tensorflow_optimizations_quantization.py b/tests/test_nn_tensorflow_optimizations_quantization.py
new file mode 100644
index 0000000..5228cec
--- /dev/null
+++ b/tests/test_nn_tensorflow_optimizations_quantization.py
@@ -0,0 +1,53 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module optimizations/quantization."""
+from __future__ import annotations
+
+from itertools import chain
+from pathlib import Path
+from typing import Generator
+
+import numpy as np
+from numpy.core.numeric import isclose
+
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.optimizations.quantization import dequantize
+from mlia.nn.tensorflow.optimizations.quantization import is_quantized
+from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters
+from mlia.nn.tensorflow.optimizations.quantization import quantize
+
+
+def model_io_quant_params(model_path: Path) -> Generator:
+ """Generate QuantizationParameters for all model inputs and outputs."""
+ model = TFLiteModel(model_path=model_path)
+ for details in chain(model.input_details, model.output_details):
+ yield QuantizationParameters(**details["quantization_parameters"])
+
+
+def test_is_quantized(test_tflite_model: Path) -> None:
+ """Test function is_quantized() with a quantized model."""
+ for quant_params in model_io_quant_params(test_tflite_model):
+ assert is_quantized(quant_params)
+
+
+def test_is_not_quantized(test_tflite_model_fp32: Path) -> None:
+ """Test function is_quantized() with an unquantized model."""
+ for quant_params in model_io_quant_params(test_tflite_model_fp32):
+ assert not is_quantized(quant_params)
+
+
+def test_quantize() -> None:
+ """Test function quantize()."""
+ ref_dequant = np.array((0.0, 0.1, 0.2, 0.3))
+ ref_quant = np.array((0, 10, 20, 30), dtype=np.int8)
+ quant_params = QuantizationParameters(
+ scales=np.array([0.01]), zero_points=np.array([0.0]), quantized_dimension=0
+ )
+
+ quant = quantize(ref_dequant, quant_params)
+ assert quant.dtype == np.int8
+ assert np.all(quant == ref_quant)
+
+ dequant = dequantize(quant, quant_params)
+ assert dequant.dtype == np.float32
+ assert np.all(isclose(dequant, ref_dequant, atol=0.03))