diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-04-08 15:42:11 +0100 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-07 11:48:11 +0000 |
commit | 0999ba0f4381ce1e2e8b06a932bfe693692223e2 (patch) | |
tree | cdc1061b885f9bfa7548b5c9bc569750a7ba957c | |
parent | 7565b9a5460ddc7d74eb74dfb6d0c4264f99ebaf (diff) | |
download | mlia-0999ba0f4381ce1e2e8b06a932bfe693692223e2.tar.gz |
fix: Fixes MAE discrepancies of rewrites
- Adds model.evaluate stage to retrieve correct MAE
- Sets augmentations to none by default
- Enables MAE calculations using dequantized data (if needed)
Resolves: MLIA-972
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: Ia838b2d86caf7b10ad6e4d87bf6aa9e27c80bb72
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/record.py | 10 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 33 |
2 files changed, 30 insertions, 13 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index f85433d..7d9f219 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Save subgraph data.""" # pylint: disable=too-many-locals @@ -32,7 +32,7 @@ def dequantized_path(filename: str | Path) -> Path: return path -def record_model( +def record_model( # pylint: disable=too-many-arguments input_filename: str | Path, model_filename: str | Path, output_filename: str | Path, @@ -41,6 +41,7 @@ def record_model( num_procs: int = 1, num_threads: int = 0, dequantize_output: bool = False, + quantize_input: bool = False, ) -> None: """Model recorder. @@ -92,7 +93,10 @@ def record_model( for _, named_x in enumerate( track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) ): - named_y = model(named_x) + if quantize_input: + named_y = model(model.quantize_inputs(named_x)) + else: + named_y = model(named_x) write(writer, named_y) if dequantize_output: diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 88efa23..4204978 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -62,7 +62,7 @@ LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule) class TrainingParameters: """Define default parameters for the training.""" - augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"] + augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["none"] batch_size: int = 32 steps: int = 48000 learning_rate: float = 1e-3 @@ -147,7 +147,8 @@ def train( # pylint: disable=too-many-arguments # Assess the output diff between the parts after the rewrite subgraph # in original and optimized model optimized_end_path = Path(train_dir, "optimized_end.tfrec") - end_path = Path(train_dir, "end.tfrec") + optimized_end_path_dequant = Path(train_dir, "optimized_end_dequant.tfrec") + end_path = Path(train_dir, "end_dequant.tfrec") record_model( str(input_tfrec), @@ -155,8 +156,10 @@ def train( # pylint: disable=too-many-arguments optimized_end_path, num_procs=train_params.num_procs, num_threads=train_params.num_threads, + dequantize_output=True, ) - mae, nrmse = diff_stats(end_path, str(optimized_end_path)) + + mae, nrmse = diff_stats(end_path, optimized_end_path_dequant) if unmodified_model_dir: cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() @@ -179,24 +182,27 @@ def eval_in_dir( model_input = ( model_input_path if model_input_path.exists() - else ExtractPaths.tfrec.input(target_dir, False) + else ExtractPaths.tfrec.input(target_dir, True) ) output = ( model_output_path if model_output_path.exists() - else ExtractPaths.tfrec.output(target_dir, False) + else ExtractPaths.tfrec.output(target_dir, True) ) with tempfile.TemporaryDirectory() as tmp_dir: predict = Path(tmp_dir, "predict.tfrec") + predict_dequant = Path(tmp_dir, "predict_dequant.tfrec") record_model( str(model_input), new_part, str(predict), num_procs=num_procs, num_threads=num_threads, + dequantize_output=True, + quantize_input=True, ) - mae, nrmse = diff_stats(str(output), str(predict)) + mae, nrmse = diff_stats(str(output), predict_dequant) return mae, nrmse @@ -249,7 +255,7 @@ def set_up_data_pipeline( augmentations: tuple[float | None, float | None], steps: int, batch_size: int = 32, -) -> tf.data.Dataset: +) -> tuple[tf.data.Dataset, int]: """Create a data pipeline for training of the replacement model.""" _check_model_compatibility(teacher, replace) @@ -340,7 +346,7 @@ def set_up_data_pipeline( dataset = dataset.map(restore_shapes) dataset = dataset.prefetch(tf.data.AUTOTUNE) - return dataset + return dataset, steps_per_epoch def train_in_dir( @@ -373,7 +379,7 @@ def train_in_dir( if model_is_quantized: replace.check_datatypes(np.int8) - dataset = set_up_data_pipeline( + dataset, steps_per_epoch = set_up_data_pipeline( teacher, replace, train_dir, @@ -453,6 +459,7 @@ def train_in_dir( input_shape, output_shape, loss_fn, + steps_per_epoch, post_process=True, ) @@ -492,6 +499,7 @@ def train_in_dir( input_shape, output_shape, loss_fn, + steps_per_epoch, ) # Placeholder for now, will be parametrized later (MLIA-1114) # rewrite.check_optimization( # type: ignore[attr-defined] @@ -548,6 +556,7 @@ def model_fit( # pylint: disable=too-many-arguments input_shape: int, output_shape: int, loss_fn: Callable, + steps_per_epoch: int, post_process: bool = False, ) -> keras.Model: """Train a tflite model.""" @@ -593,8 +602,12 @@ def model_fit( # pylint: disable=too-many-arguments model_to_save = model else: checkpoint_filename = str(output_filename) + logger.info("Evaluate final Keras Model using %d steps", steps_per_epoch) + model.evaluate( + dataset, + steps=steps_per_epoch, + ) model_to_save = model - with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): |