aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/library/sparsity.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/library/sparsity.py')
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
index 1e53254..5102094 100644
--- a/src/mlia/nn/rewrite/library/sparsity.py
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -40,7 +40,7 @@ def fc_sparsity_unstructured_rewrite(
return model
-def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value
+def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value, too-many-arguments
input_shape: Any,
output_shape: Any,
initial_sparsity: float = 0.5,
@@ -49,6 +49,7 @@ def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-v
end_step: int = 48000,
activation: str = "relu",
kernel_size: list[int] = [3, 3],
+ layer_type: type[keras.layers.Layer] = keras.layers.Conv2D,
) -> keras.Model:
"""Conv2d TensorFlow Lite model ready for unstructured sparse pruning."""
conv2d_parameters = compute_conv2d_parameters(
@@ -68,7 +69,7 @@ def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-v
to_prune=keras.Sequential(
[
keras.layers.InputLayer(input_shape=input_shape),
- keras.layers.Conv2D(**conv2d_parameters),
+ layer_type(**conv2d_parameters),
keras.layers.BatchNormalization(),
*activation_func_found,
]
@@ -112,6 +113,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value
sparsity_n: int = 4,
activation: str = "relu",
kernel_size: list[int] = [3, 3],
+ layer_type: type[keras.layers.Layer] = keras.layers.Conv2D,
) -> keras.Model:
"""Conv2d TensorFlow Lite model ready for sparse pruning."""
conv2d_parameters = compute_conv2d_parameters(
@@ -122,7 +124,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value
activation_function, activation_function_extra_args = get_activation_function(
activation
)
- activation_func_found = (
+ activation_func_found = ( # pylint: disable=duplicate-code
[activation_function(**activation_function_extra_args)]
if activation_function
else []
@@ -131,7 +133,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value
to_prune=keras.Sequential(
[
keras.layers.InputLayer(input_shape=input_shape),
- keras.layers.Conv2D(**conv2d_parameters),
+ layer_type(**conv2d_parameters),
keras.layers.BatchNormalization(),
*activation_func_found,
]