From 0efca3cadbad5517a59884576ddb90cfe7ac30f8 Mon Sep 17 00:00:00 2001 From: Diego Russo Date: Mon, 30 May 2022 13:34:14 +0100 Subject: Add MLIA codebase Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd --- src/mlia/nn/tensorflow/optimizations/pruning.py | 168 ++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 src/mlia/nn/tensorflow/optimizations/pruning.py (limited to 'src/mlia/nn/tensorflow/optimizations/pruning.py') diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py new file mode 100644 index 0000000..f629ba1 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +""" +Contains class Pruner to prune a model to a specified sparsity. + +In order to do this, we need to have a base model and corresponding training data. +We also have to specify a subset of layers we want to prune. For more details, +please refer to the documentation for TensorFlow Model Optimization Toolkit. +""" +from dataclasses import dataclass +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import tensorflow as tf +import tensorflow_model_optimization as tfmot +from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module + pruning_wrapper, +) + +from mlia.nn.tensorflow.optimizations.common import Optimizer +from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration + + +@dataclass +class PruningConfiguration(OptimizerConfiguration): + """Pruning configuration.""" + + optimization_target: float + layers_to_optimize: Optional[List[str]] = None + x_train: Optional[np.array] = None + y_train: Optional[np.array] = None + batch_size: int = 1 + num_epochs: int = 1 + + def __str__(self) -> str: + """Return string representation of the configuration.""" + return f"pruning: {self.optimization_target}" + + def has_training_data(self) -> bool: + """Return True if training data provided.""" + return self.x_train is not None and self.y_train is not None + + +class Pruner(Optimizer): + """ + Pruner class. Used to prune a model to a specified sparsity. + + Sample usage: + pruner = Pruner( + base_model, + optimizer_configuration) + + pruner.apply_pruning() + pruned_model = pruner.get_model() + """ + + def __init__( + self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration + ): + """Init Pruner instance.""" + self.model = model + self.optimizer_configuration = optimizer_configuration + + if not optimizer_configuration.has_training_data(): + mock_x_train, mock_y_train = self._mock_train_data() + + self.optimizer_configuration.x_train = mock_x_train + self.optimizer_configuration.y_train = mock_y_train + + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" + return str(self.optimizer_configuration) + + def _mock_train_data(self) -> Tuple[np.array, np.array]: + # get rid of the batch_size dimension in input and output shape + input_shape = tuple(x for x in self.model.input_shape if x is not None) + output_shape = tuple(x for x in self.model.output_shape if x is not None) + + return ( + np.random.rand(*input_shape), + np.random.randint(0, output_shape[-1], (output_shape[:-1])), + ) + + def _setup_pruning_params(self) -> dict: + return { + "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=0, + final_sparsity=self.optimizer_configuration.optimization_target, + begin_step=0, + end_step=self.optimizer_configuration.num_epochs, + frequency=1, + ), + } + + def _apply_pruning_to_layer( + self, layer: tf.keras.layers.Layer + ) -> tf.keras.layers.Layer: + layers_to_optimize = self.optimizer_configuration.layers_to_optimize + assert layers_to_optimize, "List of the layers to optimize is empty" + + if layer.name not in layers_to_optimize: + return layer + + pruning_params = self._setup_pruning_params() + return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params) + + def _init_for_pruning(self) -> None: + # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer` + # to the layers of the model + if not self.optimizer_configuration.layers_to_optimize: + pruning_params = self._setup_pruning_params() + prunable_model = tfmot.sparsity.keras.prune_low_magnitude( + self.model, **pruning_params + ) + else: + prunable_model = tf.keras.models.clone_model( + self.model, clone_function=self._apply_pruning_to_layer + ) + + self.model = prunable_model + + def _train_pruning(self) -> None: + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy() + self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"]) + + # Model callbacks + callbacks = [tfmot.sparsity.keras.UpdatePruningStep()] + + # Fitting data + self.model.fit( + self.optimizer_configuration.x_train, + self.optimizer_configuration.y_train, + batch_size=self.optimizer_configuration.batch_size, + epochs=self.optimizer_configuration.num_epochs, + callbacks=callbacks, + verbose=0, + ) + + def _assert_sparsity_reached(self) -> None: + for layer in self.model.layers: + if not isinstance(layer, pruning_wrapper.PruneLowMagnitude): + continue + + for weight in layer.layer.get_prunable_weights(): + nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight)) + all_weights = tf.keras.backend.get_value(weight).size + + np.testing.assert_approx_equal( + self.optimizer_configuration.optimization_target, + 1 - nonzero_weights / all_weights, + significant=2, + ) + + def _strip_pruning(self) -> None: + self.model = tfmot.sparsity.keras.strip_pruning(self.model) + + def apply_optimization(self) -> None: + """Apply all steps of pruning sequentially.""" + self._init_for_pruning() + self._train_pruning() + self._assert_sparsity_reached() + self._strip_pruning() + + def get_model(self) -> tf.keras.Model: + """Get model.""" + return self.model -- cgit v1.2.1