aboutsummaryrefslogtreecommitdiff
path: root/src/mlia
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-06-03 09:58:31 +0100
committerNathan Bailey <nathan.bailey@arm.com>2024-06-13 13:18:44 +0100
commit09b5122bab771161377321e3f17e05465171ad06 (patch)
tree74676eb296c110a925996448d86cc9dcde28b002 /src/mlia
parent9896c7e97da38cdaa14953fdce81a29397d1fca3 (diff)
downloadmlia-09b5122bab771161377321e3f17e05465171ad06.tar.gz
feat: Unstructured Sparsity Rewrites for Fully Connected and Conv2D Layers
Adds support for unstructured polynomial decay pruning rewrites Resolves: MLIA-1171 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I9e753f35f8afe53aa24b87d794ff6986a571168f
Diffstat (limited to 'src/mlia')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py130
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py74
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml20
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml18
-rw-r--r--src/mlia/target/config.py2
5 files changed, 209 insertions, 35 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index a802c51..c2ad364 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -10,10 +10,13 @@ from abc import abstractmethod
from dataclasses import dataclass
from inspect import getfullargspec
from pathlib import Path
+from statistics import fmean
from typing import Any
from typing import Callable
+from typing import Generator
import numpy as np
+import tensorflow as tf
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module
@@ -32,7 +35,9 @@ from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite
from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite
from mlia.nn.rewrite.library.fc_layer import fc_rewrite
from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_unstructured_rewrite
from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_unstructured_rewrite
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
@@ -117,9 +122,20 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC):
"""Apply optimization-aware quantization to a given model."""
return model
+ def check_optimization_generator(
+ self, model: keras.Model
+ ) -> Generator[tuple[tf.Tensor, keras.layers.Layer], None, None]:
+ """Loop for check_optimization function."""
+ for layer in model.layers:
+ for weight in layer.weights:
+ if "kernel" in weight.name:
+ if "kernel_min" in weight.name or "kernel_max" in weight.name:
+ continue
+ yield weight, layer
+
class SparsityRewrite(QuantizeAwareTrainingRewrite):
- """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity."""
+ """Base rewrite class for sparsity rewrites."""
pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
@@ -147,9 +163,54 @@ class SparsityRewrite(QuantizeAwareTrainingRewrite):
model,
tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
)
-
return model
+ def check_optimization(self, model: keras.Model) -> bool:
+ """Not needed here."""
+ return True
+
+
+class UnstructuredSparsityRewrite(SparsityRewrite):
+ """
+ Rewrite class for unstructured sparsity rewrite.
+
+ e.g. fully-connected-unstructured-sparsity.
+ """
+
+ def check_optimization(
+ self, model: keras.Model, final_sparsity: float = 0.5, **_: Any
+ ) -> bool:
+ """Not needed here."""
+ found_sparsity_list = []
+ num_dec_places = str(final_sparsity)[::-1].find(".")
+ for weight, _ in self.check_optimization_generator(model=model):
+ weight_np = weight.numpy()
+ found_sparsity_list.append(
+ round(np.count_nonzero(weight_np) / weight_np.size, num_dec_places)
+ )
+ if len(found_sparsity_list) == 0:
+ logger.warning(
+ "\nWARNING: Could not find any layers "
+ "in rewrite that could be sparsely pruned"
+ )
+ return False
+ found_sparsity = fmean(found_sparsity_list)
+ if found_sparsity != final_sparsity:
+ logger.warning(
+ "\nWARNING: Found total sparsity of "
+ "rewrite model: %.2f "
+ "expected total sparsity to be: "
+ "%.2f\n",
+ found_sparsity,
+ final_sparsity,
+ )
+ return False
+ return True
+
+
+class StructuredSparsityRewrite(SparsityRewrite):
+ """Rewrite class for structured sparsity rewrite e.g. fully-connected-sparsity."""
+
def check_optimization(
self,
model: keras.Model,
@@ -158,21 +219,17 @@ class SparsityRewrite(QuantizeAwareTrainingRewrite):
**_: Any,
) -> bool:
"""Check if sparity has produced the correct result."""
- for layer in model.layers:
- for weight in layer.weights:
- if "kernel" in weight.name:
- if "kernel_min" in weight.name or "kernel_max" in weight.name:
- continue
- if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)):
- logger.warning(
- "\nWARNING: Could not find (%d, %d) sparsity, "
- "in layer %s for weight %s \n",
- sparsity_m,
- sparsity_n,
- layer.name,
- weight.name,
- )
- return False
+ for weight, layer in self.check_optimization_generator(model=model):
+ if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)):
+ logger.warning(
+ "\nWARNING: Could not find (%d, %d) sparsity, "
+ "in layer %s for weight %s \n",
+ sparsity_m,
+ sparsity_n,
+ layer.name,
+ weight.name,
+ )
+ return False
return True
@@ -194,22 +251,18 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite):
self, model: keras.Model, num_clusters: int = 2, **_: Any
) -> bool:
"""Check if clustering has produced the correct result."""
- for layer in model.layers:
- for weight in layer.weights:
- if "kernel" in weight.name:
- if "kernel_min" in weight.name or "kernel_max" in weight.name:
- continue
- number_of_found_clusters = len(np.unique(weight))
- if number_of_found_clusters != num_clusters:
- logger.warning(
- "\nWARNING: Expected %d cluster(s), found %d "
- "cluster(s) in layer %s for weight %s \n",
- num_clusters,
- number_of_found_clusters,
- layer.name,
- weight.name,
- )
- return False
+ for weight, layer in self.check_optimization_generator(model=model):
+ number_of_found_clusters = len(np.unique(weight))
+ if number_of_found_clusters != num_clusters:
+ logger.warning(
+ "\nWARNING: Expected %d cluster(s), found %d "
+ "cluster(s) in layer %s for weight %s \n",
+ num_clusters,
+ number_of_found_clusters,
+ layer.name,
+ weight.name,
+ )
+ return False
return True
def training_callbacks(self) -> list:
@@ -261,10 +314,17 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
GenericRewrite("fully-connected", fc_rewrite),
- SparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite),
+ StructuredSparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite),
ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite),
ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite),
- SparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite),
+ StructuredSparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite),
+ UnstructuredSparsityRewrite(
+ "conv2d-unstructured-sparsity", conv2d_sparsity_unstructured_rewrite
+ ),
+ UnstructuredSparsityRewrite(
+ "fully-connected-unstructured-sparsity",
+ fc_sparsity_unstructured_rewrite,
+ ),
]
)
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
index 0937b13..1e53254 100644
--- a/src/mlia/nn/rewrite/library/sparsity.py
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Rewrite functions used to return layers ready for sparse pruning."""
+from __future__ import annotations
+
from typing import Any
import tensorflow_model_optimization as tfmot
@@ -10,6 +12,78 @@ from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
from mlia.nn.rewrite.library.helper_functions import get_activation_function
+def fc_sparsity_unstructured_rewrite(
+ input_shape: Any,
+ output_shape: Any,
+ initial_sparsity: float = 0.5,
+ final_sparsity: float = 0.5,
+ begin_step: int = 0,
+ end_step: int = 48000,
+) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for unstructured sparse pruning."""
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ ]
+ ),
+ pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=initial_sparsity,
+ final_sparsity=final_sparsity,
+ begin_step=begin_step,
+ end_step=end_step,
+ ),
+ )
+
+ return model
+
+
+def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value
+ input_shape: Any,
+ output_shape: Any,
+ initial_sparsity: float = 0.5,
+ final_sparsity: float = 0.5,
+ begin_step: int = 0,
+ end_step: int = 48000,
+ activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
+) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for unstructured sparse pruning."""
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ activation_func_found = (
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
+ )
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ *activation_func_found,
+ ]
+ ),
+ pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=initial_sparsity,
+ final_sparsity=final_sparsity,
+ begin_step=begin_step,
+ end_step=end_step,
+ ),
+ )
+
+ return model
+
+
def fc_sparsity_rewrite(
input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4
) -> keras.Model:
diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml
new file mode 100644
index 0000000..67740ca
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml
@@ -0,0 +1,20 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[rewrite.training_parameters]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.0
+augmentations.mixup_strength = 0.0
+
+[rewrite.conv2d-unstructured-sparsity]
+initial_sparsity = 0.25
+final_sparsity = 0.5
+end_step = 48000
+activation = "relu"
+kernel_size = [3, 3]
diff --git a/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml
new file mode 100644
index 0000000..cd5f745
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml
@@ -0,0 +1,18 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[rewrite.training_parameters]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.0
+augmentations.mixup_strength = 0.0
+
+[rewrite.fully-connected-unstructured-sparsity]
+initial_sparsity = 0.25
+final_sparsity = 0.5
+end_step = 48000
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index 236511c..8a5b360 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -76,8 +76,10 @@ BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = [
"optimization-custom-augmentation",
"optimization-fully-connected-clustering",
"optimization-fully-connected-pruning",
+ "optimization-fully-connected-unstructured-pruning",
"optimization-conv2d-clustering",
"optimization-conv2d-pruning",
+ "optimization-conv2d-unstructured-pruning",
]