diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 77 |
1 files changed, 33 insertions, 44 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 0d182df..6b27984 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -4,6 +4,7 @@ from __future__ import annotations import importlib +import logging import tempfile from dataclasses import dataclass from pathlib import Path @@ -12,13 +13,14 @@ from typing import Any from mlia.core.errors import ConfigurationError from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration -from mlia.nn.rewrite.core.train import eval_in_dir -from mlia.nn.rewrite.core.train import join_in_dir from mlia.nn.rewrite.core.train import train -from mlia.nn.rewrite.core.train import train_in_dir +from mlia.nn.rewrite.core.train import TrainingParameters from mlia.nn.tensorflow.config import TFLiteModel +logger = logging.getLogger(__name__) + + @dataclass class RewriteConfiguration(OptimizerConfiguration): """Rewrite configuration.""" @@ -26,6 +28,7 @@ class RewriteConfiguration(OptimizerConfiguration): optimization_target: str layers_to_optimize: list[str] | None = None dataset: Path | None = None + train_params: TrainingParameters = TrainingParameters() def __str__(self) -> str: """Return string representation of the configuration.""" @@ -40,8 +43,8 @@ class Rewriter(Optimizer): ): """Init Rewriter instance.""" self.model = TFLiteModel(tflite_model_path) + self.model_path = tflite_model_path self.optimizer_configuration = optimizer_configuration - self.train_dir = "" def apply_optimization(self) -> None: """Apply the rewrite flow.""" @@ -61,50 +64,36 @@ class Rewriter(Optimizer): replace_fn = get_function(replace_function) - augmentation_preset = (None, None) use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_output = Path(tmp_dir, "output.tflite") - - if self.train_dir: - tmp_new = Path(tmp_dir, "new.tflite") - new_part = train_in_dir( - train_dir=self.train_dir, - baseline_dir=None, - output_filename=tmp_new, - replace_fn=replace_fn, - augmentations=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=1, - verbose=True, - show_progress=True, - ) - eval_in_dir(self.train_dir, new_part[0]) - join_in_dir(self.train_dir, new_part[0], str(tmp_output)) - else: - if not self.optimizer_configuration.layers_to_optimize: - raise ConfigurationError( - "Input and output tensor names need to be set for rewrite." - ) - train( - source_model=tflite_model, - unmodified_model=tflite_model if use_unmodified_model else None, - output_model=str(tmp_output), - input_tfrec=str(tfrecord), - replace_fn=replace_fn, - input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], - output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], - augment=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=1, - verbose=True, - show_progress=True, - ) + tmp_dir = tempfile.mkdtemp() + tmp_output = Path(tmp_dir, "output.tflite") + + if not self.optimizer_configuration.layers_to_optimize: + raise ConfigurationError( + "Input and output tensor names need to be set for rewrite." + ) + result = train( + source_model=tflite_model, + unmodified_model=tflite_model if use_unmodified_model else None, + output_model=str(tmp_output), + input_tfrec=str(tfrecord), + replace_fn=replace_fn, + input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], + output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], + train_params=self.optimizer_configuration.train_params, + ) + + self.model = TFLiteModel(tmp_output) + + if result: + stats_as_str = ", ".join(str(stats) for stats in result) + logger.info( + "The MAE and NRMSE between original and replacement [%s]", + stats_as_str, + ) def get_model(self) -> TFLiteModel: """Return optimized model.""" |