aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/library/sparsity.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/library/sparsity.py')
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
index 34d3eb7..0937b13 100644
--- a/src/mlia/nn/rewrite/library/sparsity.py
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -48,13 +48,18 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value
activation_function, activation_function_extra_args = get_activation_function(
activation
)
+ activation_func_found = (
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
+ )
model = tfmot.sparsity.keras.prune_low_magnitude(
to_prune=keras.Sequential(
[
keras.layers.InputLayer(input_shape=input_shape),
keras.layers.Conv2D(**conv2d_parameters),
keras.layers.BatchNormalization(),
- activation_function(**activation_function_extra_args),
+ *activation_func_found,
]
),
sparsity_m_by_n=(
@@ -62,5 +67,4 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value
sparsity_n,
),
)
-
return model