diff options
-rw-r--r-- | src/mlia/nn/rewrite/core/extract.py | 61 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/cut.py | 16 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/record.py | 45 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 67 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/config.py | 101 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/quantization.py | 74 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/utils.py | 30 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_extract.py | 38 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_graph_edit_record.py | 63 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 67 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_config.py | 12 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_optimizations_quantization.py | 53 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_utils.py | 31 |
13 files changed, 598 insertions, 60 deletions
diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py index f609955..4fcf735 100644 --- a/src/mlia/nn/rewrite/core/extract.py +++ b/src/mlia/nn/rewrite/core/extract.py @@ -2,19 +2,62 @@ # SPDX-License-Identifier: Apache-2.0 """Extract module.""" # pylint: disable=too-many-arguments, too-many-locals +from __future__ import annotations + import os +from functools import partial +from pathlib import Path import tensorflow as tf from tensorflow.lite.python.schema_py_generated import SubGraphT from mlia.nn.rewrite.core.graph_edit.cut import cut_model +from mlia.nn.rewrite.core.graph_edit.record import dequantized_path from mlia.nn.rewrite.core.graph_edit.record import record_model - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +def _get_path( + ext: str, name: str, dir_path: str | Path, model_is_quantized: bool = False +) -> Path: + """Create a file path for extracted files.""" + path = Path(dir_path, f"{name}{ext}") + return dequantized_path(path) if model_is_quantized else path + + +class TFLitePaths: # pylint: disable=too-few-public-methods + """Provide safe access to TensorFlow Lite file paths.""" + + _get_path_tflite = partial(_get_path, ".tflite") + + start = partial(_get_path_tflite, "start") + replace = partial(_get_path_tflite, "replace") + end = partial(_get_path_tflite, "end") + + +class TFRecordPaths: # pylint: disable=too-few-public-methods + """Provide safe access to tfrec file paths.""" + + _get_path_tfrec = partial(_get_path, ".tfrec") + + input = partial(_get_path_tfrec, "input") + output = partial(_get_path_tfrec, "output") + end = partial(_get_path_tfrec, "end") + + +class ExtractPaths: # pylint: disable=too-few-public-methods + """Get paths to extract files. + + This is meant to be the single source of truth regarding all file names + created by the extract() function in an output directory. + """ + + tflite = TFLitePaths + tfrec = TFRecordPaths + + def extract( output_path: str, model_file: str, @@ -26,6 +69,7 @@ def extract( show_progress: bool = False, num_procs: int = 1, num_threads: int = 0, + dequantize_output: bool = False, ) -> None: """Extract a model after cut and record.""" try: @@ -33,7 +77,7 @@ def extract( except FileExistsError: pass - start_file = os.path.join(output_path, "start.tflite") + start_file = ExtractPaths.tflite.start(output_path) cut_model( model_file, input_names=None, @@ -42,7 +86,7 @@ def extract( output_file=start_file, ) - input_tfrec = os.path.join(output_path, "input.tfrec") + input_tfrec = ExtractPaths.tfrec.input(output_path) record_model( input_filename, start_file, @@ -50,9 +94,10 @@ def extract( show_progress=show_progress, num_procs=num_procs, num_threads=num_threads, + dequantize_output=dequantize_output, ) - replace_file = os.path.join(output_path, "replace.tflite") + replace_file = ExtractPaths.tflite.replace(output_path) cut_model( model_file, input_names=input_names, @@ -61,7 +106,7 @@ def extract( output_file=replace_file, ) - end_file = os.path.join(output_path, "end.tflite") + end_file = ExtractPaths.tflite.end(output_path) cut_model( model_file, input_names=output_names, @@ -71,7 +116,7 @@ def extract( ) if not skip_outputs: - output_tfrec = os.path.join(output_path, "output.tfrec") + output_tfrec = ExtractPaths.tfrec.output(output_path) record_model( input_tfrec, replace_file, @@ -79,9 +124,10 @@ def extract( show_progress=show_progress, num_procs=num_procs, num_threads=num_threads, + dequantize_output=dequantize_output, ) - end_tfrec = os.path.join(output_path, "end.tfrec") + end_tfrec = ExtractPaths.tfrec.end(output_path) record_model( output_tfrec, end_file, @@ -89,4 +135,5 @@ def extract( show_progress=show_progress, num_procs=num_procs, num_threads=num_threads, + dequantize_output=dequantize_output, ) diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py index 13a5268..53d5389 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/cut.py +++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py @@ -1,9 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cut module.""" +from __future__ import annotations + import os from collections import defaultdict -from typing import Optional +from pathlib import Path import tensorflow as tf from tensorflow.lite.python.schema_py_generated import ModelT @@ -25,8 +27,8 @@ def tensors_by_name(subgraph: SubGraphT, names: list) -> list: def cut_subgraph( subgraph: SubGraphT, - input_tensor_names: Optional[list], - output_tensor_names: Optional[list], + input_tensor_names: list | None, + output_tensor_names: list | None, ) -> None: """Change the global inputs and outputs of a graph to the provided named tensors.""" if input_tensor_names is not None: @@ -131,11 +133,11 @@ def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple: def cut_model( - model_file: str, - input_names: Optional[list], - output_names: Optional[list], + model_file: str | Path, + input_names: list | None, + output_names: list | None, subgraph_index: int, - output_file: str, + output_file: str | Path, ) -> None: """Cut subgraphs and simplify a given model.""" model = load_fb(model_file) diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index 90f3db8..f85433d 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -6,6 +6,7 @@ from __future__ import annotations import math import os +from contextlib import ExitStack from pathlib import Path import tensorflow as tf @@ -15,11 +16,22 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.tensorflow.config import NameToTensorMap os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +DEQUANT_SUFFIX = "_dequant" + + +def dequantized_path(filename: str | Path) -> Path: + """Append the de-quantization suffix to the given filename.""" + path = Path(filename) + path = Path(path.parent, f"{path.stem}{DEQUANT_SUFFIX}{path.suffix}") + return path + + def record_model( input_filename: str | Path, model_filename: str | Path, @@ -28,11 +40,14 @@ def record_model( show_progress: bool = False, num_procs: int = 1, num_threads: int = 0, + dequantize_output: bool = False, ) -> None: """Model recorder. num_procs: 0 => detect real cores on system num_threads: 0 => TFLite impl. specific setting, usually 3 + + dequantize: True => de-quantize the recorded output before saving """ model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size) if not batch_size: @@ -51,22 +66,38 @@ def record_model( dataset = dataset.batch(batch_size, drop_remainder=False) total = int(math.ceil(total / batch_size)) - with NumpyTFWriter(output_filename) as writer: - for _, named_x in enumerate( - track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) - ): - named_y = model(named_x) + with ExitStack() as stack: + writer = stack.enter_context(NumpyTFWriter(output_filename)) + writer_dequant = None + if dequantize_output: + dequant_path = dequantized_path(output_filename) + writer_dequant = stack.enter_context(NumpyTFWriter(dequant_path)) + + def write(writer: NumpyTFWriter, data: NameToTensorMap) -> None: + """Write the data using the given NumpyTFWriter instance.""" if batch_size > 1: for i in range(batch_size): # Expand the batches and recreate each dict as a # batch-size 1 item for the tfrec output recreated_dict = { k: v[i : i + 1] # noqa: E203 - for k, v in named_y.items() + for k, v in data.items() if i < v.shape[0] } if recreated_dict: writer.write(recreated_dict) else: - writer.write(named_y) + writer.write(data) + + for _, named_x in enumerate( + track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) + ): + named_y = model(named_x) + write(writer, named_y) + + if dequantize_output: + assert writer_dequant + named_y_dequant = model.dequantize_outputs(named_y) + write(writer_dequant, named_y_dequant) + model.close() diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 82af747..6345f07 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -22,9 +22,11 @@ from typing import Literal import numpy as np import tensorflow as tf +import tensorflow_model_optimization as tfmot from numpy.random import Generator from mlia.nn.rewrite.core.extract import extract +from mlia.nn.rewrite.core.extract import ExtractPaths from mlia.nn.rewrite.core.graph_edit.diff import diff_stats from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.graph_edit.record import record_model @@ -34,6 +36,7 @@ from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb +from mlia.nn.tensorflow.utils import get_tflite_converter from mlia.utils.logging import log_action @@ -91,6 +94,7 @@ def train( input_tfrec, input_tensors, output_tensors, + dequantize_output=True, ) else: unmodified_model_dir = None @@ -106,6 +110,7 @@ def train( output_tensors, num_procs=train_params.num_procs, num_threads=train_params.num_threads, + dequantize_output=True, ) tflite_filenames = train_in_dir( @@ -160,7 +165,10 @@ def train( def eval_in_dir( - target_dir: str, new_part: str, num_procs: int = 1, num_threads: int = 0 + target_dir: str, + new_part: str, + num_procs: int = 1, + num_threads: int = 0, ) -> tuple: """Evaluate a model in a given directory.""" model_input_path = Path(target_dir, "input_orig.tfrec") @@ -168,12 +176,12 @@ def eval_in_dir( model_input = ( model_input_path if model_input_path.exists() - else Path(target_dir, "input.tfrec") + else ExtractPaths.tfrec.input(target_dir, False) ) output = ( model_output_path if model_output_path.exists() - else Path(target_dir, "output.tfrec") + else ExtractPaths.tfrec.output(target_dir, False) ) with tempfile.TemporaryDirectory() as tmp_dir: @@ -194,8 +202,8 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None: """Join two models in a given directory.""" with tempfile.TemporaryDirectory() as tmp_dir: new_end = Path(tmp_dir, "new_end.tflite") - join_models(new_part, Path(model_dir, "end.tflite"), new_end) - join_models(Path(model_dir, "start.tflite"), new_end, output_model) + join_models(new_part, ExtractPaths.tflite.end(model_dir), new_end) + join_models(ExtractPaths.tflite.start(model_dir), new_end, output_model) def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]: @@ -244,7 +252,9 @@ def set_up_data_pipeline( input_name, output_name = _get_io_tensors(teacher) - input_filename = Path(train_dir, "input.tfrec") + model_is_quantized = replace.is_tensor_quantized(name=input_name) + + input_filename = ExtractPaths.tfrec.input(train_dir, model_is_quantized) total = numpytf_count(str(input_filename)) dict_inputs = numpytf_read(str(input_filename)) @@ -264,13 +274,13 @@ def set_up_data_pipeline( if any(augmentations): # Map the teacher inputs here because the augmentation stage passes these # through a TFLite model to get the outputs - teacher_outputs = numpytf_read(str(Path(teacher_dir, "input.tfrec"))).map( - lambda d: tf.squeeze(d[input_name], axis=0) - ) + teacher_outputs = numpytf_read( + str(ExtractPaths.tfrec.input(teacher_dir, model_is_quantized)) + ).map(lambda d: tf.squeeze(d[input_name], axis=0)) else: - teacher_outputs = numpytf_read(str(Path(teacher_dir, "output.tfrec"))).map( - lambda d: tf.squeeze(d[output_name], axis=0) - ) + teacher_outputs = numpytf_read( + str(ExtractPaths.tfrec.output(teacher_dir, model_is_quantized)) + ).map(lambda d: tf.squeeze(d[output_name], axis=0)) dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) if epochs > 1: @@ -285,7 +295,23 @@ def set_up_data_pipeline( ) -> tuple: """Return results of train and teach based on augmentations.""" augmented_train = augment_train({input_name: train})[input_name] - augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name] + # If augmentation of the input is enabled, we need to re-generate + # the reference output by running the augmented data through the + # teacher model. + if model_is_quantized: + # If the input model is quantized we have to additionally + # - quantize the augmented data before running it through the + # (quantized) teacher model + # - de-quantize the output for the training. + augmented_teach = teacher.dequantize_outputs( + teacher( + teacher.quantize_inputs(augment_teacher({input_name: teach})) + ) + )[output_name] + else: + augmented_teach = teacher(augment_teacher({input_name: teach}))[ + output_name + ] return (augmented_train, augmented_teach) dataset = dataset.map( @@ -329,15 +355,20 @@ def train_in_dir( """ teacher_dir = baseline_dir if baseline_dir else train_dir teacher = ParallelTFLiteModel( - f"{teacher_dir}/replace.tflite", + ExtractPaths.tflite.replace(teacher_dir), train_params.num_procs, train_params.num_threads, batch_size=train_params.batch_size, ) - replace = TFLiteModel(f"{train_dir}/replace.tflite") + replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir)) input_name, output_name = _get_io_tensors(teacher) + model_is_quantized = replace.is_tensor_quantized(name=input_name) + + if model_is_quantized: + replace.check_datatypes(np.int8) + dataset = set_up_data_pipeline( teacher, replace, @@ -354,6 +385,8 @@ def train_in_dir( optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = tf.keras.losses.MeanSquaredError() + if model_is_quantized: + model = tfmot.quantization.keras.quantize_model(model) model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) logger.info(model.summary()) @@ -432,6 +465,7 @@ def train_in_dir( replace.shape_from_name[input_name], output_name, replace.shape_from_name[output_name], + model_is_quantized, ) output_filenames.append(checkpoint_filename) @@ -446,6 +480,7 @@ def save_as_tflite( input_shape: list, output_name: str, output_shape: list, + model_is_quantized: bool = False, ) -> None: """Save Keras model as TFLite file.""" @@ -464,7 +499,7 @@ def save_as_tflite( keras_model.input.set_shape(orig_shape) with fixed_input(keras_model, input_shape) as fixed_model: - converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model) + converter = get_tflite_converter(fixed_model, quantized=model_is_quantized) tflite_model = converter.convert() with open(filename, "wb") as file: diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index c6a7c88..b94350a 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -7,13 +7,23 @@ import logging import tempfile from collections import defaultdict from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List import numpy as np import tensorflow as tf from mlia.core.context import Context +from mlia.nn.tensorflow.optimizations.quantization import dequantize +from mlia.nn.tensorflow.optimizations.quantization import is_quantized +from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters +from mlia.nn.tensorflow.optimizations.quantization import quantize from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb +from mlia.nn.tensorflow.utils import check_tflite_datatypes from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_saved_model @@ -71,6 +81,11 @@ class KerasModel(ModelConfiguration): return self +TFLiteIODetails = Dict[str, Dict[str, Any]] +TFLiteIODetailsList = List[TFLiteIODetails] +NameToTensorMap = Dict[str, np.ndarray] + + class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method """TensorFlow Lite model configuration.""" @@ -119,7 +134,22 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method self.shape_from_name = {d["name"]: d["shape"] for d in details} self.batch_size = next(iter(self.shape_from_name.values()))[0] - def __call__(self, named_input: dict) -> dict: + # Prepare quantization parameters for input and output + def named_quant_params( + details: TFLiteIODetailsList, + ) -> dict[str, QuantizationParameters]: + return { + str(detail["name"]): QuantizationParameters( + **detail["quantization_parameters"] + ) + for detail in details + if TFLiteModel._is_tensor_quantized(detail) + } + + self._quant_params_input = named_quant_params(self.input_details) + self._quant_params_output = named_quant_params(self.output_details) + + def __call__(self, named_input: dict) -> NameToTensorMap: """Execute the model on one or a batch of named inputs \ (a dict of name: numpy array).""" input_len = next(iter(named_input.values())).shape[0] @@ -150,11 +180,11 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method ) return {k: np.concatenate(v) for k, v in named_ys.items()} - def input_tensors(self) -> list: + def input_tensors(self) -> list[str]: """Return name from input details.""" return [d["name"] for d in self.input_details] - def output_tensors(self) -> list: + def output_tensors(self) -> list[str]: """Return name from output details.""" return [d["name"] for d in self.output_details] @@ -164,6 +194,71 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method """Convert model to TensorFlow Lite format.""" return self + def _tensor_details( + self, name: str | None = None, idx: int | None = None + ) -> TFLiteIODetails: + """Get the details of the tensor by name or index.""" + if idx is not None: + details = self.interpreter.get_tensor_details()[idx] + assert details["index"] == idx + elif name is not None: + for details_ in self.interpreter.get_tensor_details(): + if name == details_["name"]: + details = details_ + break + else: + raise NameError( + f"Tensor '{name}' not found in model {self.model_path}." + ) + else: + raise ValueError("Either tensor name or index needs to be passed.") + + assert isinstance(details, dict) + return cast(TFLiteIODetails, details) + + @staticmethod + def _is_tensor_quantized(details: TFLiteIODetails) -> bool: + """Use tensor details to check if the corresponding tensor is quantized.""" + quant_params = QuantizationParameters(**details["quantization_parameters"]) + return is_quantized(quant_params) + + def is_tensor_quantized( + self, + name: str | None = None, + idx: int | None = None, + ) -> bool: + """Check if the given tensor (identified by name or index) is quantized.""" + details = self._tensor_details(name, idx) + return self._is_tensor_quantized(details) + + def check_datatypes(self, *allowed_types: type) -> None: + """Check if the model only has the given allowed datatypes.""" + check_tflite_datatypes(self.model_path, *allowed_types) + + @staticmethod + def _quant_dequant( + tensors: NameToTensorMap, + quant_params: dict[str, QuantizationParameters], + func: Callable, + ) -> NameToTensorMap: + """Quantize/de-quantize tensor using the given parameters and function.""" + return { + name: (func(tensor, quant_params[name]) if name in quant_params else tensor) + for name, tensor in tensors.items() + } + + def dequantize_outputs(self, outputs: NameToTensorMap) -> NameToTensorMap: + """De-quantize the given model outputs.""" + dequant_outputs = self._quant_dequant( + outputs, self._quant_params_output, dequantize + ) + return dequant_outputs + + def quantize_inputs(self, inputs: NameToTensorMap) -> NameToTensorMap: + """Quantize the given model inputs.""" + quant_inputs = self._quant_dequant(inputs, self._quant_params_input, quantize) + return quant_inputs + class TfModel(ModelConfiguration): # pylint: disable=abstract-method """TensorFlow model configuration. diff --git a/src/mlia/nn/tensorflow/optimizations/quantization.py b/src/mlia/nn/tensorflow/optimizations/quantization.py new file mode 100644 index 0000000..4e3c2c2 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/quantization.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contains functionality for quantization and de-quantization.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import cast + +import numpy as np +import tensorflow as tf + + +@dataclass +class QuantizationParameters: + """Collection of TensorFlow Lite quantization parameters. + + Can be directly initialized from TensorFlow Lite tensor details, e.g. + + ``` + QuantizationParameters( + **interpreter.get_input_details()[0]["quantization_parameters"] + ) + ``` + """ + + scales: np.ndarray + zero_points: np.ndarray + quantized_dimension: int + + +def is_quantized(quant_params: QuantizationParameters) -> bool: + """Check if the quantization parameters describe a quantized tensor.""" + return quant_params.scales.size > 0 + + +def dequantize( + quantized_tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters +) -> np.ndarray: + """De-quantize the input tensor using the given quantization parameters.""" + assert isinstance(quantized_tensor, (tf.Tensor, np.ndarray)) + assert ( + not isinstance(quantized_tensor, tf.Tensor) + or quantized_tensor.dtype.is_quantized + ) and ( + not isinstance(quantized_tensor, np.ndarray) + or issubclass(quantized_tensor.dtype.type, np.integer) + ), ( + f"Input tensor for de-quantization is of type {quantized_tensor.dtype}, " + "but it must be int." + ) + + dequantized_tensor = np.subtract( + quantized_tensor, quant_params.zero_points, dtype=np.float32 + ) + dequantized_tensor = np.multiply( + dequantized_tensor, quant_params.scales, dtype=np.float32 + ) + return dequantized_tensor + + +def quantize( + tensor: np.ndarray | tf.Tensor, quant_params: QuantizationParameters +) -> np.ndarray: + """Quantize the given float input tensor to int8.""" + assert isinstance(tensor, (tf.Tensor, np.ndarray)) + assert (not isinstance(tensor, tf.Tensor) or tensor.dtype.is_floating) and ( + not isinstance(tensor, np.ndarray) or issubclass(tensor.dtype.type, np.floating) + ), f"Input tensor for quantization is of type {tensor.dtype}, but it must be float." + + quantized_tensor = (tensor / quant_params.scales) + quant_params.zero_points + quantized_tensor = np.clip( # type: ignore + quantized_tensor, -128, 127, dtype=np.int8, casting="unsafe" + ) + return cast(np.ndarray, quantized_tensor) diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py index 77ac529..b8d45c6 100644 --- a/src/mlia/nn/tensorflow/utils.py +++ b/src/mlia/nn/tensorflow/utils.py @@ -123,3 +123,33 @@ def get_tflite_converter( converter.inference_output_type = tf.int8 return converter + + +def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]: + """Get type map from tflite model.""" + model_type_map: dict[str, Any] = {} + interpreter = tf.lite.Interpreter(str(model_filename)) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + model_type_map = { + input_detail["name"]: input_detail["dtype"] for input_detail in input_details + } + return model_type_map + + +def check_tflite_datatypes(model_filename: str | Path, *allowed_types: type) -> None: + """Check if the model only has the given allowed datatypes.""" + type_map = get_tflite_model_type_map(model_filename) + types = set(type_map.values()) + allowed = set(allowed_types) + unexpected = types - allowed + + def cls_to_str(types: set[type]) -> list[str]: + return [t.__name__ for t in types] + + if len(unexpected) > 0: + raise TypeError( + f"Model {model_filename} has " + f"unexpected data types: {cls_to_str(unexpected)}. " + f"Only {cls_to_str(allowed)} are allowed." + ) diff --git a/tests/test_nn_rewrite_core_extract.py b/tests/test_nn_rewrite_core_extract.py new file mode 100644 index 0000000..09eca77 --- /dev/null +++ b/tests/test_nn_rewrite_core_extract.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.core.extract.""" +from __future__ import annotations + +from pathlib import Path +from typing import Any +from typing import Iterable + +import pytest + +from mlia.nn.rewrite.core.extract import ExtractPaths +from mlia.nn.rewrite.core.graph_edit.record import DEQUANT_SUFFIX + + +@pytest.mark.parametrize("dir_path", ("/dev/null", Path("/dev/null"))) +@pytest.mark.parametrize("model_is_quantized", (False, True)) +@pytest.mark.parametrize( + ("obj", "func_names", "suffix"), + ( + (ExtractPaths.tflite, ("start", "replace", "end"), ".tflite"), + (ExtractPaths.tfrec, ("input", "output", "end"), ".tfrec"), + ), +) +def test_extract_paths( + dir_path: str | Path, + model_is_quantized: bool, + obj: Any, + func_names: Iterable[str], + suffix: str, +) -> None: + """Test class ExtractPaths.""" + for func_name in func_names: + func = getattr(obj, func_name) + path = func(dir_path, model_is_quantized) + assert path == Path(dir_path, path.relative_to(dir_path)) + assert path.suffix == suffix + assert not model_is_quantized or path.stem.endswith(DEQUANT_SUFFIX) diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py index 41b9c50..422b53e 100644 --- a/tests/test_nn_rewrite_core_graph_edit_record.py +++ b/tests/test_nn_rewrite_core_graph_edit_record.py @@ -3,43 +3,57 @@ """Tests for module mlia.nn.rewrite.graph_edit.record.""" from pathlib import Path +import numpy as np +import pytest import tensorflow as tf +from mlia.nn.rewrite.core.extract import ExtractPaths from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read +def data_matches_outputs( + name: str, + tensor: tf.Tensor, + model_outputs: list, + dequantized_output: bool, +) -> bool: + """Check that the name and the tensor match any of the model outputs.""" + for model_output in model_outputs: + if model_output["name"] == name: + # If the name is a match, tensor shape and type have to match! + tensor_shape = tensor.shape.as_list() + tensor_type = tensor.dtype.as_numpy_dtype + return all( + ( + tensor_shape == model_output["shape"].tolist(), + tensor_type == np.float32 + if dequantized_output + else model_output["dtype"], + ) + ) + return False + + def check_record_model( test_tflite_model: Path, tmp_path: Path, test_tfrecord: Path, batch_size: int, + dequantize_output: bool, ) -> None: """Test the function record_model().""" - output_file = tmp_path / "out.tfrecord" + output_file = ExtractPaths.tfrec.output(tmp_path) record_model( input_filename=str(test_tfrecord), model_filename=str(test_tflite_model), output_filename=str(output_file), batch_size=batch_size, + dequantize_output=dequantize_output, ) + output_file = ExtractPaths.tfrec.output(tmp_path, dequantize_output) assert output_file.is_file() - def data_matches_outputs(name: str, tensor: tf.Tensor, model_outputs: list) -> bool: - """Check that the name and the tensor match any of the model outputs.""" - for model_output in model_outputs: - if model_output["name"] == name: - # If the name is a match, tensor shape and type have to match! - tensor_shape = tensor.shape.as_list() - tensor_type = tensor.dtype.as_numpy_dtype - return all( - ( - tensor_shape == model_output["shape"].tolist(), - tensor_type == model_output["dtype"], - ) - ) - return False - # Now load model and the data and make sure that the written data matches # any of the model outputs interpreter = tf.lite.Interpreter(str(test_tflite_model)) @@ -47,4 +61,19 @@ def check_record_model( dataset = numpytf_read(str(output_file)) for data in dataset: for name, tensor in data.items(): - assert data_matches_outputs(name, tensor, model_outputs) + assert data_matches_outputs(name, tensor, model_outputs, dequantize_output) + + +@pytest.mark.parametrize("batch_size", (None, 1, 2)) +@pytest.mark.parametrize("dequantize_output", (True, False)) +def test_record_model( + test_tflite_model: Path, + tmp_path: Path, + test_tfrecord: Path, + batch_size: int, + dequantize_output: bool, +) -> None: + """Test the function record_model().""" + check_record_model( + test_tflite_model, tmp_path, test_tfrecord, batch_size, dequantize_output + ) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index b001a09..ef52320 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Tests for module mlia.nn.rewrite.train.""" +"""Tests for module mlia.nn.rewrite.core.train.""" # pylint: disable=too-many-arguments from __future__ import annotations @@ -47,10 +47,11 @@ def check_train( tfrecord: Path, train_params: TrainingParameters = TestTrainingParameters(), use_unmodified_model: bool = False, + quantized: bool = False, ) -> None: """Test the train() function.""" with TemporaryDirectory() as tmp_dir: - output_file = Path(tmp_dir, "out.tfrecord") + output_file = Path(tmp_dir, "out.tflite") result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, @@ -65,6 +66,17 @@ def check_train( assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}" assert output_file.is_file() + if quantized: + interpreter = tf.lite.Interpreter(model_path=str(output_file)) + interpreter.allocate_tensors() + # Check that the quantization parameters are non-zero + assert all(interpreter.get_output_details()[0]["quantization"]) + assert all(interpreter.get_input_details()[0]["quantization"]) + dtypes = [] + for tensor_detail in interpreter.get_tensor_details(): + dtypes.append(tensor_detail["dtype"]) + assert all(np.issubdtype(dtype, np.integer) for dtype in dtypes) + @pytest.mark.parametrize( ( @@ -89,7 +101,7 @@ def check_train( ), ), ) -def test_train( +def test_train_fp32( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, batch_size: int, @@ -114,6 +126,55 @@ def test_train( ) +@pytest.mark.parametrize( + ( + "batch_size", + "show_progress", + "augmentation_preset", + "lr_schedule", + "use_unmodified_model", + "num_procs", + ), + ( + (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2), + (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1), + (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0), + ( + 1, + False, + AUGMENTATION_PRESETS["mix_gaussian_large"], + "cosine", + False, + 2, + ), + ), +) +def test_train_int8( + test_tflite_model: Path, + test_tfrecord: Path, + batch_size: int, + show_progress: bool, + augmentation_preset: tuple[float | None, float | None], + lr_schedule: LearningRateSchedule, + use_unmodified_model: bool, + num_procs: int, +) -> None: + """Test the train() function with valid parameters.""" + check_train( + tflite_model=test_tflite_model, + tfrecord=test_tfrecord, + train_params=TestTrainingParameters( + batch_size=batch_size, + show_progress=show_progress, + augmentations=augmentation_preset, + learning_rate_schedule=lr_schedule, + num_procs=num_procs, + ), + use_unmodified_model=use_unmodified_model, + quantized=True, + ) + + def test_train_invalid_schedule( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py index 48aec0a..fff3857 100644 --- a/tests/test_nn_tensorflow_config.py +++ b/tests/test_nn_tensorflow_config.py @@ -111,3 +111,15 @@ def test_tflite_model_call( for named_input in data.as_numpy_iterator(): res = model(named_input) assert res + + +def test_tflite_model_is_tensor_quantized(test_tflite_model: Path) -> None: + """Test function TFLiteModel.is_tensor_quantized().""" + model = TFLiteModel(test_tflite_model) + input_details = model.input_details[0] + assert model.is_tensor_quantized(name=input_details["name"]) + assert model.is_tensor_quantized(idx=input_details["index"]) + with pytest.raises(ValueError): + assert model.is_tensor_quantized() + with pytest.raises(NameError): + assert model.is_tensor_quantized(name="NAME_DOES_NOT_EXIST") diff --git a/tests/test_nn_tensorflow_optimizations_quantization.py b/tests/test_nn_tensorflow_optimizations_quantization.py new file mode 100644 index 0000000..5228cec --- /dev/null +++ b/tests/test_nn_tensorflow_optimizations_quantization.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module optimizations/quantization.""" +from __future__ import annotations + +from itertools import chain +from pathlib import Path +from typing import Generator + +import numpy as np +from numpy.core.numeric import isclose + +from mlia.nn.tensorflow.config import TFLiteModel +from mlia.nn.tensorflow.optimizations.quantization import dequantize +from mlia.nn.tensorflow.optimizations.quantization import is_quantized +from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters +from mlia.nn.tensorflow.optimizations.quantization import quantize + + +def model_io_quant_params(model_path: Path) -> Generator: + """Generate QuantizationParameters for all model inputs and outputs.""" + model = TFLiteModel(model_path=model_path) + for details in chain(model.input_details, model.output_details): + yield QuantizationParameters(**details["quantization_parameters"]) + + +def test_is_quantized(test_tflite_model: Path) -> None: + """Test function is_quantized() with a quantized model.""" + for quant_params in model_io_quant_params(test_tflite_model): + assert is_quantized(quant_params) + + +def test_is_not_quantized(test_tflite_model_fp32: Path) -> None: + """Test function is_quantized() with an unquantized model.""" + for quant_params in model_io_quant_params(test_tflite_model_fp32): + assert not is_quantized(quant_params) + + +def test_quantize() -> None: + """Test function quantize().""" + ref_dequant = np.array((0.0, 0.1, 0.2, 0.3)) + ref_quant = np.array((0, 10, 20, 30), dtype=np.int8) + quant_params = QuantizationParameters( + scales=np.array([0.01]), zero_points=np.array([0.0]), quantized_dimension=0 + ) + + quant = quantize(ref_dequant, quant_params) + assert quant.dtype == np.int8 + assert np.all(quant == ref_quant) + + dequant = dequantize(quant, quant_params) + assert dequant.dtype == np.float32 + assert np.all(isclose(dequant, ref_dequant, atol=0.03)) diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py index 14b06c4..dab8b4e 100644 --- a/tests/test_nn_tensorflow_utils.py +++ b/tests/test_nn_tensorflow_utils.py @@ -1,14 +1,17 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Test for module utils/test_utils.""" +import re from pathlib import Path import numpy as np import pytest import tensorflow as tf +from mlia.nn.tensorflow.utils import check_tflite_datatypes from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import get_tf_tensor_shape +from mlia.nn.tensorflow.utils import get_tflite_model_type_map from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_tflite_model from mlia.nn.tensorflow.utils import representative_dataset @@ -109,3 +112,31 @@ def test_is_keras_model(model_path: Path, expected_result: bool) -> None: def test_get_tf_tensor_shape(test_tf_model: Path) -> None: """Test get_tf_tensor_shape with test model.""" assert get_tf_tensor_shape(str(test_tf_model)) == [1, 28, 28, 1] + + +def test_tflite_model_type_map( + test_tflite_model_fp32: Path, test_tflite_model: Path +) -> None: + """Test the model type map function.""" + assert get_tflite_model_type_map(test_tflite_model_fp32) == { + "serving_default_input:0": np.float32 + } + assert get_tflite_model_type_map(test_tflite_model) == { + "serving_default_input:0": np.int8 + } + + +def test_check_tflite_datatypes( + test_tflite_model_fp32: Path, test_tflite_model: Path +) -> None: + """Test the model type map function.""" + check_tflite_datatypes(test_tflite_model_fp32, np.float32) + check_tflite_datatypes(test_tflite_model, np.int8) + + with pytest.raises( + Exception, + match=re.escape( + "unexpected data types: ['float32']. Only ['int8'] are allowed" + ), + ): + check_tflite_datatypes(test_tflite_model_fp32, np.int8) |