aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-07-12 15:18:26 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 16:16:32 +0100
commitecc4264b93d4a89fa2cb40518b225d8371b7ffad (patch)
tree47244d2d67ab6c50bfc15eab768252359eae0df6 /src/mlia/nn/rewrite/core/train.py
parentbaaf4de286762c1955c874f78cd802d4703a8ba5 (diff)
downloadmlia-ecc4264b93d4a89fa2cb40518b225d8371b7ffad.tar.gz
Enable rewrites for quantized input models
If the input model for rewriting is quantized: - Record de-quantized TFRecords - enable writing de-quantized calibration data for the training - re-generate augmented training data, if needed - Use quantization-aware training (QAT) to train the replacement models - Check if replacement model is quantized: If source model is quantized, we make sure rewrite's output model is quantized too. Right now, only int8 is supported so raising an error if any other datatype is present in the output. Resolves: MLIA-907, MLIA-908, MLIA-927 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: Icb4070a9e6f1fdb5ce36120d73823986e89ac955
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py67
1 files changed, 51 insertions, 16 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 82af747..6345f07 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -22,9 +22,11 @@ from typing import Literal
import numpy as np
import tensorflow as tf
+import tensorflow_model_optimization as tfmot
from numpy.random import Generator
from mlia.nn.rewrite.core.extract import extract
+from mlia.nn.rewrite.core.extract import ExtractPaths
from mlia.nn.rewrite.core.graph_edit.diff import diff_stats
from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.graph_edit.record import record_model
@@ -34,6 +36,7 @@ from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
+from mlia.nn.tensorflow.utils import get_tflite_converter
from mlia.utils.logging import log_action
@@ -91,6 +94,7 @@ def train(
input_tfrec,
input_tensors,
output_tensors,
+ dequantize_output=True,
)
else:
unmodified_model_dir = None
@@ -106,6 +110,7 @@ def train(
output_tensors,
num_procs=train_params.num_procs,
num_threads=train_params.num_threads,
+ dequantize_output=True,
)
tflite_filenames = train_in_dir(
@@ -160,7 +165,10 @@ def train(
def eval_in_dir(
- target_dir: str, new_part: str, num_procs: int = 1, num_threads: int = 0
+ target_dir: str,
+ new_part: str,
+ num_procs: int = 1,
+ num_threads: int = 0,
) -> tuple:
"""Evaluate a model in a given directory."""
model_input_path = Path(target_dir, "input_orig.tfrec")
@@ -168,12 +176,12 @@ def eval_in_dir(
model_input = (
model_input_path
if model_input_path.exists()
- else Path(target_dir, "input.tfrec")
+ else ExtractPaths.tfrec.input(target_dir, False)
)
output = (
model_output_path
if model_output_path.exists()
- else Path(target_dir, "output.tfrec")
+ else ExtractPaths.tfrec.output(target_dir, False)
)
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -194,8 +202,8 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None:
"""Join two models in a given directory."""
with tempfile.TemporaryDirectory() as tmp_dir:
new_end = Path(tmp_dir, "new_end.tflite")
- join_models(new_part, Path(model_dir, "end.tflite"), new_end)
- join_models(Path(model_dir, "start.tflite"), new_end, output_model)
+ join_models(new_part, ExtractPaths.tflite.end(model_dir), new_end)
+ join_models(ExtractPaths.tflite.start(model_dir), new_end, output_model)
def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]:
@@ -244,7 +252,9 @@ def set_up_data_pipeline(
input_name, output_name = _get_io_tensors(teacher)
- input_filename = Path(train_dir, "input.tfrec")
+ model_is_quantized = replace.is_tensor_quantized(name=input_name)
+
+ input_filename = ExtractPaths.tfrec.input(train_dir, model_is_quantized)
total = numpytf_count(str(input_filename))
dict_inputs = numpytf_read(str(input_filename))
@@ -264,13 +274,13 @@ def set_up_data_pipeline(
if any(augmentations):
# Map the teacher inputs here because the augmentation stage passes these
# through a TFLite model to get the outputs
- teacher_outputs = numpytf_read(str(Path(teacher_dir, "input.tfrec"))).map(
- lambda d: tf.squeeze(d[input_name], axis=0)
- )
+ teacher_outputs = numpytf_read(
+ str(ExtractPaths.tfrec.input(teacher_dir, model_is_quantized))
+ ).map(lambda d: tf.squeeze(d[input_name], axis=0))
else:
- teacher_outputs = numpytf_read(str(Path(teacher_dir, "output.tfrec"))).map(
- lambda d: tf.squeeze(d[output_name], axis=0)
- )
+ teacher_outputs = numpytf_read(
+ str(ExtractPaths.tfrec.output(teacher_dir, model_is_quantized))
+ ).map(lambda d: tf.squeeze(d[output_name], axis=0))
dataset = tf.data.Dataset.zip((inputs, teacher_outputs))
if epochs > 1:
@@ -285,7 +295,23 @@ def set_up_data_pipeline(
) -> tuple:
"""Return results of train and teach based on augmentations."""
augmented_train = augment_train({input_name: train})[input_name]
- augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name]
+ # If augmentation of the input is enabled, we need to re-generate
+ # the reference output by running the augmented data through the
+ # teacher model.
+ if model_is_quantized:
+ # If the input model is quantized we have to additionally
+ # - quantize the augmented data before running it through the
+ # (quantized) teacher model
+ # - de-quantize the output for the training.
+ augmented_teach = teacher.dequantize_outputs(
+ teacher(
+ teacher.quantize_inputs(augment_teacher({input_name: teach}))
+ )
+ )[output_name]
+ else:
+ augmented_teach = teacher(augment_teacher({input_name: teach}))[
+ output_name
+ ]
return (augmented_train, augmented_teach)
dataset = dataset.map(
@@ -329,15 +355,20 @@ def train_in_dir(
"""
teacher_dir = baseline_dir if baseline_dir else train_dir
teacher = ParallelTFLiteModel(
- f"{teacher_dir}/replace.tflite",
+ ExtractPaths.tflite.replace(teacher_dir),
train_params.num_procs,
train_params.num_threads,
batch_size=train_params.batch_size,
)
- replace = TFLiteModel(f"{train_dir}/replace.tflite")
+ replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir))
input_name, output_name = _get_io_tensors(teacher)
+ model_is_quantized = replace.is_tensor_quantized(name=input_name)
+
+ if model_is_quantized:
+ replace.check_datatypes(np.int8)
+
dataset = set_up_data_pipeline(
teacher,
replace,
@@ -354,6 +385,8 @@ def train_in_dir(
optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = tf.keras.losses.MeanSquaredError()
+ if model_is_quantized:
+ model = tfmot.quantization.keras.quantize_model(model)
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
logger.info(model.summary())
@@ -432,6 +465,7 @@ def train_in_dir(
replace.shape_from_name[input_name],
output_name,
replace.shape_from_name[output_name],
+ model_is_quantized,
)
output_filenames.append(checkpoint_filename)
@@ -446,6 +480,7 @@ def save_as_tflite(
input_shape: list,
output_name: str,
output_shape: list,
+ model_is_quantized: bool = False,
) -> None:
"""Save Keras model as TFLite file."""
@@ -464,7 +499,7 @@ def save_as_tflite(
keras_model.input.set_shape(orig_shape)
with fixed_input(keras_model, input_shape) as fixed_model:
- converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model)
+ converter = get_tflite_converter(fixed_model, quantized=model_is_quantized)
tflite_model = converter.convert()
with open(filename, "wb") as file: