diff options
Diffstat (limited to 'src/mlia/nn/rewrite/library/sparsity.py')
-rw-r--r-- | src/mlia/nn/rewrite/library/sparsity.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py index 34d3eb7..0937b13 100644 --- a/src/mlia/nn/rewrite/library/sparsity.py +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -48,13 +48,18 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value activation_function, activation_function_extra_args = get_activation_function( activation ) + activation_func_found = ( + [activation_function(**activation_function_extra_args)] + if activation_function + else [] + ) model = tfmot.sparsity.keras.prune_low_magnitude( to_prune=keras.Sequential( [ keras.layers.InputLayer(input_shape=input_shape), keras.layers.Conv2D(**conv2d_parameters), keras.layers.BatchNormalization(), - activation_function(**activation_function_extra_args), + *activation_func_found, ] ), sparsity_m_by_n=( @@ -62,5 +67,4 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value sparsity_n, ), ) - return model |