aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
blob: 6b2798434e9312db48f074787f1158a876c413c4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Contains class Rewriter to replace a subgraph/layer of a model."""
from __future__ import annotations

import importlib
import logging
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from mlia.core.errors import ConfigurationError
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import TFLiteModel


logger = logging.getLogger(__name__)


@dataclass
class RewriteConfiguration(OptimizerConfiguration):
    """Rewrite configuration."""

    optimization_target: str
    layers_to_optimize: list[str] | None = None
    dataset: Path | None = None
    train_params: TrainingParameters = TrainingParameters()

    def __str__(self) -> str:
        """Return string representation of the configuration."""
        return f"rewrite: {self.optimization_target}"


class Rewriter(Optimizer):
    """Rewriter class for basic rewrite flow."""

    def __init__(
        self, tflite_model_path: Path, optimizer_configuration: RewriteConfiguration
    ):
        """Init Rewriter instance."""
        self.model = TFLiteModel(tflite_model_path)
        self.model_path = tflite_model_path
        self.optimizer_configuration = optimizer_configuration

    def apply_optimization(self) -> None:
        """Apply the rewrite flow."""

        def get_function(arg: str) -> Any:
            module_name = ".".join(arg.split(".")[:-1])
            fn_name = arg.split(".")[-1]
            module = importlib.import_module(module_name)
            return getattr(module, fn_name)

        if self.optimizer_configuration.optimization_target == "fully_connected":
            replace_function = "mlia.nn.rewrite.library.fc_layer.get_keras_model"
        else:
            raise ConfigurationError(
                "Only fully_connected replacement is supported in rewrite module."
            )

        replace_fn = get_function(replace_function)

        use_unmodified_model = True
        tflite_model = self.model.model_path
        tfrecord = str(self.optimizer_configuration.dataset)

        tmp_dir = tempfile.mkdtemp()
        tmp_output = Path(tmp_dir, "output.tflite")

        if not self.optimizer_configuration.layers_to_optimize:
            raise ConfigurationError(
                "Input and output tensor names need to be set for rewrite."
            )
        result = train(
            source_model=tflite_model,
            unmodified_model=tflite_model if use_unmodified_model else None,
            output_model=str(tmp_output),
            input_tfrec=str(tfrecord),
            replace_fn=replace_fn,
            input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
            output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
            train_params=self.optimizer_configuration.train_params,
        )

        self.model = TFLiteModel(tmp_output)

        if result:
            stats_as_str = ", ".join(str(stats) for stats in result)
            logger.info(
                "The MAE and NRMSE between original and replacement [%s]",
                stats_as_str,
            )

    def get_model(self) -> TFLiteModel:
        """Return optimized model."""
        return self.model

    def optimization_config(self) -> str:
        """Optimization configurations."""
        return str(self.optimizer_configuration)