aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/library
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/library')
-rw-r--r--src/mlia/nn/rewrite/library/clustering.py51
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py6
-rw-r--r--src/mlia/nn/rewrite/library/helper_functions.py32
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py45
4 files changed, 131 insertions, 3 deletions
diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py
new file mode 100644
index 0000000..81bfd90
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/clustering.py
@@ -0,0 +1,51 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Rewrite functions used to return layers ready for clustering."""
+from typing import Any
+
+import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+
+
+def fc_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for clustering."""
+ rewrite_params = {
+ "number_of_clusters": 4,
+ "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
+ }
+ model = tfmot.clustering.keras.cluster_weights(
+ to_cluster=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Flatten(),
+ keras.layers.Dense(units=output_shape),
+ ]
+ ),
+ **rewrite_params
+ )
+ return model
+
+
+def conv2d_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for clustering."""
+ rewrite_params = {
+ "number_of_clusters": 4,
+ "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
+ }
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape, output_shape=output_shape
+ )
+ model = tfmot.clustering.keras.cluster_weights(
+ to_cluster=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ keras.layers.ReLU(),
+ ]
+ ),
+ **rewrite_params
+ )
+ return model
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
index 041ce85..92195d1 100644
--- a/src/mlia/nn/rewrite/library/fc_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_layer.py
@@ -1,13 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Example rewrite with one fully connected layer."""
+"""Rewrite function used to return regular layers."""
from typing import Any
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for rewrite."""
+def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Fully connected TensorFlow Lite model for rewrite."""
model = keras.Sequential(
(
keras.layers.InputLayer(input_shape=input_shape),
diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py
new file mode 100644
index 0000000..4f08170
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/helper_functions.py
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Helper functions for the rewrite library."""
+import math
+from typing import Any
+
+import numpy as np
+
+
+def compute_conv2d_parameters(
+ input_shape: np.ndarray, output_shape: np.ndarray
+) -> dict[str, Any]:
+ """Compute needed kernel size and strides for a given input and output_shape."""
+ input_shape = input_shape.tolist()
+ output_shape = output_shape.tolist()
+ assert len(input_shape) == 3
+ assert len(output_shape) == 3
+ num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1]
+ padding = "valid"
+ kernel_size = (3, 3)
+ stride_h = round(input_shape[0] / output_shape[0])
+ check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1
+ stride_w = round(input_shape[1] / output_shape[1])
+ check_output_size_w = math.floor((input_shape[1] - kernel_size[1]) / stride_w) + 1
+ if check_output_size_h != output_shape[0] or check_output_size_w != output_shape[1]:
+ padding = "same"
+ return {
+ "filters": num_filters,
+ "kernel_size": kernel_size,
+ "padding": padding,
+ "strides": (stride_h, stride_w),
+ }
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
new file mode 100644
index 0000000..745fa8b
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Rewrite functions used to return layers ready for sparse pruning."""
+from typing import Any
+
+import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+
+
+def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for sparse pruning."""
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ ]
+ ),
+ sparsity_m_by_n=(2, 4),
+ )
+
+ return model
+
+
+def conv2d_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for sparse pruning."""
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape, output_shape=output_shape
+ )
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ keras.layers.ReLU(),
+ ]
+ ),
+ sparsity_m_by_n=(2, 4),
+ )
+
+ return model