aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations/clustering.py
blob: 16d9e4b3663038b94c474466a904e9a73aea9bd0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
Contains class Clusterer that clusters unique weights per layer to a specified number.

In order to do this, we need to have a base model and corresponding training data.
We also have to specify a subset of layers we want to cluster. For more details,
please refer to the documentation for TensorFlow Model Optimization Toolkit.
"""
from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import List
from typing import Optional

import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (  # pylint: disable=no-name-in-module
    cluster as experimental_cluster,
)

from mlia.nn.tensorflow.optimizations.common import Optimizer
from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration


@dataclass
class ClusteringConfiguration(OptimizerConfiguration):
    """Clustering configuration."""

    optimization_target: int
    layers_to_optimize: Optional[List[str]] = None

    def __str__(self) -> str:
        """Return string representation of the configuration."""
        return f"clustering: {self.optimization_target}"


class Clusterer(Optimizer):
    """
    Clusterer class.

    Used to cluster a model to a specified number of unique weights per layer.

    Sample usage:
        clusterer = Clusterer(
            base_model,
            optimizer_configuration)

    clusterer.apply_clustering()
    clustered_model = clusterer.get_model()
    """

    def __init__(
        self, model: tf.keras.Model, optimizer_configuration: ClusteringConfiguration
    ):
        """Init Clusterer instance."""
        self.model = model
        self.optimizer_configuration = optimizer_configuration

    def optimization_config(self) -> str:
        """Return string representation of the optimization config."""
        return str(self.optimizer_configuration)

    def _setup_clustering_params(self) -> Dict[str, Any]:
        CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
        return {
            "number_of_clusters": self.optimizer_configuration.optimization_target,
            "cluster_centroids_init": CentroidInitialization.LINEAR,
            "preserve_sparsity": True,
        }

    def _apply_clustering_to_layer(
        self, layer: tf.keras.layers.Layer
    ) -> tf.keras.layers.Layer:
        layers_to_optimize = self.optimizer_configuration.layers_to_optimize
        assert layers_to_optimize, "List of the layers to optimize is empty"

        if layer.name not in layers_to_optimize:
            return layer

        clustering_params = self._setup_clustering_params()
        return experimental_cluster.cluster_weights(layer, **clustering_params)

    def _init_for_clustering(self) -> None:
        # Use `tf.keras.models.clone_model` to apply `apply_clustering_to_layer`
        # to the layers of the model
        if not self.optimizer_configuration.layers_to_optimize:
            clustering_params = self._setup_clustering_params()
            clustered_model = experimental_cluster.cluster_weights(
                self.model, **clustering_params
            )
        else:
            clustered_model = tf.keras.models.clone_model(
                self.model, clone_function=self._apply_clustering_to_layer
            )

        self.model = clustered_model

    def _strip_clustering(self) -> None:
        self.model = tfmot.clustering.keras.strip_clustering(self.model)

    def apply_optimization(self) -> None:
        """Apply all steps of clustering at once."""
        self._init_for_clustering()
        self._strip_clustering()

    def get_model(self) -> tf.keras.Model:
        """Get model."""
        return self.model