aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations/pruning.py
blob: f629ba10a72bb0026f7d92449c846ef4e93e7543 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
Contains class Pruner to prune a model to a specified sparsity.

In order to do this, we need to have a base model and corresponding training data.
We also have to specify a subset of layers we want to prune. For more details,
please refer to the documentation for TensorFlow Model Optimization Toolkit.
"""
from dataclasses import dataclass
from typing import List
from typing import Optional
from typing import Tuple

import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.sparsity.keras import (  # pylint: disable=no-name-in-module
    pruning_wrapper,
)

from mlia.nn.tensorflow.optimizations.common import Optimizer
from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration


@dataclass
class PruningConfiguration(OptimizerConfiguration):
    """Pruning configuration."""

    optimization_target: float
    layers_to_optimize: Optional[List[str]] = None
    x_train: Optional[np.array] = None
    y_train: Optional[np.array] = None
    batch_size: int = 1
    num_epochs: int = 1

    def __str__(self) -> str:
        """Return string representation of the configuration."""
        return f"pruning: {self.optimization_target}"

    def has_training_data(self) -> bool:
        """Return True if training data provided."""
        return self.x_train is not None and self.y_train is not None


class Pruner(Optimizer):
    """
    Pruner class. Used to prune a model to a specified sparsity.

    Sample usage:
        pruner = Pruner(
            base_model,
            optimizer_configuration)

    pruner.apply_pruning()
    pruned_model = pruner.get_model()
    """

    def __init__(
        self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration
    ):
        """Init Pruner instance."""
        self.model = model
        self.optimizer_configuration = optimizer_configuration

        if not optimizer_configuration.has_training_data():
            mock_x_train, mock_y_train = self._mock_train_data()

            self.optimizer_configuration.x_train = mock_x_train
            self.optimizer_configuration.y_train = mock_y_train

    def optimization_config(self) -> str:
        """Return string representation of the optimization config."""
        return str(self.optimizer_configuration)

    def _mock_train_data(self) -> Tuple[np.array, np.array]:
        # get rid of the batch_size dimension in input and output shape
        input_shape = tuple(x for x in self.model.input_shape if x is not None)
        output_shape = tuple(x for x in self.model.output_shape if x is not None)

        return (
            np.random.rand(*input_shape),
            np.random.randint(0, output_shape[-1], (output_shape[:-1])),
        )

    def _setup_pruning_params(self) -> dict:
        return {
            "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay(
                initial_sparsity=0,
                final_sparsity=self.optimizer_configuration.optimization_target,
                begin_step=0,
                end_step=self.optimizer_configuration.num_epochs,
                frequency=1,
            ),
        }

    def _apply_pruning_to_layer(
        self, layer: tf.keras.layers.Layer
    ) -> tf.keras.layers.Layer:
        layers_to_optimize = self.optimizer_configuration.layers_to_optimize
        assert layers_to_optimize, "List of the layers to optimize is empty"

        if layer.name not in layers_to_optimize:
            return layer

        pruning_params = self._setup_pruning_params()
        return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)

    def _init_for_pruning(self) -> None:
        # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer`
        # to the layers of the model
        if not self.optimizer_configuration.layers_to_optimize:
            pruning_params = self._setup_pruning_params()
            prunable_model = tfmot.sparsity.keras.prune_low_magnitude(
                self.model, **pruning_params
            )
        else:
            prunable_model = tf.keras.models.clone_model(
                self.model, clone_function=self._apply_pruning_to_layer
            )

        self.model = prunable_model

    def _train_pruning(self) -> None:
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
        self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])

        # Model callbacks
        callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]

        # Fitting data
        self.model.fit(
            self.optimizer_configuration.x_train,
            self.optimizer_configuration.y_train,
            batch_size=self.optimizer_configuration.batch_size,
            epochs=self.optimizer_configuration.num_epochs,
            callbacks=callbacks,
            verbose=0,
        )

    def _assert_sparsity_reached(self) -> None:
        for layer in self.model.layers:
            if not isinstance(layer, pruning_wrapper.PruneLowMagnitude):
                continue

            for weight in layer.layer.get_prunable_weights():
                nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight))
                all_weights = tf.keras.backend.get_value(weight).size

                np.testing.assert_approx_equal(
                    self.optimizer_configuration.optimization_target,
                    1 - nonzero_weights / all_weights,
                    significant=2,
                )

    def _strip_pruning(self) -> None:
        self.model = tfmot.sparsity.keras.strip_pruning(self.model)

    def apply_optimization(self) -> None:
        """Apply all steps of pruning sequentially."""
        self._init_for_pruning()
        self._train_pruning()
        self._assert_sparsity_reached()
        self._strip_pruning()

    def get_model(self) -> tf.keras.Model:
        """Get model."""
        return self.model