From 8dec455c467b8019223a40e107378845e1419f5d Mon Sep 17 00:00:00 2001 From: Annie Tallund Date: Tue, 19 Sep 2023 16:34:14 +0200 Subject: MLIA-469 Support batch size > 1 for optimizations - Add a PruningPolicy to skip layers that are not supported by the Keras pruning API - Make dataset generation more generic to support use-cases beyond classification Signed-off-by: Annie Tallund Change-Id: I198dae2b53860f449f2fdbc71575babceed1ffcf --- src/mlia/nn/tensorflow/optimizations/pruning.py | 60 ++++++++++++++++++++----- src/mlia/nn/tensorflow/utils.py | 4 +- tests/test_nn_tensorflow_utils.py | 8 +--- 3 files changed, 51 insertions(+), 21 deletions(-) diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py index 41954b9..2d5ef0e 100644 --- a/src/mlia/nn/tensorflow/optimizations/pruning.py +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """ Contains class Pruner to prune a model to a specified sparsity. @@ -9,12 +9,17 @@ please refer to the documentation for TensorFlow Model Optimization Toolkit. """ from __future__ import annotations +import logging import typing from dataclasses import dataclass +from typing import Any import numpy as np import tensorflow as tf import tensorflow_model_optimization as tfmot +from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module + prune_registry, +) from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module pruning_wrapper, ) @@ -22,6 +27,8 @@ from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint from mlia.nn.tensorflow.optimizations.common import Optimizer from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration +logger = logging.getLogger(__name__) + @dataclass class PruningConfiguration(OptimizerConfiguration): @@ -43,6 +50,41 @@ class PruningConfiguration(OptimizerConfiguration): return self.x_train is not None and self.y_train is not None +@dataclass +class PrunableLayerPolicy(tfmot.sparsity.keras.PruningPolicy): + """A policy to skip unsupported layers. + + PrunableLayerPolicy makes sure that all layers subject for pruning + are compatible with the pruning API, and that the model supports pruning. + """ + + def allow_pruning(self, layer: tf.keras.layers.Layer) -> Any: + """Allow pruning only for layers that are prunable. + + Checks the PruneRegistry in TensorFlow Model Optimization Toolkit. + """ + layer_is_supported = prune_registry.PruneRegistry.supports(layer) + if not layer_is_supported: + logger.warning( + "Layer %s is not supported for pruning, will be skipped.", layer.name + ) + + return layer_is_supported + + def ensure_model_supports_pruning(self, model: tf.keras.Model) -> None: + """Ensure that the model contains only supported layers.""" + # Check whether the model is a Keras model. + if not isinstance(model, tf.keras.Model): + raise ValueError( + "Models that are not part of the \ + tf.keras.Model base class \ + are not supported currently." + ) + + if not model.built: + raise ValueError("Unbuilt models are not supported currently.") + + class Pruner(Optimizer): """ Pruner class. Used to prune a model to a specified sparsity. @@ -64,7 +106,7 @@ class Pruner(Optimizer): self.optimizer_configuration = optimizer_configuration if not optimizer_configuration.has_training_data(): - mock_x_train, mock_y_train = self._mock_train_data() + mock_x_train, mock_y_train = self._mock_train_data(1) self.optimizer_configuration.x_train = mock_x_train self.optimizer_configuration.y_train = mock_y_train @@ -73,14 +115,10 @@ class Pruner(Optimizer): """Return string representation of the optimization config.""" return str(self.optimizer_configuration) - def _mock_train_data(self) -> tuple[np.ndarray, np.ndarray]: - # get rid of the batch_size dimension in input and output shape - input_shape = tuple(x for x in self.model.input_shape if x is not None) - output_shape = tuple(x for x in self.model.output_shape if x is not None) - + def _mock_train_data(self, batch_size: int) -> tuple[np.ndarray, np.ndarray]: return ( - np.random.rand(*input_shape), - np.random.randint(0, output_shape[-1], (output_shape[:-1])), + np.random.rand(batch_size, *self.model.input_shape[1:]), + np.random.rand(batch_size, *self.model.output_shape[1:]), ) def _setup_pruning_params(self) -> dict: @@ -112,7 +150,7 @@ class Pruner(Optimizer): if not self.optimizer_configuration.layers_to_optimize: pruning_params = self._setup_pruning_params() prunable_model = tfmot.sparsity.keras.prune_low_magnitude( - self.model, **pruning_params + self.model, pruning_policy=PrunableLayerPolicy(), **pruning_params ) else: prunable_model = tf.keras.models.clone_model( @@ -122,7 +160,7 @@ class Pruner(Optimizer): self.model = prunable_model def _train_pruning(self) -> None: - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy() + loss_fn = tf.keras.losses.MeanAbsolutePercentageError() self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"]) # Model callbacks diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py index d688a63..77ac529 100644 --- a/src/mlia/nn/tensorflow/utils.py +++ b/src/mlia/nn/tensorflow/utils.py @@ -21,12 +21,10 @@ def representative_dataset( input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32 ) -> Callable: """Sample dataset used for quantization.""" - if input_shape[0] != 1: - raise ValueError("Only the input batch_size=1 is supported!") def dataset() -> Iterable: for _ in range(sample_count): - data = np.random.rand(*input_shape) + data = np.random.rand(1, *input_shape[1:]) yield [data.astype(input_dtype)] return dataset diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py index 5131171..14b06c4 100644 --- a/tests/test_nn_tensorflow_utils.py +++ b/tests/test_nn_tensorflow_utils.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Test for module utils/test_utils.""" from pathlib import Path @@ -31,12 +31,6 @@ def test_generate_representative_dataset() -> None: assert isinstance(ndarray, np.ndarray) -def test_generate_representative_dataset_wrong_shape() -> None: - """Test that only shape with batch size=1 is supported.""" - with pytest.raises(Exception, match="Only the input batch_size=1 is supported!"): - representative_dataset([2, 3, 3], 5) - - def test_convert_saved_model_to_tflite(test_tf_model: Path) -> None: """Test converting SavedModel to TensorFlow Lite.""" result = convert_to_tflite(test_tf_model.as_posix()) -- cgit v1.2.1