From ecc4264b93d4a89fa2cb40518b225d8371b7ffad Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Wed, 12 Jul 2023 15:18:26 +0100 Subject: Enable rewrites for quantized input models If the input model for rewriting is quantized: - Record de-quantized TFRecords - enable writing de-quantized calibration data for the training - re-generate augmented training data, if needed - Use quantization-aware training (QAT) to train the replacement models - Check if replacement model is quantized: If source model is quantized, we make sure rewrite's output model is quantized too. Right now, only int8 is supported so raising an error if any other datatype is present in the output. Resolves: MLIA-907, MLIA-908, MLIA-927 Signed-off-by: Benjamin Klimczak Change-Id: Icb4070a9e6f1fdb5ce36120d73823986e89ac955 --- src/mlia/nn/rewrite/core/extract.py | 61 +++++++++++++++++++++--- src/mlia/nn/rewrite/core/graph_edit/cut.py | 16 ++++--- src/mlia/nn/rewrite/core/graph_edit/record.py | 45 +++++++++++++++--- src/mlia/nn/rewrite/core/train.py | 67 ++++++++++++++++++++------- 4 files changed, 152 insertions(+), 37 deletions(-) (limited to 'src/mlia/nn/rewrite') diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py index f609955..4fcf735 100644 --- a/src/mlia/nn/rewrite/core/extract.py +++ b/src/mlia/nn/rewrite/core/extract.py @@ -2,19 +2,62 @@ # SPDX-License-Identifier: Apache-2.0 """Extract module.""" # pylint: disable=too-many-arguments, too-many-locals +from __future__ import annotations + import os +from functools import partial +from pathlib import Path import tensorflow as tf from tensorflow.lite.python.schema_py_generated import SubGraphT from mlia.nn.rewrite.core.graph_edit.cut import cut_model +from mlia.nn.rewrite.core.graph_edit.record import dequantized_path from mlia.nn.rewrite.core.graph_edit.record import record_model - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +def _get_path( + ext: str, name: str, dir_path: str | Path, model_is_quantized: bool = False +) -> Path: + """Create a file path for extracted files.""" + path = Path(dir_path, f"{name}{ext}") + return dequantized_path(path) if model_is_quantized else path + + +class TFLitePaths: # pylint: disable=too-few-public-methods + """Provide safe access to TensorFlow Lite file paths.""" + + _get_path_tflite = partial(_get_path, ".tflite") + + start = partial(_get_path_tflite, "start") + replace = partial(_get_path_tflite, "replace") + end = partial(_get_path_tflite, "end") + + +class TFRecordPaths: # pylint: disable=too-few-public-methods + """Provide safe access to tfrec file paths.""" + + _get_path_tfrec = partial(_get_path, ".tfrec") + + input = partial(_get_path_tfrec, "input") + output = partial(_get_path_tfrec, "output") + end = partial(_get_path_tfrec, "end") + + +class ExtractPaths: # pylint: disable=too-few-public-methods + """Get paths to extract files. + + This is meant to be the single source of truth regarding all file names + created by the extract() function in an output directory. + """ + + tflite = TFLitePaths + tfrec = TFRecordPaths + + def extract( output_path: str, model_file: str, @@ -26,6 +69,7 @@ def extract( show_progress: bool = False, num_procs: int = 1, num_threads: int = 0, + dequantize_output: bool = False, ) -> None: """Extract a model after cut and record.""" try: @@ -33,7 +77,7 @@ def extract( except FileExistsError: pass - start_file = os.path.join(output_path, "start.tflite") + start_file = ExtractPaths.tflite.start(output_path) cut_model( model_file, input_names=None, @@ -42,7 +86,7 @@ def extract( output_file=start_file, ) - input_tfrec = os.path.join(output_path, "input.tfrec") + input_tfrec = ExtractPaths.tfrec.input(output_path) record_model( input_filename, start_file, @@ -50,9 +94,10 @@ def extract( show_progress=show_progress, num_procs=num_procs, num_threads=num_threads, + dequantize_output=dequantize_output, ) - replace_file = os.path.join(output_path, "replace.tflite") + replace_file = ExtractPaths.tflite.replace(output_path) cut_model( model_file, input_names=input_names, @@ -61,7 +106,7 @@ def extract( output_file=replace_file, ) - end_file = os.path.join(output_path, "end.tflite") + end_file = ExtractPaths.tflite.end(output_path) cut_model( model_file, input_names=output_names, @@ -71,7 +116,7 @@ def extract( ) if not skip_outputs: - output_tfrec = os.path.join(output_path, "output.tfrec") + output_tfrec = ExtractPaths.tfrec.output(output_path) record_model( input_tfrec, replace_file, @@ -79,9 +124,10 @@ def extract( show_progress=show_progress, num_procs=num_procs, num_threads=num_threads, + dequantize_output=dequantize_output, ) - end_tfrec = os.path.join(output_path, "end.tfrec") + end_tfrec = ExtractPaths.tfrec.end(output_path) record_model( output_tfrec, end_file, @@ -89,4 +135,5 @@ def extract( show_progress=show_progress, num_procs=num_procs, num_threads=num_threads, + dequantize_output=dequantize_output, ) diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py index 13a5268..53d5389 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/cut.py +++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py @@ -1,9 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cut module.""" +from __future__ import annotations + import os from collections import defaultdict -from typing import Optional +from pathlib import Path import tensorflow as tf from tensorflow.lite.python.schema_py_generated import ModelT @@ -25,8 +27,8 @@ def tensors_by_name(subgraph: SubGraphT, names: list) -> list: def cut_subgraph( subgraph: SubGraphT, - input_tensor_names: Optional[list], - output_tensor_names: Optional[list], + input_tensor_names: list | None, + output_tensor_names: list | None, ) -> None: """Change the global inputs and outputs of a graph to the provided named tensors.""" if input_tensor_names is not None: @@ -131,11 +133,11 @@ def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple: def cut_model( - model_file: str, - input_names: Optional[list], - output_names: Optional[list], + model_file: str | Path, + input_names: list | None, + output_names: list | None, subgraph_index: int, - output_file: str, + output_file: str | Path, ) -> None: """Cut subgraphs and simplify a given model.""" model = load_fb(model_file) diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index 90f3db8..f85433d 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -6,6 +6,7 @@ from __future__ import annotations import math import os +from contextlib import ExitStack from pathlib import Path import tensorflow as tf @@ -15,11 +16,22 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.tensorflow.config import NameToTensorMap os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +DEQUANT_SUFFIX = "_dequant" + + +def dequantized_path(filename: str | Path) -> Path: + """Append the de-quantization suffix to the given filename.""" + path = Path(filename) + path = Path(path.parent, f"{path.stem}{DEQUANT_SUFFIX}{path.suffix}") + return path + + def record_model( input_filename: str | Path, model_filename: str | Path, @@ -28,11 +40,14 @@ def record_model( show_progress: bool = False, num_procs: int = 1, num_threads: int = 0, + dequantize_output: bool = False, ) -> None: """Model recorder. num_procs: 0 => detect real cores on system num_threads: 0 => TFLite impl. specific setting, usually 3 + + dequantize: True => de-quantize the recorded output before saving """ model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size) if not batch_size: @@ -51,22 +66,38 @@ def record_model( dataset = dataset.batch(batch_size, drop_remainder=False) total = int(math.ceil(total / batch_size)) - with NumpyTFWriter(output_filename) as writer: - for _, named_x in enumerate( - track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) - ): - named_y = model(named_x) + with ExitStack() as stack: + writer = stack.enter_context(NumpyTFWriter(output_filename)) + writer_dequant = None + if dequantize_output: + dequant_path = dequantized_path(output_filename) + writer_dequant = stack.enter_context(NumpyTFWriter(dequant_path)) + + def write(writer: NumpyTFWriter, data: NameToTensorMap) -> None: + """Write the data using the given NumpyTFWriter instance.""" if batch_size > 1: for i in range(batch_size): # Expand the batches and recreate each dict as a # batch-size 1 item for the tfrec output recreated_dict = { k: v[i : i + 1] # noqa: E203 - for k, v in named_y.items() + for k, v in data.items() if i < v.shape[0] } if recreated_dict: writer.write(recreated_dict) else: - writer.write(named_y) + writer.write(data) + + for _, named_x in enumerate( + track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) + ): + named_y = model(named_x) + write(writer, named_y) + + if dequantize_output: + assert writer_dequant + named_y_dequant = model.dequantize_outputs(named_y) + write(writer_dequant, named_y_dequant) + model.close() diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 82af747..6345f07 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -22,9 +22,11 @@ from typing import Literal import numpy as np import tensorflow as tf +import tensorflow_model_optimization as tfmot from numpy.random import Generator from mlia.nn.rewrite.core.extract import extract +from mlia.nn.rewrite.core.extract import ExtractPaths from mlia.nn.rewrite.core.graph_edit.diff import diff_stats from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.graph_edit.record import record_model @@ -34,6 +36,7 @@ from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb +from mlia.nn.tensorflow.utils import get_tflite_converter from mlia.utils.logging import log_action @@ -91,6 +94,7 @@ def train( input_tfrec, input_tensors, output_tensors, + dequantize_output=True, ) else: unmodified_model_dir = None @@ -106,6 +110,7 @@ def train( output_tensors, num_procs=train_params.num_procs, num_threads=train_params.num_threads, + dequantize_output=True, ) tflite_filenames = train_in_dir( @@ -160,7 +165,10 @@ def train( def eval_in_dir( - target_dir: str, new_part: str, num_procs: int = 1, num_threads: int = 0 + target_dir: str, + new_part: str, + num_procs: int = 1, + num_threads: int = 0, ) -> tuple: """Evaluate a model in a given directory.""" model_input_path = Path(target_dir, "input_orig.tfrec") @@ -168,12 +176,12 @@ def eval_in_dir( model_input = ( model_input_path if model_input_path.exists() - else Path(target_dir, "input.tfrec") + else ExtractPaths.tfrec.input(target_dir, False) ) output = ( model_output_path if model_output_path.exists() - else Path(target_dir, "output.tfrec") + else ExtractPaths.tfrec.output(target_dir, False) ) with tempfile.TemporaryDirectory() as tmp_dir: @@ -194,8 +202,8 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None: """Join two models in a given directory.""" with tempfile.TemporaryDirectory() as tmp_dir: new_end = Path(tmp_dir, "new_end.tflite") - join_models(new_part, Path(model_dir, "end.tflite"), new_end) - join_models(Path(model_dir, "start.tflite"), new_end, output_model) + join_models(new_part, ExtractPaths.tflite.end(model_dir), new_end) + join_models(ExtractPaths.tflite.start(model_dir), new_end, output_model) def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]: @@ -244,7 +252,9 @@ def set_up_data_pipeline( input_name, output_name = _get_io_tensors(teacher) - input_filename = Path(train_dir, "input.tfrec") + model_is_quantized = replace.is_tensor_quantized(name=input_name) + + input_filename = ExtractPaths.tfrec.input(train_dir, model_is_quantized) total = numpytf_count(str(input_filename)) dict_inputs = numpytf_read(str(input_filename)) @@ -264,13 +274,13 @@ def set_up_data_pipeline( if any(augmentations): # Map the teacher inputs here because the augmentation stage passes these # through a TFLite model to get the outputs - teacher_outputs = numpytf_read(str(Path(teacher_dir, "input.tfrec"))).map( - lambda d: tf.squeeze(d[input_name], axis=0) - ) + teacher_outputs = numpytf_read( + str(ExtractPaths.tfrec.input(teacher_dir, model_is_quantized)) + ).map(lambda d: tf.squeeze(d[input_name], axis=0)) else: - teacher_outputs = numpytf_read(str(Path(teacher_dir, "output.tfrec"))).map( - lambda d: tf.squeeze(d[output_name], axis=0) - ) + teacher_outputs = numpytf_read( + str(ExtractPaths.tfrec.output(teacher_dir, model_is_quantized)) + ).map(lambda d: tf.squeeze(d[output_name], axis=0)) dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) if epochs > 1: @@ -285,7 +295,23 @@ def set_up_data_pipeline( ) -> tuple: """Return results of train and teach based on augmentations.""" augmented_train = augment_train({input_name: train})[input_name] - augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name] + # If augmentation of the input is enabled, we need to re-generate + # the reference output by running the augmented data through the + # teacher model. + if model_is_quantized: + # If the input model is quantized we have to additionally + # - quantize the augmented data before running it through the + # (quantized) teacher model + # - de-quantize the output for the training. + augmented_teach = teacher.dequantize_outputs( + teacher( + teacher.quantize_inputs(augment_teacher({input_name: teach})) + ) + )[output_name] + else: + augmented_teach = teacher(augment_teacher({input_name: teach}))[ + output_name + ] return (augmented_train, augmented_teach) dataset = dataset.map( @@ -329,15 +355,20 @@ def train_in_dir( """ teacher_dir = baseline_dir if baseline_dir else train_dir teacher = ParallelTFLiteModel( - f"{teacher_dir}/replace.tflite", + ExtractPaths.tflite.replace(teacher_dir), train_params.num_procs, train_params.num_threads, batch_size=train_params.batch_size, ) - replace = TFLiteModel(f"{train_dir}/replace.tflite") + replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir)) input_name, output_name = _get_io_tensors(teacher) + model_is_quantized = replace.is_tensor_quantized(name=input_name) + + if model_is_quantized: + replace.check_datatypes(np.int8) + dataset = set_up_data_pipeline( teacher, replace, @@ -354,6 +385,8 @@ def train_in_dir( optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = tf.keras.losses.MeanSquaredError() + if model_is_quantized: + model = tfmot.quantization.keras.quantize_model(model) model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) logger.info(model.summary()) @@ -432,6 +465,7 @@ def train_in_dir( replace.shape_from_name[input_name], output_name, replace.shape_from_name[output_name], + model_is_quantized, ) output_filenames.append(checkpoint_filename) @@ -446,6 +480,7 @@ def save_as_tflite( input_shape: list, output_name: str, output_shape: list, + model_is_quantized: bool = False, ) -> None: """Save Keras model as TFLite file.""" @@ -464,7 +499,7 @@ def save_as_tflite( keras_model.input.set_shape(orig_shape) with fixed_input(keras_model, input_shape) as fixed_model: - converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model) + converter = get_tflite_converter(fixed_model, quantized=model_is_quantized) tflite_model = converter.convert() with open(filename, "wb") as file: -- cgit v1.2.1