diff options
Diffstat (limited to 'src')
5 files changed, 18 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py index b159763..2d608c3 100644 --- a/src/mlia/nn/rewrite/library/clustering.py +++ b/src/mlia/nn/rewrite/library/clustering.py @@ -36,7 +36,7 @@ def fc_clustering_rewrite( return model -def conv2d_clustering_rewrite( +def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value input_shape: Any, output_shape: Any, num_clusters: int = 2, @@ -44,6 +44,7 @@ def conv2d_clustering_rewrite( "CentroidInitialization.LINEAR" ), activation: str = "relu", + kernel_size: list[int] = [3, 3], ) -> keras.Model: """Conv2d TensorFlow Lite model ready for clustering.""" rewrite_params = { @@ -51,7 +52,9 @@ def conv2d_clustering_rewrite( "cluster_centroids_init": cluster_centroids_init, } conv2d_parameters = compute_conv2d_parameters( - input_shape=input_shape, output_shape=output_shape + input_shape=input_shape, + output_shape=output_shape, + kernel_size_input=kernel_size, ) activation_function, activation_function_extra_args = get_activation_function( activation diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py index 58d84b1..d8847a5 100644 --- a/src/mlia/nn/rewrite/library/helper_functions.py +++ b/src/mlia/nn/rewrite/library/helper_functions.py @@ -33,17 +33,20 @@ def get_activation_function( return activation_function, activation_function_extra_args -def compute_conv2d_parameters( - input_shape: np.ndarray, output_shape: np.ndarray +def compute_conv2d_parameters( # pylint: disable=dangerous-default-value + input_shape: np.ndarray, + output_shape: np.ndarray, + kernel_size_input: list[int] = [3, 3], ) -> dict[str, Any]: """Compute needed kernel size and strides for a given input and output_shape.""" input_shape = input_shape.tolist() output_shape = output_shape.tolist() + assert len(kernel_size_input) == 2, "Kernel size should have 2 entries" assert len(input_shape) == 3 assert len(output_shape) == 3 + kernel_size = tuple(kernel_size_input) num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1] padding = "valid" - kernel_size = (3, 3) stride_h = round(input_shape[0] / output_shape[0]) check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1 stride_w = round(input_shape[1] / output_shape[1]) diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py index 95f99a7..34d3eb7 100644 --- a/src/mlia/nn/rewrite/library/sparsity.py +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -31,16 +31,19 @@ def fc_sparsity_rewrite( return model -def conv2d_sparsity_rewrite( +def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4, activation: str = "relu", + kernel_size: list[int] = [3, 3], ) -> keras.Model: """Conv2d TensorFlow Lite model ready for sparse pruning.""" conv2d_parameters = compute_conv2d_parameters( - input_shape=input_shape, output_shape=output_shape + input_shape=input_shape, + output_shape=output_shape, + kernel_size_input=kernel_size, ) activation_function, activation_function_extra_args = get_activation_function( activation diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml index fe50c31..3d8adfa 100644 --- a/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml +++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml @@ -16,3 +16,4 @@ augmentations.mixup_strength = 0.0 num_clusters = 16 cluster_centroids_init = "CentroidInitialization.LINEAR" activation = "relu" +kernel_size = [3, 3] diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml index d0e05a7..aa7f982 100644 --- a/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml +++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml @@ -16,3 +16,4 @@ augmentations.mixup_strength = 0.0 sparsity_m = 2 sparsity_n = 4 activation = "relu" +kernel_size = [3, 3] |