diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 67 |
1 files changed, 51 insertions, 16 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 82af747..6345f07 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -22,9 +22,11 @@ from typing import Literal import numpy as np import tensorflow as tf +import tensorflow_model_optimization as tfmot from numpy.random import Generator from mlia.nn.rewrite.core.extract import extract +from mlia.nn.rewrite.core.extract import ExtractPaths from mlia.nn.rewrite.core.graph_edit.diff import diff_stats from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.graph_edit.record import record_model @@ -34,6 +36,7 @@ from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb +from mlia.nn.tensorflow.utils import get_tflite_converter from mlia.utils.logging import log_action @@ -91,6 +94,7 @@ def train( input_tfrec, input_tensors, output_tensors, + dequantize_output=True, ) else: unmodified_model_dir = None @@ -106,6 +110,7 @@ def train( output_tensors, num_procs=train_params.num_procs, num_threads=train_params.num_threads, + dequantize_output=True, ) tflite_filenames = train_in_dir( @@ -160,7 +165,10 @@ def train( def eval_in_dir( - target_dir: str, new_part: str, num_procs: int = 1, num_threads: int = 0 + target_dir: str, + new_part: str, + num_procs: int = 1, + num_threads: int = 0, ) -> tuple: """Evaluate a model in a given directory.""" model_input_path = Path(target_dir, "input_orig.tfrec") @@ -168,12 +176,12 @@ def eval_in_dir( model_input = ( model_input_path if model_input_path.exists() - else Path(target_dir, "input.tfrec") + else ExtractPaths.tfrec.input(target_dir, False) ) output = ( model_output_path if model_output_path.exists() - else Path(target_dir, "output.tfrec") + else ExtractPaths.tfrec.output(target_dir, False) ) with tempfile.TemporaryDirectory() as tmp_dir: @@ -194,8 +202,8 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None: """Join two models in a given directory.""" with tempfile.TemporaryDirectory() as tmp_dir: new_end = Path(tmp_dir, "new_end.tflite") - join_models(new_part, Path(model_dir, "end.tflite"), new_end) - join_models(Path(model_dir, "start.tflite"), new_end, output_model) + join_models(new_part, ExtractPaths.tflite.end(model_dir), new_end) + join_models(ExtractPaths.tflite.start(model_dir), new_end, output_model) def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]: @@ -244,7 +252,9 @@ def set_up_data_pipeline( input_name, output_name = _get_io_tensors(teacher) - input_filename = Path(train_dir, "input.tfrec") + model_is_quantized = replace.is_tensor_quantized(name=input_name) + + input_filename = ExtractPaths.tfrec.input(train_dir, model_is_quantized) total = numpytf_count(str(input_filename)) dict_inputs = numpytf_read(str(input_filename)) @@ -264,13 +274,13 @@ def set_up_data_pipeline( if any(augmentations): # Map the teacher inputs here because the augmentation stage passes these # through a TFLite model to get the outputs - teacher_outputs = numpytf_read(str(Path(teacher_dir, "input.tfrec"))).map( - lambda d: tf.squeeze(d[input_name], axis=0) - ) + teacher_outputs = numpytf_read( + str(ExtractPaths.tfrec.input(teacher_dir, model_is_quantized)) + ).map(lambda d: tf.squeeze(d[input_name], axis=0)) else: - teacher_outputs = numpytf_read(str(Path(teacher_dir, "output.tfrec"))).map( - lambda d: tf.squeeze(d[output_name], axis=0) - ) + teacher_outputs = numpytf_read( + str(ExtractPaths.tfrec.output(teacher_dir, model_is_quantized)) + ).map(lambda d: tf.squeeze(d[output_name], axis=0)) dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) if epochs > 1: @@ -285,7 +295,23 @@ def set_up_data_pipeline( ) -> tuple: """Return results of train and teach based on augmentations.""" augmented_train = augment_train({input_name: train})[input_name] - augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name] + # If augmentation of the input is enabled, we need to re-generate + # the reference output by running the augmented data through the + # teacher model. + if model_is_quantized: + # If the input model is quantized we have to additionally + # - quantize the augmented data before running it through the + # (quantized) teacher model + # - de-quantize the output for the training. + augmented_teach = teacher.dequantize_outputs( + teacher( + teacher.quantize_inputs(augment_teacher({input_name: teach})) + ) + )[output_name] + else: + augmented_teach = teacher(augment_teacher({input_name: teach}))[ + output_name + ] return (augmented_train, augmented_teach) dataset = dataset.map( @@ -329,15 +355,20 @@ def train_in_dir( """ teacher_dir = baseline_dir if baseline_dir else train_dir teacher = ParallelTFLiteModel( - f"{teacher_dir}/replace.tflite", + ExtractPaths.tflite.replace(teacher_dir), train_params.num_procs, train_params.num_threads, batch_size=train_params.batch_size, ) - replace = TFLiteModel(f"{train_dir}/replace.tflite") + replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir)) input_name, output_name = _get_io_tensors(teacher) + model_is_quantized = replace.is_tensor_quantized(name=input_name) + + if model_is_quantized: + replace.check_datatypes(np.int8) + dataset = set_up_data_pipeline( teacher, replace, @@ -354,6 +385,8 @@ def train_in_dir( optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = tf.keras.losses.MeanSquaredError() + if model_is_quantized: + model = tfmot.quantization.keras.quantize_model(model) model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) logger.info(model.summary()) @@ -432,6 +465,7 @@ def train_in_dir( replace.shape_from_name[input_name], output_name, replace.shape_from_name[output_name], + model_is_quantized, ) output_filenames.append(checkpoint_filename) @@ -446,6 +480,7 @@ def save_as_tflite( input_shape: list, output_name: str, output_shape: list, + model_is_quantized: bool = False, ) -> None: """Save Keras model as TFLite file.""" @@ -464,7 +499,7 @@ def save_as_tflite( keras_model.input.set_shape(orig_shape) with fixed_input(keras_model, input_shape) as fixed_model: - converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model) + converter = get_tflite_converter(fixed_model, quantized=model_is_quantized) tflite_model = converter.convert() with open(filename, "wb") as file: |