aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations/clustering.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations/clustering.py')
-rw-r--r--src/mlia/nn/tensorflow/optimizations/clustering.py109
1 files changed, 109 insertions, 0 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py
new file mode 100644
index 0000000..16d9e4b
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/clustering.py
@@ -0,0 +1,109 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Contains class Clusterer that clusters unique weights per layer to a specified number.
+
+In order to do this, we need to have a base model and corresponding training data.
+We also have to specify a subset of layers we want to cluster. For more details,
+please refer to the documentation for TensorFlow Model Optimization Toolkit.
+"""
+from dataclasses import dataclass
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+from tensorflow_model_optimization.python.core.clustering.keras.experimental import ( # pylint: disable=no-name-in-module
+ cluster as experimental_cluster,
+)
+
+from mlia.nn.tensorflow.optimizations.common import Optimizer
+from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+
+
+@dataclass
+class ClusteringConfiguration(OptimizerConfiguration):
+ """Clustering configuration."""
+
+ optimization_target: int
+ layers_to_optimize: Optional[List[str]] = None
+
+ def __str__(self) -> str:
+ """Return string representation of the configuration."""
+ return f"clustering: {self.optimization_target}"
+
+
+class Clusterer(Optimizer):
+ """
+ Clusterer class.
+
+ Used to cluster a model to a specified number of unique weights per layer.
+
+ Sample usage:
+ clusterer = Clusterer(
+ base_model,
+ optimizer_configuration)
+
+ clusterer.apply_clustering()
+ clustered_model = clusterer.get_model()
+ """
+
+ def __init__(
+ self, model: tf.keras.Model, optimizer_configuration: ClusteringConfiguration
+ ):
+ """Init Clusterer instance."""
+ self.model = model
+ self.optimizer_configuration = optimizer_configuration
+
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
+ return str(self.optimizer_configuration)
+
+ def _setup_clustering_params(self) -> Dict[str, Any]:
+ CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
+ return {
+ "number_of_clusters": self.optimizer_configuration.optimization_target,
+ "cluster_centroids_init": CentroidInitialization.LINEAR,
+ "preserve_sparsity": True,
+ }
+
+ def _apply_clustering_to_layer(
+ self, layer: tf.keras.layers.Layer
+ ) -> tf.keras.layers.Layer:
+ layers_to_optimize = self.optimizer_configuration.layers_to_optimize
+ assert layers_to_optimize, "List of the layers to optimize is empty"
+
+ if layer.name not in layers_to_optimize:
+ return layer
+
+ clustering_params = self._setup_clustering_params()
+ return experimental_cluster.cluster_weights(layer, **clustering_params)
+
+ def _init_for_clustering(self) -> None:
+ # Use `tf.keras.models.clone_model` to apply `apply_clustering_to_layer`
+ # to the layers of the model
+ if not self.optimizer_configuration.layers_to_optimize:
+ clustering_params = self._setup_clustering_params()
+ clustered_model = experimental_cluster.cluster_weights(
+ self.model, **clustering_params
+ )
+ else:
+ clustered_model = tf.keras.models.clone_model(
+ self.model, clone_function=self._apply_clustering_to_layer
+ )
+
+ self.model = clustered_model
+
+ def _strip_clustering(self) -> None:
+ self.model = tfmot.clustering.keras.strip_clustering(self.model)
+
+ def apply_optimization(self) -> None:
+ """Apply all steps of clustering at once."""
+ self._init_for_clustering()
+ self._strip_clustering()
+
+ def get_model(self) -> tf.keras.Model:
+ """Get model."""
+ return self.model