# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """ Contains class Clusterer that clusters unique weights per layer to a specified number. In order to do this, we need to have a base model and corresponding training data. We also have to specify a subset of layers we want to cluster. For more details, please refer to the documentation for TensorFlow Model Optimization Toolkit. """ from dataclasses import dataclass from typing import Any from typing import Dict from typing import List from typing import Optional import tensorflow as tf import tensorflow_model_optimization as tfmot from tensorflow_model_optimization.python.core.clustering.keras.experimental import ( # pylint: disable=no-name-in-module cluster as experimental_cluster, ) from mlia.nn.tensorflow.optimizations.common import Optimizer from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration @dataclass class ClusteringConfiguration(OptimizerConfiguration): """Clustering configuration.""" optimization_target: int layers_to_optimize: Optional[List[str]] = None def __str__(self) -> str: """Return string representation of the configuration.""" return f"clustering: {self.optimization_target}" class Clusterer(Optimizer): """ Clusterer class. Used to cluster a model to a specified number of unique weights per layer. Sample usage: clusterer = Clusterer( base_model, optimizer_configuration) clusterer.apply_clustering() clustered_model = clusterer.get_model() """ def __init__( self, model: tf.keras.Model, optimizer_configuration: ClusteringConfiguration ): """Init Clusterer instance.""" self.model = model self.optimizer_configuration = optimizer_configuration def optimization_config(self) -> str: """Return string representation of the optimization config.""" return str(self.optimizer_configuration) def _setup_clustering_params(self) -> Dict[str, Any]: CentroidInitialization = tfmot.clustering.keras.CentroidInitialization return { "number_of_clusters": self.optimizer_configuration.optimization_target, "cluster_centroids_init": CentroidInitialization.LINEAR, "preserve_sparsity": True, } def _apply_clustering_to_layer( self, layer: tf.keras.layers.Layer ) -> tf.keras.layers.Layer: layers_to_optimize = self.optimizer_configuration.layers_to_optimize assert layers_to_optimize, "List of the layers to optimize is empty" if layer.name not in layers_to_optimize: return layer clustering_params = self._setup_clustering_params() return experimental_cluster.cluster_weights(layer, **clustering_params) def _init_for_clustering(self) -> None: # Use `tf.keras.models.clone_model` to apply `apply_clustering_to_layer` # to the layers of the model if not self.optimizer_configuration.layers_to_optimize: clustering_params = self._setup_clustering_params() clustered_model = experimental_cluster.cluster_weights( self.model, **clustering_params ) else: clustered_model = tf.keras.models.clone_model( self.model, clone_function=self._apply_clustering_to_layer ) self.model = clustered_model def _strip_clustering(self) -> None: self.model = tfmot.clustering.keras.strip_clustering(self.model) def apply_optimization(self) -> None: """Apply all steps of clustering at once.""" self._init_for_clustering() self._strip_clustering() def get_model(self) -> tf.keras.Model: """Get model.""" return self.model