aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-04-08 15:42:11 +0100
committerNathan Bailey <nathan.bailey@arm.com>2024-05-07 11:48:11 +0000
commit0999ba0f4381ce1e2e8b06a932bfe693692223e2 (patch)
treecdc1061b885f9bfa7548b5c9bc569750a7ba957c
parent7565b9a5460ddc7d74eb74dfb6d0c4264f99ebaf (diff)
downloadmlia-0999ba0f4381ce1e2e8b06a932bfe693692223e2.tar.gz
fix: Fixes MAE discrepancies of rewrites
- Adds model.evaluate stage to retrieve correct MAE - Sets augmentations to none by default - Enables MAE calculations using dequantized data (if needed) Resolves: MLIA-972 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: Ia838b2d86caf7b10ad6e4d87bf6aa9e27c80bb72
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py10
-rw-r--r--src/mlia/nn/rewrite/core/train.py33
2 files changed, 30 insertions, 13 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index f85433d..7d9f219 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Save subgraph data."""
# pylint: disable=too-many-locals
@@ -32,7 +32,7 @@ def dequantized_path(filename: str | Path) -> Path:
return path
-def record_model(
+def record_model( # pylint: disable=too-many-arguments
input_filename: str | Path,
model_filename: str | Path,
output_filename: str | Path,
@@ -41,6 +41,7 @@ def record_model(
num_procs: int = 1,
num_threads: int = 0,
dequantize_output: bool = False,
+ quantize_input: bool = False,
) -> None:
"""Model recorder.
@@ -92,7 +93,10 @@ def record_model(
for _, named_x in enumerate(
track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
):
- named_y = model(named_x)
+ if quantize_input:
+ named_y = model(model.quantize_inputs(named_x))
+ else:
+ named_y = model(named_x)
write(writer, named_y)
if dequantize_output:
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 88efa23..4204978 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -62,7 +62,7 @@ LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule)
class TrainingParameters:
"""Define default parameters for the training."""
- augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"]
+ augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["none"]
batch_size: int = 32
steps: int = 48000
learning_rate: float = 1e-3
@@ -147,7 +147,8 @@ def train( # pylint: disable=too-many-arguments
# Assess the output diff between the parts after the rewrite subgraph
# in original and optimized model
optimized_end_path = Path(train_dir, "optimized_end.tfrec")
- end_path = Path(train_dir, "end.tfrec")
+ optimized_end_path_dequant = Path(train_dir, "optimized_end_dequant.tfrec")
+ end_path = Path(train_dir, "end_dequant.tfrec")
record_model(
str(input_tfrec),
@@ -155,8 +156,10 @@ def train( # pylint: disable=too-many-arguments
optimized_end_path,
num_procs=train_params.num_procs,
num_threads=train_params.num_threads,
+ dequantize_output=True,
)
- mae, nrmse = diff_stats(end_path, str(optimized_end_path))
+
+ mae, nrmse = diff_stats(end_path, optimized_end_path_dequant)
if unmodified_model_dir:
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
@@ -179,24 +182,27 @@ def eval_in_dir(
model_input = (
model_input_path
if model_input_path.exists()
- else ExtractPaths.tfrec.input(target_dir, False)
+ else ExtractPaths.tfrec.input(target_dir, True)
)
output = (
model_output_path
if model_output_path.exists()
- else ExtractPaths.tfrec.output(target_dir, False)
+ else ExtractPaths.tfrec.output(target_dir, True)
)
with tempfile.TemporaryDirectory() as tmp_dir:
predict = Path(tmp_dir, "predict.tfrec")
+ predict_dequant = Path(tmp_dir, "predict_dequant.tfrec")
record_model(
str(model_input),
new_part,
str(predict),
num_procs=num_procs,
num_threads=num_threads,
+ dequantize_output=True,
+ quantize_input=True,
)
- mae, nrmse = diff_stats(str(output), str(predict))
+ mae, nrmse = diff_stats(str(output), predict_dequant)
return mae, nrmse
@@ -249,7 +255,7 @@ def set_up_data_pipeline(
augmentations: tuple[float | None, float | None],
steps: int,
batch_size: int = 32,
-) -> tf.data.Dataset:
+) -> tuple[tf.data.Dataset, int]:
"""Create a data pipeline for training of the replacement model."""
_check_model_compatibility(teacher, replace)
@@ -340,7 +346,7 @@ def set_up_data_pipeline(
dataset = dataset.map(restore_shapes)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
- return dataset
+ return dataset, steps_per_epoch
def train_in_dir(
@@ -373,7 +379,7 @@ def train_in_dir(
if model_is_quantized:
replace.check_datatypes(np.int8)
- dataset = set_up_data_pipeline(
+ dataset, steps_per_epoch = set_up_data_pipeline(
teacher,
replace,
train_dir,
@@ -453,6 +459,7 @@ def train_in_dir(
input_shape,
output_shape,
loss_fn,
+ steps_per_epoch,
post_process=True,
)
@@ -492,6 +499,7 @@ def train_in_dir(
input_shape,
output_shape,
loss_fn,
+ steps_per_epoch,
)
# Placeholder for now, will be parametrized later (MLIA-1114)
# rewrite.check_optimization( # type: ignore[attr-defined]
@@ -548,6 +556,7 @@ def model_fit( # pylint: disable=too-many-arguments
input_shape: int,
output_shape: int,
loss_fn: Callable,
+ steps_per_epoch: int,
post_process: bool = False,
) -> keras.Model:
"""Train a tflite model."""
@@ -593,8 +602,12 @@ def model_fit( # pylint: disable=too-many-arguments
model_to_save = model
else:
checkpoint_filename = str(output_filename)
+ logger.info("Evaluate final Keras Model using %d steps", steps_per_epoch)
+ model.evaluate(
+ dataset,
+ steps=steps_per_epoch,
+ )
model_to_save = model
-
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):