diff options
Diffstat (limited to 'src/mlia/nn/rewrite/library/sparsity.py')
-rw-r--r-- | src/mlia/nn/rewrite/library/sparsity.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py index 1e53254..5102094 100644 --- a/src/mlia/nn/rewrite/library/sparsity.py +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -40,7 +40,7 @@ def fc_sparsity_unstructured_rewrite( return model -def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value +def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value, too-many-arguments input_shape: Any, output_shape: Any, initial_sparsity: float = 0.5, @@ -49,6 +49,7 @@ def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-v end_step: int = 48000, activation: str = "relu", kernel_size: list[int] = [3, 3], + layer_type: type[keras.layers.Layer] = keras.layers.Conv2D, ) -> keras.Model: """Conv2d TensorFlow Lite model ready for unstructured sparse pruning.""" conv2d_parameters = compute_conv2d_parameters( @@ -68,7 +69,7 @@ def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-v to_prune=keras.Sequential( [ keras.layers.InputLayer(input_shape=input_shape), - keras.layers.Conv2D(**conv2d_parameters), + layer_type(**conv2d_parameters), keras.layers.BatchNormalization(), *activation_func_found, ] @@ -112,6 +113,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value sparsity_n: int = 4, activation: str = "relu", kernel_size: list[int] = [3, 3], + layer_type: type[keras.layers.Layer] = keras.layers.Conv2D, ) -> keras.Model: """Conv2d TensorFlow Lite model ready for sparse pruning.""" conv2d_parameters = compute_conv2d_parameters( @@ -122,7 +124,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value activation_function, activation_function_extra_args = get_activation_function( activation ) - activation_func_found = ( + activation_func_found = ( # pylint: disable=duplicate-code [activation_function(**activation_function_extra_args)] if activation_function else [] @@ -131,7 +133,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value to_prune=keras.Sequential( [ keras.layers.InputLayer(input_shape=input_shape), - keras.layers.Conv2D(**conv2d_parameters), + layer_type(**conv2d_parameters), keras.layers.BatchNormalization(), *activation_func_found, ] |