aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-07-19 16:35:57 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 16:06:17 +0100
commit3cd84481fa25e64c29e57396d4bf32d7a3ca490a (patch)
treead81fb520a965bd3a3c7c983833b7cd48f9b8dea /src/mlia/nn/rewrite/core/rewrite.py
parentf3e6597dd50ec70f043d692b773f2d9fd31519ae (diff)
downloadmlia-3cd84481fa25e64c29e57396d4bf32d7a3ca490a.tar.gz
Bug-fixes and re-factoring for the rewrite module
- Fix input shape of rewrite replacement: During and after training of the replacement model for a rewrite the Keras model is converted and saved in TensorFlow Lite format. If the input shape does not match the teacher model exactly, e.g. if the batch size is undefined, the TFLiteConverter adds extra operators during conversion. - Fix rewritten model output - Save the model output with the rewritten operator in the output dir - Log MAE and NRMSE of the rewrite - Remove 'verbose' flag from rewrite module and rely on the logging mechanism to control verbose output. - Re-factor utility classes for rewrites - Merge the two TFLiteModel classes - Move functionality to load/save TensorFlow Lite flatbuffers to nn/tensorflow/tflite_graph - Fix issue with unknown shape in datasets After upgrading to TensorFlow 2.12 the unknown shape of the TFRecordDataset is causing problems when training the replacement models for rewrites. By explicitly setting the right shape of the tensors we can work around the issue. - Adapt default parameters for rewrites. The training steps especially had to be increased significantly to be effective. Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py77
1 files changed, 33 insertions, 44 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 0d182df..6b27984 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -4,6 +4,7 @@
from __future__ import annotations
import importlib
+import logging
import tempfile
from dataclasses import dataclass
from pathlib import Path
@@ -12,13 +13,14 @@ from typing import Any
from mlia.core.errors import ConfigurationError
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
-from mlia.nn.rewrite.core.train import eval_in_dir
-from mlia.nn.rewrite.core.train import join_in_dir
from mlia.nn.rewrite.core.train import train
-from mlia.nn.rewrite.core.train import train_in_dir
+from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import TFLiteModel
+logger = logging.getLogger(__name__)
+
+
@dataclass
class RewriteConfiguration(OptimizerConfiguration):
"""Rewrite configuration."""
@@ -26,6 +28,7 @@ class RewriteConfiguration(OptimizerConfiguration):
optimization_target: str
layers_to_optimize: list[str] | None = None
dataset: Path | None = None
+ train_params: TrainingParameters = TrainingParameters()
def __str__(self) -> str:
"""Return string representation of the configuration."""
@@ -40,8 +43,8 @@ class Rewriter(Optimizer):
):
"""Init Rewriter instance."""
self.model = TFLiteModel(tflite_model_path)
+ self.model_path = tflite_model_path
self.optimizer_configuration = optimizer_configuration
- self.train_dir = ""
def apply_optimization(self) -> None:
"""Apply the rewrite flow."""
@@ -61,50 +64,36 @@ class Rewriter(Optimizer):
replace_fn = get_function(replace_function)
- augmentation_preset = (None, None)
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
- with tempfile.TemporaryDirectory() as tmp_dir:
- tmp_output = Path(tmp_dir, "output.tflite")
-
- if self.train_dir:
- tmp_new = Path(tmp_dir, "new.tflite")
- new_part = train_in_dir(
- train_dir=self.train_dir,
- baseline_dir=None,
- output_filename=tmp_new,
- replace_fn=replace_fn,
- augmentations=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=1,
- verbose=True,
- show_progress=True,
- )
- eval_in_dir(self.train_dir, new_part[0])
- join_in_dir(self.train_dir, new_part[0], str(tmp_output))
- else:
- if not self.optimizer_configuration.layers_to_optimize:
- raise ConfigurationError(
- "Input and output tensor names need to be set for rewrite."
- )
- train(
- source_model=tflite_model,
- unmodified_model=tflite_model if use_unmodified_model else None,
- output_model=str(tmp_output),
- input_tfrec=str(tfrecord),
- replace_fn=replace_fn,
- input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
- output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
- augment=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=1,
- verbose=True,
- show_progress=True,
- )
+ tmp_dir = tempfile.mkdtemp()
+ tmp_output = Path(tmp_dir, "output.tflite")
+
+ if not self.optimizer_configuration.layers_to_optimize:
+ raise ConfigurationError(
+ "Input and output tensor names need to be set for rewrite."
+ )
+ result = train(
+ source_model=tflite_model,
+ unmodified_model=tflite_model if use_unmodified_model else None,
+ output_model=str(tmp_output),
+ input_tfrec=str(tfrecord),
+ replace_fn=replace_fn,
+ input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
+ output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
+ train_params=self.optimizer_configuration.train_params,
+ )
+
+ self.model = TFLiteModel(tmp_output)
+
+ if result:
+ stats_as_str = ", ".join(str(stats) for stats in result)
+ logger.info(
+ "The MAE and NRMSE between original and replacement [%s]",
+ stats_as_str,
+ )
def get_model(self) -> TFLiteModel:
"""Return optimized model."""