aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/library
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/library')
-rw-r--r--src/mlia/nn/rewrite/library/clustering.py5
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py18
-rw-r--r--src/mlia/nn/rewrite/library/helper_functions.py2
-rw-r--r--src/mlia/nn/rewrite/library/layers.py53
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py80
5 files changed, 135 insertions, 23 deletions
diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py
index 48914dc..9247457 100644
--- a/src/mlia/nn/rewrite/library/clustering.py
+++ b/src/mlia/nn/rewrite/library/clustering.py
@@ -45,6 +45,7 @@ def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value
),
activation: str = "relu",
kernel_size: list[int] = [3, 3],
+ layer_type: type[keras.layers.Layer] = keras.layers.Conv2D,
) -> keras.Model:
"""Conv2d TensorFlow Lite model ready for clustering."""
rewrite_params = {
@@ -59,7 +60,7 @@ def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value
activation_function, activation_function_extra_args = get_activation_function(
activation
)
- activation_func_found = (
+ activation_func_found = ( # pylint: disable=duplicate-code
[activation_function(**activation_function_extra_args)]
if activation_function
else []
@@ -68,7 +69,7 @@ def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value
to_cluster=keras.Sequential(
[
keras.layers.InputLayer(input_shape=input_shape),
- keras.layers.Conv2D(**conv2d_parameters),
+ layer_type(**conv2d_parameters),
keras.layers.BatchNormalization(),
*activation_func_found,
]
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
deleted file mode 100644
index 92195d1..0000000
--- a/src/mlia/nn/rewrite/library/fc_layer.py
+++ /dev/null
@@ -1,18 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Rewrite function used to return regular layers."""
-from typing import Any
-
-from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-
-
-def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
- """Fully connected TensorFlow Lite model for rewrite."""
- model = keras.Sequential(
- (
- keras.layers.InputLayer(input_shape=input_shape),
- keras.layers.Reshape([-1]),
- keras.layers.Dense(output_shape),
- )
- )
- return model
diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py
index 1237c17..fd5993b 100644
--- a/src/mlia/nn/rewrite/library/helper_functions.py
+++ b/src/mlia/nn/rewrite/library/helper_functions.py
@@ -45,7 +45,7 @@ def compute_conv2d_parameters( # pylint: disable=dangerous-default-value
assert len(input_shape) == 3
assert len(output_shape) == 3
kernel_size = tuple(kernel_size_input)
- num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1]
+ num_filters = output_shape[-1]
padding = "valid"
stride_h = round(input_shape[0] / output_shape[0])
check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1
diff --git a/src/mlia/nn/rewrite/library/layers.py b/src/mlia/nn/rewrite/library/layers.py
new file mode 100644
index 0000000..abf0a4c
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/layers.py
@@ -0,0 +1,53 @@
+# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Rewrite function used to return regular layers."""
+from typing import Any
+
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+from mlia.nn.rewrite.library.helper_functions import get_activation_function
+
+
+def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Fully connected TensorFlow Lite model for rewrite."""
+ model = keras.Sequential(
+ (
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ )
+ )
+ return model
+
+
+def conv2d_rewrite( # pylint: disable=dangerous-default-value
+ input_shape: Any,
+ output_shape: Any,
+ activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
+ layer_type: type[keras.layers.Layer] = keras.layers.Conv2D,
+) -> keras.Model:
+ """Fully connected TensorFlow Lite model for rewrite."""
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ activation_func_found = ( # pylint: disable=duplicate-code
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
+ )
+ model = keras.Sequential(
+ (
+ keras.layers.InputLayer(input_shape=input_shape),
+ layer_type(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ *activation_func_found,
+ )
+ )
+ return model
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
index 0937b13..5102094 100644
--- a/src/mlia/nn/rewrite/library/sparsity.py
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Rewrite functions used to return layers ready for sparse pruning."""
+from __future__ import annotations
+
from typing import Any
import tensorflow_model_optimization as tfmot
@@ -10,6 +12,79 @@ from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
from mlia.nn.rewrite.library.helper_functions import get_activation_function
+def fc_sparsity_unstructured_rewrite(
+ input_shape: Any,
+ output_shape: Any,
+ initial_sparsity: float = 0.5,
+ final_sparsity: float = 0.5,
+ begin_step: int = 0,
+ end_step: int = 48000,
+) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for unstructured sparse pruning."""
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ ]
+ ),
+ pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=initial_sparsity,
+ final_sparsity=final_sparsity,
+ begin_step=begin_step,
+ end_step=end_step,
+ ),
+ )
+
+ return model
+
+
+def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value, too-many-arguments
+ input_shape: Any,
+ output_shape: Any,
+ initial_sparsity: float = 0.5,
+ final_sparsity: float = 0.5,
+ begin_step: int = 0,
+ end_step: int = 48000,
+ activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
+ layer_type: type[keras.layers.Layer] = keras.layers.Conv2D,
+) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for unstructured sparse pruning."""
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ activation_func_found = (
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
+ )
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ layer_type(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ *activation_func_found,
+ ]
+ ),
+ pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=initial_sparsity,
+ final_sparsity=final_sparsity,
+ begin_step=begin_step,
+ end_step=end_step,
+ ),
+ )
+
+ return model
+
+
def fc_sparsity_rewrite(
input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4
) -> keras.Model:
@@ -38,6 +113,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value
sparsity_n: int = 4,
activation: str = "relu",
kernel_size: list[int] = [3, 3],
+ layer_type: type[keras.layers.Layer] = keras.layers.Conv2D,
) -> keras.Model:
"""Conv2d TensorFlow Lite model ready for sparse pruning."""
conv2d_parameters = compute_conv2d_parameters(
@@ -48,7 +124,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value
activation_function, activation_function_extra_args = get_activation_function(
activation
)
- activation_func_found = (
+ activation_func_found = ( # pylint: disable=duplicate-code
[activation_function(**activation_function_extra_args)]
if activation_function
else []
@@ -57,7 +133,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value
to_prune=keras.Sequential(
[
keras.layers.InputLayer(input_shape=input_shape),
- keras.layers.Conv2D(**conv2d_parameters),
+ layer_type(**conv2d_parameters),
keras.layers.BatchNormalization(),
*activation_func_found,
]