diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations')
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/__init__.py | 3 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/clustering.py | 109 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/common.py | 29 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/pruning.py | 168 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/select.py | 179 |
5 files changed, 488 insertions, 0 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/__init__.py b/src/mlia/nn/tensorflow/optimizations/__init__.py new file mode 100644 index 0000000..201c130 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Optimizations module.""" diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py new file mode 100644 index 0000000..16d9e4b --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/clustering.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +""" +Contains class Clusterer that clusters unique weights per layer to a specified number. + +In order to do this, we need to have a base model and corresponding training data. +We also have to specify a subset of layers we want to cluster. For more details, +please refer to the documentation for TensorFlow Model Optimization Toolkit. +""" +from dataclasses import dataclass +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import tensorflow as tf +import tensorflow_model_optimization as tfmot +from tensorflow_model_optimization.python.core.clustering.keras.experimental import ( # pylint: disable=no-name-in-module + cluster as experimental_cluster, +) + +from mlia.nn.tensorflow.optimizations.common import Optimizer +from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration + + +@dataclass +class ClusteringConfiguration(OptimizerConfiguration): + """Clustering configuration.""" + + optimization_target: int + layers_to_optimize: Optional[List[str]] = None + + def __str__(self) -> str: + """Return string representation of the configuration.""" + return f"clustering: {self.optimization_target}" + + +class Clusterer(Optimizer): + """ + Clusterer class. + + Used to cluster a model to a specified number of unique weights per layer. + + Sample usage: + clusterer = Clusterer( + base_model, + optimizer_configuration) + + clusterer.apply_clustering() + clustered_model = clusterer.get_model() + """ + + def __init__( + self, model: tf.keras.Model, optimizer_configuration: ClusteringConfiguration + ): + """Init Clusterer instance.""" + self.model = model + self.optimizer_configuration = optimizer_configuration + + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" + return str(self.optimizer_configuration) + + def _setup_clustering_params(self) -> Dict[str, Any]: + CentroidInitialization = tfmot.clustering.keras.CentroidInitialization + return { + "number_of_clusters": self.optimizer_configuration.optimization_target, + "cluster_centroids_init": CentroidInitialization.LINEAR, + "preserve_sparsity": True, + } + + def _apply_clustering_to_layer( + self, layer: tf.keras.layers.Layer + ) -> tf.keras.layers.Layer: + layers_to_optimize = self.optimizer_configuration.layers_to_optimize + assert layers_to_optimize, "List of the layers to optimize is empty" + + if layer.name not in layers_to_optimize: + return layer + + clustering_params = self._setup_clustering_params() + return experimental_cluster.cluster_weights(layer, **clustering_params) + + def _init_for_clustering(self) -> None: + # Use `tf.keras.models.clone_model` to apply `apply_clustering_to_layer` + # to the layers of the model + if not self.optimizer_configuration.layers_to_optimize: + clustering_params = self._setup_clustering_params() + clustered_model = experimental_cluster.cluster_weights( + self.model, **clustering_params + ) + else: + clustered_model = tf.keras.models.clone_model( + self.model, clone_function=self._apply_clustering_to_layer + ) + + self.model = clustered_model + + def _strip_clustering(self) -> None: + self.model = tfmot.clustering.keras.strip_clustering(self.model) + + def apply_optimization(self) -> None: + """Apply all steps of clustering at once.""" + self._init_for_clustering() + self._strip_clustering() + + def get_model(self) -> tf.keras.Model: + """Get model.""" + return self.model diff --git a/src/mlia/nn/tensorflow/optimizations/common.py b/src/mlia/nn/tensorflow/optimizations/common.py new file mode 100644 index 0000000..1dce0b2 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/common.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Common items for the optimizations module.""" +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass + +import tensorflow as tf + + +@dataclass +class OptimizerConfiguration: + """Abstract optimizer configuration.""" + + +class Optimizer(ABC): + """Abstract class for the optimizer.""" + + @abstractmethod + def get_model(self) -> tf.keras.Model: + """Abstract method to return the model instance from the optimizer.""" + + @abstractmethod + def apply_optimization(self) -> None: + """Abstract method to apply optimization to the model.""" + + @abstractmethod + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py new file mode 100644 index 0000000..f629ba1 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +""" +Contains class Pruner to prune a model to a specified sparsity. + +In order to do this, we need to have a base model and corresponding training data. +We also have to specify a subset of layers we want to prune. For more details, +please refer to the documentation for TensorFlow Model Optimization Toolkit. +""" +from dataclasses import dataclass +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import tensorflow as tf +import tensorflow_model_optimization as tfmot +from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module + pruning_wrapper, +) + +from mlia.nn.tensorflow.optimizations.common import Optimizer +from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration + + +@dataclass +class PruningConfiguration(OptimizerConfiguration): + """Pruning configuration.""" + + optimization_target: float + layers_to_optimize: Optional[List[str]] = None + x_train: Optional[np.array] = None + y_train: Optional[np.array] = None + batch_size: int = 1 + num_epochs: int = 1 + + def __str__(self) -> str: + """Return string representation of the configuration.""" + return f"pruning: {self.optimization_target}" + + def has_training_data(self) -> bool: + """Return True if training data provided.""" + return self.x_train is not None and self.y_train is not None + + +class Pruner(Optimizer): + """ + Pruner class. Used to prune a model to a specified sparsity. + + Sample usage: + pruner = Pruner( + base_model, + optimizer_configuration) + + pruner.apply_pruning() + pruned_model = pruner.get_model() + """ + + def __init__( + self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration + ): + """Init Pruner instance.""" + self.model = model + self.optimizer_configuration = optimizer_configuration + + if not optimizer_configuration.has_training_data(): + mock_x_train, mock_y_train = self._mock_train_data() + + self.optimizer_configuration.x_train = mock_x_train + self.optimizer_configuration.y_train = mock_y_train + + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" + return str(self.optimizer_configuration) + + def _mock_train_data(self) -> Tuple[np.array, np.array]: + # get rid of the batch_size dimension in input and output shape + input_shape = tuple(x for x in self.model.input_shape if x is not None) + output_shape = tuple(x for x in self.model.output_shape if x is not None) + + return ( + np.random.rand(*input_shape), + np.random.randint(0, output_shape[-1], (output_shape[:-1])), + ) + + def _setup_pruning_params(self) -> dict: + return { + "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=0, + final_sparsity=self.optimizer_configuration.optimization_target, + begin_step=0, + end_step=self.optimizer_configuration.num_epochs, + frequency=1, + ), + } + + def _apply_pruning_to_layer( + self, layer: tf.keras.layers.Layer + ) -> tf.keras.layers.Layer: + layers_to_optimize = self.optimizer_configuration.layers_to_optimize + assert layers_to_optimize, "List of the layers to optimize is empty" + + if layer.name not in layers_to_optimize: + return layer + + pruning_params = self._setup_pruning_params() + return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params) + + def _init_for_pruning(self) -> None: + # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer` + # to the layers of the model + if not self.optimizer_configuration.layers_to_optimize: + pruning_params = self._setup_pruning_params() + prunable_model = tfmot.sparsity.keras.prune_low_magnitude( + self.model, **pruning_params + ) + else: + prunable_model = tf.keras.models.clone_model( + self.model, clone_function=self._apply_pruning_to_layer + ) + + self.model = prunable_model + + def _train_pruning(self) -> None: + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy() + self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"]) + + # Model callbacks + callbacks = [tfmot.sparsity.keras.UpdatePruningStep()] + + # Fitting data + self.model.fit( + self.optimizer_configuration.x_train, + self.optimizer_configuration.y_train, + batch_size=self.optimizer_configuration.batch_size, + epochs=self.optimizer_configuration.num_epochs, + callbacks=callbacks, + verbose=0, + ) + + def _assert_sparsity_reached(self) -> None: + for layer in self.model.layers: + if not isinstance(layer, pruning_wrapper.PruneLowMagnitude): + continue + + for weight in layer.layer.get_prunable_weights(): + nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight)) + all_weights = tf.keras.backend.get_value(weight).size + + np.testing.assert_approx_equal( + self.optimizer_configuration.optimization_target, + 1 - nonzero_weights / all_weights, + significant=2, + ) + + def _strip_pruning(self) -> None: + self.model = tfmot.sparsity.keras.strip_pruning(self.model) + + def apply_optimization(self) -> None: + """Apply all steps of pruning sequentially.""" + self._init_for_pruning() + self._train_pruning() + self._assert_sparsity_reached() + self._strip_pruning() + + def get_model(self) -> tf.keras.Model: + """Get model.""" + return self.model diff --git a/src/mlia/nn/tensorflow/optimizations/select.py b/src/mlia/nn/tensorflow/optimizations/select.py new file mode 100644 index 0000000..1b0c755 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/select.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for optimization selection.""" +import math +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Tuple +from typing import Union + +import tensorflow as tf + +from mlia.core.errors import ConfigurationError +from mlia.nn.tensorflow.config import KerasModel +from mlia.nn.tensorflow.optimizations.clustering import Clusterer +from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration +from mlia.nn.tensorflow.optimizations.common import Optimizer +from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration +from mlia.nn.tensorflow.optimizations.pruning import Pruner +from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration +from mlia.utils.types import is_list_of + + +class OptimizationSettings(NamedTuple): + """Optimization settings.""" + + optimization_type: str + optimization_target: Union[int, float] + layers_to_optimize: Optional[List[str]] + + @staticmethod + def create_from( + optimizer_params: List[Tuple[str, float]], + layers_to_optimize: Optional[List[str]] = None, + ) -> List["OptimizationSettings"]: + """Create optimization settings from the provided parameters.""" + return [ + OptimizationSettings( + optimization_type=opt_type, + optimization_target=opt_target, + layers_to_optimize=layers_to_optimize, + ) + for opt_type, opt_target in optimizer_params + ] + + def __str__(self) -> str: + """Return string representation.""" + return f"{self.optimization_type}: {self.optimization_target}" + + def next_target(self) -> "OptimizationSettings": + """Return next optimization target.""" + if self.optimization_type == "pruning": + next_target = round(min(self.optimization_target + 0.1, 0.9), 2) + return OptimizationSettings( + self.optimization_type, next_target, self.layers_to_optimize + ) + + if self.optimization_type == "clustering": + # return next lowest power of two for clustering + next_target = math.log(self.optimization_target, 2) + if next_target.is_integer(): + next_target -= 1 + + next_target = max(int(2 ** int(next_target)), 4) + return OptimizationSettings( + self.optimization_type, next_target, self.layers_to_optimize + ) + + raise Exception(f"Unknown optimization type {self.optimization_type}") + + +class MultiStageOptimizer(Optimizer): + """Optimizer with multiply stages.""" + + def __init__( + self, + model: tf.keras.Model, + optimizations: List[OptimizerConfiguration], + ) -> None: + """Init MultiStageOptimizer instance.""" + self.model = model + self.optimizations = optimizations + + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" + return " - ".join(str(opt) for opt in self.optimizations) + + def get_model(self) -> tf.keras.Model: + """Return optimized model.""" + return self.model + + def apply_optimization(self) -> None: + """Apply optimization to the model.""" + for config in self.optimizations: + optimizer = get_optimizer(self.model, config) + optimizer.apply_optimization() + self.model = optimizer.get_model() + + +def get_optimizer( + model: Union[tf.keras.Model, KerasModel], + config: Union[ + OptimizerConfiguration, OptimizationSettings, List[OptimizationSettings] + ], +) -> Optimizer: + """Get optimizer for provided configuration.""" + if isinstance(model, KerasModel): + model = model.get_keras_model() + + if isinstance(config, PruningConfiguration): + return Pruner(model, config) + + if isinstance(config, ClusteringConfiguration): + return Clusterer(model, config) + + if isinstance(config, OptimizationSettings) or is_list_of( + config, OptimizationSettings + ): + return _get_optimizer(model, config) # type: ignore + + raise ConfigurationError(f"Unknown optimization configuration {config}") + + +def _get_optimizer( + model: tf.keras.Model, + optimization_settings: Union[OptimizationSettings, List[OptimizationSettings]], +) -> Optimizer: + if isinstance(optimization_settings, OptimizationSettings): + optimization_settings = [optimization_settings] + + optimizer_configs = [] + for opt_type, opt_target, layers_to_optimize in optimization_settings: + _check_optimizer_params(opt_type, opt_target) + + opt_config = _get_optimizer_configuration( + opt_type, opt_target, layers_to_optimize + ) + optimizer_configs.append(opt_config) + + if len(optimizer_configs) == 1: + return get_optimizer(model, optimizer_configs[0]) + + return MultiStageOptimizer(model, optimizer_configs) + + +def _get_optimizer_configuration( + optimization_type: str, + optimization_target: Union[int, float], + layers_to_optimize: Optional[List[str]] = None, +) -> OptimizerConfiguration: + """Get optimizer configuration for provided parameters.""" + _check_optimizer_params(optimization_type, optimization_target) + + opt_type = optimization_type.lower() + if opt_type == "pruning": + return PruningConfiguration(optimization_target, layers_to_optimize) + + if opt_type == "clustering": + # make sure an integer is given as clustering target + if optimization_target == int(optimization_target): + return ClusteringConfiguration(int(optimization_target), layers_to_optimize) + + raise ConfigurationError( + "Optimization target should be a positive integer. " + f"Optimization target provided: {optimization_target}" + ) + + raise ConfigurationError(f"Unsupported optimization type: {optimization_type}") + + +def _check_optimizer_params( + optimization_type: str, optimization_target: Union[int, float] +) -> None: + """Check optimizer params.""" + if not optimization_target: + raise ConfigurationError("Optimization target is not provided") + + if not optimization_type: + raise ConfigurationError("Optimization type is not provided") |