aboutsummaryrefslogtreecommitdiff
path: root/src/mlia
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia')
-rw-r--r--src/mlia/nn/rewrite/library/clustering.py7
-rw-r--r--src/mlia/nn/rewrite/library/helper_functions.py9
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py7
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml1
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml1
5 files changed, 18 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py
index b159763..2d608c3 100644
--- a/src/mlia/nn/rewrite/library/clustering.py
+++ b/src/mlia/nn/rewrite/library/clustering.py
@@ -36,7 +36,7 @@ def fc_clustering_rewrite(
return model
-def conv2d_clustering_rewrite(
+def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value
input_shape: Any,
output_shape: Any,
num_clusters: int = 2,
@@ -44,6 +44,7 @@ def conv2d_clustering_rewrite(
"CentroidInitialization.LINEAR"
),
activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
) -> keras.Model:
"""Conv2d TensorFlow Lite model ready for clustering."""
rewrite_params = {
@@ -51,7 +52,9 @@ def conv2d_clustering_rewrite(
"cluster_centroids_init": cluster_centroids_init,
}
conv2d_parameters = compute_conv2d_parameters(
- input_shape=input_shape, output_shape=output_shape
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
)
activation_function, activation_function_extra_args = get_activation_function(
activation
diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py
index 58d84b1..d8847a5 100644
--- a/src/mlia/nn/rewrite/library/helper_functions.py
+++ b/src/mlia/nn/rewrite/library/helper_functions.py
@@ -33,17 +33,20 @@ def get_activation_function(
return activation_function, activation_function_extra_args
-def compute_conv2d_parameters(
- input_shape: np.ndarray, output_shape: np.ndarray
+def compute_conv2d_parameters( # pylint: disable=dangerous-default-value
+ input_shape: np.ndarray,
+ output_shape: np.ndarray,
+ kernel_size_input: list[int] = [3, 3],
) -> dict[str, Any]:
"""Compute needed kernel size and strides for a given input and output_shape."""
input_shape = input_shape.tolist()
output_shape = output_shape.tolist()
+ assert len(kernel_size_input) == 2, "Kernel size should have 2 entries"
assert len(input_shape) == 3
assert len(output_shape) == 3
+ kernel_size = tuple(kernel_size_input)
num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1]
padding = "valid"
- kernel_size = (3, 3)
stride_h = round(input_shape[0] / output_shape[0])
check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1
stride_w = round(input_shape[1] / output_shape[1])
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
index 95f99a7..34d3eb7 100644
--- a/src/mlia/nn/rewrite/library/sparsity.py
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -31,16 +31,19 @@ def fc_sparsity_rewrite(
return model
-def conv2d_sparsity_rewrite(
+def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value
input_shape: Any,
output_shape: Any,
sparsity_m: int = 2,
sparsity_n: int = 4,
activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
) -> keras.Model:
"""Conv2d TensorFlow Lite model ready for sparse pruning."""
conv2d_parameters = compute_conv2d_parameters(
- input_shape=input_shape, output_shape=output_shape
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
)
activation_function, activation_function_extra_args = get_activation_function(
activation
diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml
index fe50c31..3d8adfa 100644
--- a/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml
+++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml
@@ -16,3 +16,4 @@ augmentations.mixup_strength = 0.0
num_clusters = 16
cluster_centroids_init = "CentroidInitialization.LINEAR"
activation = "relu"
+kernel_size = [3, 3]
diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml
index d0e05a7..aa7f982 100644
--- a/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml
+++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml
@@ -16,3 +16,4 @@ augmentations.mixup_strength = 0.0
sparsity_m = 2
sparsity_n = 4
activation = "relu"
+kernel_size = [3, 3]