diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-07-19 16:35:57 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 16:06:17 +0100 |
commit | 3cd84481fa25e64c29e57396d4bf32d7a3ca490a (patch) | |
tree | ad81fb520a965bd3a3c7c983833b7cd48f9b8dea /src/mlia/nn/rewrite/core/rewrite.py | |
parent | f3e6597dd50ec70f043d692b773f2d9fd31519ae (diff) | |
download | mlia-3cd84481fa25e64c29e57396d4bf32d7a3ca490a.tar.gz |
Bug-fixes and re-factoring for the rewrite module
- Fix input shape of rewrite replacement:
During and after training of the replacement model for a rewrite the
Keras model is converted and saved in TensorFlow Lite format. If the
input shape does not match the teacher model exactly, e.g. if the
batch size is undefined, the TFLiteConverter adds extra operators
during conversion.
- Fix rewritten model output
- Save the model output with the rewritten operator in the output dir
- Log MAE and NRMSE of the rewrite
- Remove 'verbose' flag from rewrite module and rely on the logging
mechanism to control verbose output.
- Re-factor utility classes for rewrites
- Merge the two TFLiteModel classes
- Move functionality to load/save TensorFlow Lite flatbuffers to
nn/tensorflow/tflite_graph
- Fix issue with unknown shape in datasets
After upgrading to TensorFlow 2.12 the unknown shape of the
TFRecordDataset is causing problems when training the replacement models
for rewrites. By explicitly setting the right shape of the tensors we
can work around the issue.
- Adapt default parameters for rewrites. The training steps especially
had to be increased significantly to be effective.
Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 77 |
1 files changed, 33 insertions, 44 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 0d182df..6b27984 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -4,6 +4,7 @@ from __future__ import annotations import importlib +import logging import tempfile from dataclasses import dataclass from pathlib import Path @@ -12,13 +13,14 @@ from typing import Any from mlia.core.errors import ConfigurationError from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration -from mlia.nn.rewrite.core.train import eval_in_dir -from mlia.nn.rewrite.core.train import join_in_dir from mlia.nn.rewrite.core.train import train -from mlia.nn.rewrite.core.train import train_in_dir +from mlia.nn.rewrite.core.train import TrainingParameters from mlia.nn.tensorflow.config import TFLiteModel +logger = logging.getLogger(__name__) + + @dataclass class RewriteConfiguration(OptimizerConfiguration): """Rewrite configuration.""" @@ -26,6 +28,7 @@ class RewriteConfiguration(OptimizerConfiguration): optimization_target: str layers_to_optimize: list[str] | None = None dataset: Path | None = None + train_params: TrainingParameters = TrainingParameters() def __str__(self) -> str: """Return string representation of the configuration.""" @@ -40,8 +43,8 @@ class Rewriter(Optimizer): ): """Init Rewriter instance.""" self.model = TFLiteModel(tflite_model_path) + self.model_path = tflite_model_path self.optimizer_configuration = optimizer_configuration - self.train_dir = "" def apply_optimization(self) -> None: """Apply the rewrite flow.""" @@ -61,50 +64,36 @@ class Rewriter(Optimizer): replace_fn = get_function(replace_function) - augmentation_preset = (None, None) use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_output = Path(tmp_dir, "output.tflite") - - if self.train_dir: - tmp_new = Path(tmp_dir, "new.tflite") - new_part = train_in_dir( - train_dir=self.train_dir, - baseline_dir=None, - output_filename=tmp_new, - replace_fn=replace_fn, - augmentations=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=1, - verbose=True, - show_progress=True, - ) - eval_in_dir(self.train_dir, new_part[0]) - join_in_dir(self.train_dir, new_part[0], str(tmp_output)) - else: - if not self.optimizer_configuration.layers_to_optimize: - raise ConfigurationError( - "Input and output tensor names need to be set for rewrite." - ) - train( - source_model=tflite_model, - unmodified_model=tflite_model if use_unmodified_model else None, - output_model=str(tmp_output), - input_tfrec=str(tfrecord), - replace_fn=replace_fn, - input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], - output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], - augment=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=1, - verbose=True, - show_progress=True, - ) + tmp_dir = tempfile.mkdtemp() + tmp_output = Path(tmp_dir, "output.tflite") + + if not self.optimizer_configuration.layers_to_optimize: + raise ConfigurationError( + "Input and output tensor names need to be set for rewrite." + ) + result = train( + source_model=tflite_model, + unmodified_model=tflite_model if use_unmodified_model else None, + output_model=str(tmp_output), + input_tfrec=str(tfrecord), + replace_fn=replace_fn, + input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], + output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], + train_params=self.optimizer_configuration.train_params, + ) + + self.model = TFLiteModel(tmp_output) + + if result: + stats_as_str = ", ".join(str(stats) for stats in result) + logger.info( + "The MAE and NRMSE between original and replacement [%s]", + stats_as_str, + ) def get_model(self) -> TFLiteModel: """Return optimized model.""" |