aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations/pruning.py
blob: a30b30103979ab840a84915862489b41165ccdc4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
Contains class Pruner to prune a model to a specified sparsity.

In order to do this, we need to have a base model and corresponding training data.
We also have to specify a subset of layers we want to prune. For more details,
please refer to the documentation for TensorFlow Model Optimization Toolkit.
"""
from __future__ import annotations

import logging
import typing
from dataclasses import dataclass
from typing import Any

import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.sparsity.keras import (  # pylint: disable=no-name-in-module
    prune_registry,
)
from tensorflow_model_optimization.python.core.sparsity.keras import (  # pylint: disable=no-name-in-module
    pruning_wrapper,
)

from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration

logger = logging.getLogger(__name__)


@dataclass
class PruningConfiguration(OptimizerConfiguration):
    """Pruning configuration."""

    optimization_target: float
    layers_to_optimize: list[str] | None = None
    x_train: np.ndarray | None = None
    y_train: np.ndarray | None = None
    batch_size: int = 1
    num_epochs: int = 1

    def __str__(self) -> str:
        """Return string representation of the configuration."""
        return f"pruning: {self.optimization_target}"

    def has_training_data(self) -> bool:
        """Return True if training data provided."""
        return self.x_train is not None and self.y_train is not None


@dataclass
class PrunableLayerPolicy(tfmot.sparsity.keras.PruningPolicy):
    """A policy to skip unsupported layers.

    PrunableLayerPolicy makes sure that all layers subject for pruning
    are compatible with the pruning API, and that the model supports pruning.
    """

    def allow_pruning(self, layer: tf.keras.layers.Layer) -> Any:
        """Allow pruning only for layers that are prunable.

        Checks the PruneRegistry in TensorFlow Model Optimization Toolkit.
        """
        layer_is_supported = prune_registry.PruneRegistry.supports(layer)
        if not layer_is_supported:
            logger.warning(
                "Layer %s is not supported for pruning, will be skipped.", layer.name
            )

        return layer_is_supported

    def ensure_model_supports_pruning(self, model: tf.keras.Model) -> None:
        """Ensure that the model contains only supported layers."""
        # Check whether the model is a Keras model.
        if not isinstance(model, tf.keras.Model):
            raise ValueError(
                "Models that are not part of the \
                            tf.keras.Model base class \
                            are not supported currently."
            )

        if not model.built:
            raise ValueError("Unbuilt models are not supported currently.")


class Pruner(Optimizer):
    """
    Pruner class. Used to prune a model to a specified sparsity.

    Sample usage:
        pruner = Pruner(
            base_model,
            optimizer_configuration)

    pruner.apply_pruning()
    pruned_model = pruner.get_model()
    """

    def __init__(
        self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration
    ):
        """Init Pruner instance."""
        self.model = model
        self.optimizer_configuration = optimizer_configuration

        if not optimizer_configuration.has_training_data():
            mock_x_train, mock_y_train = self._mock_train_data(1)

            self.optimizer_configuration.x_train = mock_x_train
            self.optimizer_configuration.y_train = mock_y_train

    def optimization_config(self) -> str:
        """Return string representation of the optimization config."""
        return str(self.optimizer_configuration)

    def _mock_train_data(self, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
        return (
            np.random.rand(batch_size, *self.model.input_shape[1:]),
            np.random.rand(batch_size, *self.model.output_shape[1:]),
        )

    def _setup_pruning_params(self) -> dict:
        return {
            "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay(
                initial_sparsity=0,
                final_sparsity=self.optimizer_configuration.optimization_target,
                begin_step=0,
                end_step=self.optimizer_configuration.num_epochs,
                frequency=1,
            ),
        }

    def _apply_pruning_to_layer(
        self, layer: tf.keras.layers.Layer
    ) -> tf.keras.layers.Layer:
        layers_to_optimize = self.optimizer_configuration.layers_to_optimize
        assert layers_to_optimize, "List of the layers to optimize is empty"

        if layer.name not in layers_to_optimize:
            return layer

        pruning_params = self._setup_pruning_params()
        return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)

    def _init_for_pruning(self) -> None:
        # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer`
        # to the layers of the model
        if not self.optimizer_configuration.layers_to_optimize:
            pruning_params = self._setup_pruning_params()
            prunable_model = tfmot.sparsity.keras.prune_low_magnitude(
                self.model, pruning_policy=PrunableLayerPolicy(), **pruning_params
            )
        else:
            prunable_model = tf.keras.models.clone_model(
                self.model, clone_function=self._apply_pruning_to_layer
            )

        self.model = prunable_model

    def _train_pruning(self) -> None:
        loss_fn = tf.keras.losses.MeanAbsolutePercentageError()
        self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])

        # Model callbacks
        callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]

        # Fitting data
        self.model.fit(
            self.optimizer_configuration.x_train,
            self.optimizer_configuration.y_train,
            batch_size=self.optimizer_configuration.batch_size,
            epochs=self.optimizer_configuration.num_epochs,
            callbacks=callbacks,
            verbose=0,
        )

    @typing.no_type_check
    def _assert_sparsity_reached(self) -> None:
        for layer in self.model.layers:
            if not isinstance(layer, pruning_wrapper.PruneLowMagnitude):
                continue

            for weight in layer.layer.get_prunable_weights():
                nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight))
                all_weights = tf.keras.backend.get_value(weight).size

                # Types need to be ignored for this function call because
                # np.testing.assert_approx_equal does not have type annotation while the
                # current context does.
                np.testing.assert_approx_equal(
                    self.optimizer_configuration.optimization_target,
                    1 - nonzero_weights / all_weights,
                    significant=2,
                )

    def _strip_pruning(self) -> None:
        self.model = tfmot.sparsity.keras.strip_pruning(self.model)

    def apply_optimization(self) -> None:
        """Apply all steps of pruning sequentially."""
        self._init_for_pruning()
        self._train_pruning()
        self._assert_sparsity_reached()
        self._strip_pruning()

    def get_model(self) -> tf.keras.Model:
        """Get model."""
        return self.model