diff options
Diffstat (limited to 'src/mlia/nn/rewrite/library/clustering.py')
-rw-r--r-- | src/mlia/nn/rewrite/library/clustering.py | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py new file mode 100644 index 0000000..9247457 --- /dev/null +++ b/src/mlia/nn/rewrite/library/clustering.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Rewrite functions used to return layers ready for clustering.""" +from typing import Any + +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function + + +def fc_clustering_rewrite( + input_shape: Any, + output_shape: Any, + num_clusters: int = 2, + cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long + "CentroidInitialization.LINEAR" + ), +) -> keras.Model: + """Fully connected TensorFlow Lite model ready for clustering.""" + rewrite_params = { + "number_of_clusters": num_clusters, + "cluster_centroids_init": cluster_centroids_init, + } + model = tfmot.clustering.keras.cluster_weights( + to_cluster=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Flatten(), + keras.layers.Dense(units=output_shape), + ] + ), + **rewrite_params + ) + return model + + +def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value + input_shape: Any, + output_shape: Any, + num_clusters: int = 2, + cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long + "CentroidInitialization.LINEAR" + ), + activation: str = "relu", + kernel_size: list[int] = [3, 3], + layer_type: type[keras.layers.Layer] = keras.layers.Conv2D, +) -> keras.Model: + """Conv2d TensorFlow Lite model ready for clustering.""" + rewrite_params = { + "number_of_clusters": num_clusters, + "cluster_centroids_init": cluster_centroids_init, + } + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, + output_shape=output_shape, + kernel_size_input=kernel_size, + ) + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + activation_func_found = ( # pylint: disable=duplicate-code + [activation_function(**activation_function_extra_args)] + if activation_function + else [] + ) + model = tfmot.clustering.keras.cluster_weights( + to_cluster=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + layer_type(**conv2d_parameters), + keras.layers.BatchNormalization(), + *activation_func_found, + ] + ), + **rewrite_params + ) + return model |