aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations/pruning.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations/pruning.py')
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py168
1 files changed, 168 insertions, 0 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
new file mode 100644
index 0000000..f629ba1
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -0,0 +1,168 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Contains class Pruner to prune a model to a specified sparsity.
+
+In order to do this, we need to have a base model and corresponding training data.
+We also have to specify a subset of layers we want to prune. For more details,
+please refer to the documentation for TensorFlow Model Optimization Toolkit.
+"""
+from dataclasses import dataclass
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module
+ pruning_wrapper,
+)
+
+from mlia.nn.tensorflow.optimizations.common import Optimizer
+from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+
+
+@dataclass
+class PruningConfiguration(OptimizerConfiguration):
+ """Pruning configuration."""
+
+ optimization_target: float
+ layers_to_optimize: Optional[List[str]] = None
+ x_train: Optional[np.array] = None
+ y_train: Optional[np.array] = None
+ batch_size: int = 1
+ num_epochs: int = 1
+
+ def __str__(self) -> str:
+ """Return string representation of the configuration."""
+ return f"pruning: {self.optimization_target}"
+
+ def has_training_data(self) -> bool:
+ """Return True if training data provided."""
+ return self.x_train is not None and self.y_train is not None
+
+
+class Pruner(Optimizer):
+ """
+ Pruner class. Used to prune a model to a specified sparsity.
+
+ Sample usage:
+ pruner = Pruner(
+ base_model,
+ optimizer_configuration)
+
+ pruner.apply_pruning()
+ pruned_model = pruner.get_model()
+ """
+
+ def __init__(
+ self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration
+ ):
+ """Init Pruner instance."""
+ self.model = model
+ self.optimizer_configuration = optimizer_configuration
+
+ if not optimizer_configuration.has_training_data():
+ mock_x_train, mock_y_train = self._mock_train_data()
+
+ self.optimizer_configuration.x_train = mock_x_train
+ self.optimizer_configuration.y_train = mock_y_train
+
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
+ return str(self.optimizer_configuration)
+
+ def _mock_train_data(self) -> Tuple[np.array, np.array]:
+ # get rid of the batch_size dimension in input and output shape
+ input_shape = tuple(x for x in self.model.input_shape if x is not None)
+ output_shape = tuple(x for x in self.model.output_shape if x is not None)
+
+ return (
+ np.random.rand(*input_shape),
+ np.random.randint(0, output_shape[-1], (output_shape[:-1])),
+ )
+
+ def _setup_pruning_params(self) -> dict:
+ return {
+ "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=0,
+ final_sparsity=self.optimizer_configuration.optimization_target,
+ begin_step=0,
+ end_step=self.optimizer_configuration.num_epochs,
+ frequency=1,
+ ),
+ }
+
+ def _apply_pruning_to_layer(
+ self, layer: tf.keras.layers.Layer
+ ) -> tf.keras.layers.Layer:
+ layers_to_optimize = self.optimizer_configuration.layers_to_optimize
+ assert layers_to_optimize, "List of the layers to optimize is empty"
+
+ if layer.name not in layers_to_optimize:
+ return layer
+
+ pruning_params = self._setup_pruning_params()
+ return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
+
+ def _init_for_pruning(self) -> None:
+ # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer`
+ # to the layers of the model
+ if not self.optimizer_configuration.layers_to_optimize:
+ pruning_params = self._setup_pruning_params()
+ prunable_model = tfmot.sparsity.keras.prune_low_magnitude(
+ self.model, **pruning_params
+ )
+ else:
+ prunable_model = tf.keras.models.clone_model(
+ self.model, clone_function=self._apply_pruning_to_layer
+ )
+
+ self.model = prunable_model
+
+ def _train_pruning(self) -> None:
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
+ self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
+
+ # Model callbacks
+ callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
+
+ # Fitting data
+ self.model.fit(
+ self.optimizer_configuration.x_train,
+ self.optimizer_configuration.y_train,
+ batch_size=self.optimizer_configuration.batch_size,
+ epochs=self.optimizer_configuration.num_epochs,
+ callbacks=callbacks,
+ verbose=0,
+ )
+
+ def _assert_sparsity_reached(self) -> None:
+ for layer in self.model.layers:
+ if not isinstance(layer, pruning_wrapper.PruneLowMagnitude):
+ continue
+
+ for weight in layer.layer.get_prunable_weights():
+ nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight))
+ all_weights = tf.keras.backend.get_value(weight).size
+
+ np.testing.assert_approx_equal(
+ self.optimizer_configuration.optimization_target,
+ 1 - nonzero_weights / all_weights,
+ significant=2,
+ )
+
+ def _strip_pruning(self) -> None:
+ self.model = tfmot.sparsity.keras.strip_pruning(self.model)
+
+ def apply_optimization(self) -> None:
+ """Apply all steps of pruning sequentially."""
+ self._init_for_pruning()
+ self._train_pruning()
+ self._assert_sparsity_reached()
+ self._strip_pruning()
+
+ def get_model(self) -> tf.keras.Model:
+ """Get model."""
+ return self.model