aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
blob: 0d182df5bacdefd5d8e7744f45244452bafac718 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Contains class Rewriter to replace a subgraph/layer of a model."""
from __future__ import annotations

import importlib
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from mlia.core.errors import ConfigurationError
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import eval_in_dir
from mlia.nn.rewrite.core.train import join_in_dir
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import train_in_dir
from mlia.nn.tensorflow.config import TFLiteModel


@dataclass
class RewriteConfiguration(OptimizerConfiguration):
    """Rewrite configuration."""

    optimization_target: str
    layers_to_optimize: list[str] | None = None
    dataset: Path | None = None

    def __str__(self) -> str:
        """Return string representation of the configuration."""
        return f"rewrite: {self.optimization_target}"


class Rewriter(Optimizer):
    """Rewriter class for basic rewrite flow."""

    def __init__(
        self, tflite_model_path: Path, optimizer_configuration: RewriteConfiguration
    ):
        """Init Rewriter instance."""
        self.model = TFLiteModel(tflite_model_path)
        self.optimizer_configuration = optimizer_configuration
        self.train_dir = ""

    def apply_optimization(self) -> None:
        """Apply the rewrite flow."""

        def get_function(arg: str) -> Any:
            module_name = ".".join(arg.split(".")[:-1])
            fn_name = arg.split(".")[-1]
            module = importlib.import_module(module_name)
            return getattr(module, fn_name)

        if self.optimizer_configuration.optimization_target == "fully_connected":
            replace_function = "mlia.nn.rewrite.library.fc_layer.get_keras_model"
        else:
            raise ConfigurationError(
                "Only fully_connected replacement is supported in rewrite module."
            )

        replace_fn = get_function(replace_function)

        augmentation_preset = (None, None)
        use_unmodified_model = True
        tflite_model = self.model.model_path
        tfrecord = str(self.optimizer_configuration.dataset)

        with tempfile.TemporaryDirectory() as tmp_dir:
            tmp_output = Path(tmp_dir, "output.tflite")

            if self.train_dir:
                tmp_new = Path(tmp_dir, "new.tflite")
                new_part = train_in_dir(
                    train_dir=self.train_dir,
                    baseline_dir=None,
                    output_filename=tmp_new,
                    replace_fn=replace_fn,
                    augmentations=augmentation_preset,
                    steps=32,
                    learning_rate=1e-3,
                    batch_size=1,
                    verbose=True,
                    show_progress=True,
                )
                eval_in_dir(self.train_dir, new_part[0])
                join_in_dir(self.train_dir, new_part[0], str(tmp_output))
            else:
                if not self.optimizer_configuration.layers_to_optimize:
                    raise ConfigurationError(
                        "Input and output tensor names need to be set for rewrite."
                    )
                train(
                    source_model=tflite_model,
                    unmodified_model=tflite_model if use_unmodified_model else None,
                    output_model=str(tmp_output),
                    input_tfrec=str(tfrecord),
                    replace_fn=replace_fn,
                    input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
                    output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
                    augment=augmentation_preset,
                    steps=32,
                    learning_rate=1e-3,
                    batch_size=1,
                    verbose=True,
                    show_progress=True,
                )

    def get_model(self) -> TFLiteModel:
        """Return optimized model."""
        return self.model

    def optimization_config(self) -> str:
        """Optimization configurations."""
        return str(self.optimizer_configuration)