diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-06-03 09:58:31 +0100 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-06-13 13:18:44 +0100 |
commit | 09b5122bab771161377321e3f17e05465171ad06 (patch) | |
tree | 74676eb296c110a925996448d86cc9dcde28b002 /src/mlia | |
parent | 9896c7e97da38cdaa14953fdce81a29397d1fca3 (diff) | |
download | mlia-09b5122bab771161377321e3f17e05465171ad06.tar.gz |
feat: Unstructured Sparsity Rewrites for Fully Connected and Conv2D Layers
Adds support for unstructured polynomial decay pruning rewrites
Resolves: MLIA-1171
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I9e753f35f8afe53aa24b87d794ff6986a571168f
Diffstat (limited to 'src/mlia')
5 files changed, 209 insertions, 35 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index a802c51..c2ad364 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -10,10 +10,13 @@ from abc import abstractmethod from dataclasses import dataclass from inspect import getfullargspec from pathlib import Path +from statistics import fmean from typing import Any from typing import Callable +from typing import Generator import numpy as np +import tensorflow as tf import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module @@ -32,7 +35,9 @@ from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite from mlia.nn.rewrite.library.fc_layer import fc_rewrite from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_unstructured_rewrite from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_unstructured_rewrite from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry @@ -117,9 +122,20 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC): """Apply optimization-aware quantization to a given model.""" return model + def check_optimization_generator( + self, model: keras.Model + ) -> Generator[tuple[tf.Tensor, keras.layers.Layer], None, None]: + """Loop for check_optimization function.""" + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + yield weight, layer + class SparsityRewrite(QuantizeAwareTrainingRewrite): - """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity.""" + """Base rewrite class for sparsity rewrites.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep @@ -147,9 +163,54 @@ class SparsityRewrite(QuantizeAwareTrainingRewrite): model, tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), ) - return model + def check_optimization(self, model: keras.Model) -> bool: + """Not needed here.""" + return True + + +class UnstructuredSparsityRewrite(SparsityRewrite): + """ + Rewrite class for unstructured sparsity rewrite. + + e.g. fully-connected-unstructured-sparsity. + """ + + def check_optimization( + self, model: keras.Model, final_sparsity: float = 0.5, **_: Any + ) -> bool: + """Not needed here.""" + found_sparsity_list = [] + num_dec_places = str(final_sparsity)[::-1].find(".") + for weight, _ in self.check_optimization_generator(model=model): + weight_np = weight.numpy() + found_sparsity_list.append( + round(np.count_nonzero(weight_np) / weight_np.size, num_dec_places) + ) + if len(found_sparsity_list) == 0: + logger.warning( + "\nWARNING: Could not find any layers " + "in rewrite that could be sparsely pruned" + ) + return False + found_sparsity = fmean(found_sparsity_list) + if found_sparsity != final_sparsity: + logger.warning( + "\nWARNING: Found total sparsity of " + "rewrite model: %.2f " + "expected total sparsity to be: " + "%.2f\n", + found_sparsity, + final_sparsity, + ) + return False + return True + + +class StructuredSparsityRewrite(SparsityRewrite): + """Rewrite class for structured sparsity rewrite e.g. fully-connected-sparsity.""" + def check_optimization( self, model: keras.Model, @@ -158,21 +219,17 @@ class SparsityRewrite(QuantizeAwareTrainingRewrite): **_: Any, ) -> bool: """Check if sparity has produced the correct result.""" - for layer in model.layers: - for weight in layer.weights: - if "kernel" in weight.name: - if "kernel_min" in weight.name or "kernel_max" in weight.name: - continue - if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)): - logger.warning( - "\nWARNING: Could not find (%d, %d) sparsity, " - "in layer %s for weight %s \n", - sparsity_m, - sparsity_n, - layer.name, - weight.name, - ) - return False + for weight, layer in self.check_optimization_generator(model=model): + if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)): + logger.warning( + "\nWARNING: Could not find (%d, %d) sparsity, " + "in layer %s for weight %s \n", + sparsity_m, + sparsity_n, + layer.name, + weight.name, + ) + return False return True @@ -194,22 +251,18 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): self, model: keras.Model, num_clusters: int = 2, **_: Any ) -> bool: """Check if clustering has produced the correct result.""" - for layer in model.layers: - for weight in layer.weights: - if "kernel" in weight.name: - if "kernel_min" in weight.name or "kernel_max" in weight.name: - continue - number_of_found_clusters = len(np.unique(weight)) - if number_of_found_clusters != num_clusters: - logger.warning( - "\nWARNING: Expected %d cluster(s), found %d " - "cluster(s) in layer %s for weight %s \n", - num_clusters, - number_of_found_clusters, - layer.name, - weight.name, - ) - return False + for weight, layer in self.check_optimization_generator(model=model): + number_of_found_clusters = len(np.unique(weight)) + if number_of_found_clusters != num_clusters: + logger.warning( + "\nWARNING: Expected %d cluster(s), found %d " + "cluster(s) in layer %s for weight %s \n", + num_clusters, + number_of_found_clusters, + layer.name, + weight.name, + ) + return False return True def training_callbacks(self) -> list: @@ -261,10 +314,17 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ GenericRewrite("fully-connected", fc_rewrite), - SparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite), + StructuredSparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite), ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite), - SparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite), + StructuredSparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite), + UnstructuredSparsityRewrite( + "conv2d-unstructured-sparsity", conv2d_sparsity_unstructured_rewrite + ), + UnstructuredSparsityRewrite( + "fully-connected-unstructured-sparsity", + fc_sparsity_unstructured_rewrite, + ), ] ) diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py index 0937b13..1e53254 100644 --- a/src/mlia/nn/rewrite/library/sparsity.py +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Rewrite functions used to return layers ready for sparse pruning.""" +from __future__ import annotations + from typing import Any import tensorflow_model_optimization as tfmot @@ -10,6 +12,78 @@ from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters from mlia.nn.rewrite.library.helper_functions import get_activation_function +def fc_sparsity_unstructured_rewrite( + input_shape: Any, + output_shape: Any, + initial_sparsity: float = 0.5, + final_sparsity: float = 0.5, + begin_step: int = 0, + end_step: int = 48000, +) -> keras.Model: + """Fully connected TensorFlow Lite model ready for unstructured sparse pruning.""" + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Reshape([-1]), + keras.layers.Dense(output_shape), + ] + ), + pruning_schedule=tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=initial_sparsity, + final_sparsity=final_sparsity, + begin_step=begin_step, + end_step=end_step, + ), + ) + + return model + + +def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value + input_shape: Any, + output_shape: Any, + initial_sparsity: float = 0.5, + final_sparsity: float = 0.5, + begin_step: int = 0, + end_step: int = 48000, + activation: str = "relu", + kernel_size: list[int] = [3, 3], +) -> keras.Model: + """Conv2d TensorFlow Lite model ready for unstructured sparse pruning.""" + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, + output_shape=output_shape, + kernel_size_input=kernel_size, + ) + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + activation_func_found = ( + [activation_function(**activation_function_extra_args)] + if activation_function + else [] + ) + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv2d_parameters), + keras.layers.BatchNormalization(), + *activation_func_found, + ] + ), + pruning_schedule=tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=initial_sparsity, + final_sparsity=final_sparsity, + begin_step=begin_step, + end_step=end_step, + ), + ) + + return model + + def fc_sparsity_rewrite( input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4 ) -> keras.Model: diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml new file mode 100644 index 0000000..67740ca --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.conv2d-unstructured-sparsity] +initial_sparsity = 0.25 +final_sparsity = 0.5 +end_step = 48000 +activation = "relu" +kernel_size = [3, 3] diff --git a/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml new file mode 100644 index 0000000..cd5f745 --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.fully-connected-unstructured-sparsity] +initial_sparsity = 0.25 +final_sparsity = 0.5 +end_step = 48000 diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index 236511c..8a5b360 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -76,8 +76,10 @@ BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = [ "optimization-custom-augmentation", "optimization-fully-connected-clustering", "optimization-fully-connected-pruning", + "optimization-fully-connected-unstructured-pruning", "optimization-conv2d-clustering", "optimization-conv2d-pruning", + "optimization-conv2d-unstructured-pruning", ] |