diff options
Diffstat (limited to 'src/mlia/nn/rewrite/library/sparsity.py')
-rw-r--r-- | src/mlia/nn/rewrite/library/sparsity.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py index 8b74b72..709593a 100644 --- a/src/mlia/nn/rewrite/library/sparsity.py +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -6,6 +6,8 @@ from typing import Any import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters + def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: """Generate TensorFlow Lite model for rewrite.""" @@ -23,15 +25,18 @@ def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: return model -def conv2d_sparsity_rewrite( - input_shape: Any, conv2d_parameters: dict[str, Any] -) -> keras.Model: +def conv2d_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: """Generate TensorFlow Lite model for rewrite.""" + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, output_shape=output_shape + ) model = tfmot.sparsity.keras.prune_low_magnitude( to_prune=keras.Sequential( [ keras.layers.InputLayer(input_shape=input_shape), keras.layers.Conv2D(**conv2d_parameters), + keras.layers.BatchNormalization(), + keras.layers.ReLU(), ] ), sparsity_m_by_n=(2, 4), |