From ecc4264b93d4a89fa2cb40518b225d8371b7ffad Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Wed, 12 Jul 2023 15:18:26 +0100 Subject: Enable rewrites for quantized input models If the input model for rewriting is quantized: - Record de-quantized TFRecords - enable writing de-quantized calibration data for the training - re-generate augmented training data, if needed - Use quantization-aware training (QAT) to train the replacement models - Check if replacement model is quantized: If source model is quantized, we make sure rewrite's output model is quantized too. Right now, only int8 is supported so raising an error if any other datatype is present in the output. Resolves: MLIA-907, MLIA-908, MLIA-927 Signed-off-by: Benjamin Klimczak Change-Id: Icb4070a9e6f1fdb5ce36120d73823986e89ac955 --- src/mlia/nn/rewrite/core/graph_edit/record.py | 45 ++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 7 deletions(-) (limited to 'src/mlia/nn/rewrite/core/graph_edit/record.py') diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index 90f3db8..f85433d 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -6,6 +6,7 @@ from __future__ import annotations import math import os +from contextlib import ExitStack from pathlib import Path import tensorflow as tf @@ -15,11 +16,22 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.tensorflow.config import NameToTensorMap os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +DEQUANT_SUFFIX = "_dequant" + + +def dequantized_path(filename: str | Path) -> Path: + """Append the de-quantization suffix to the given filename.""" + path = Path(filename) + path = Path(path.parent, f"{path.stem}{DEQUANT_SUFFIX}{path.suffix}") + return path + + def record_model( input_filename: str | Path, model_filename: str | Path, @@ -28,11 +40,14 @@ def record_model( show_progress: bool = False, num_procs: int = 1, num_threads: int = 0, + dequantize_output: bool = False, ) -> None: """Model recorder. num_procs: 0 => detect real cores on system num_threads: 0 => TFLite impl. specific setting, usually 3 + + dequantize: True => de-quantize the recorded output before saving """ model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size) if not batch_size: @@ -51,22 +66,38 @@ def record_model( dataset = dataset.batch(batch_size, drop_remainder=False) total = int(math.ceil(total / batch_size)) - with NumpyTFWriter(output_filename) as writer: - for _, named_x in enumerate( - track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) - ): - named_y = model(named_x) + with ExitStack() as stack: + writer = stack.enter_context(NumpyTFWriter(output_filename)) + writer_dequant = None + if dequantize_output: + dequant_path = dequantized_path(output_filename) + writer_dequant = stack.enter_context(NumpyTFWriter(dequant_path)) + + def write(writer: NumpyTFWriter, data: NameToTensorMap) -> None: + """Write the data using the given NumpyTFWriter instance.""" if batch_size > 1: for i in range(batch_size): # Expand the batches and recreate each dict as a # batch-size 1 item for the tfrec output recreated_dict = { k: v[i : i + 1] # noqa: E203 - for k, v in named_y.items() + for k, v in data.items() if i < v.shape[0] } if recreated_dict: writer.write(recreated_dict) else: - writer.write(named_y) + writer.write(data) + + for _, named_x in enumerate( + track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) + ): + named_y = model(named_x) + write(writer, named_y) + + if dequantize_output: + assert writer_dequant + named_y_dequant = model.dequantize_outputs(named_y) + write(writer_dequant, named_y_dequant) + model.close() -- cgit v1.2.1