aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2023-09-19 16:34:14 +0200
committerAnnie Tallund <annie.tallund@arm.com>2023-09-26 14:24:39 +0200
commit8dec455c467b8019223a40e107378845e1419f5d (patch)
tree823051d364c78e3ce7f70023eb94021d8b051274
parentba251631768f25b840e93ece6a4af3db119e6dd1 (diff)
downloadmlia-8dec455c467b8019223a40e107378845e1419f5d.tar.gz
MLIA-469 Support batch size > 1 for optimizations
- Add a PruningPolicy to skip layers that are not supported by the Keras pruning API - Make dataset generation more generic to support use-cases beyond classification Signed-off-by: Annie Tallund <annie.tallund@arm.com> Change-Id: I198dae2b53860f449f2fdbc71575babceed1ffcf
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py60
-rw-r--r--src/mlia/nn/tensorflow/utils.py4
-rw-r--r--tests/test_nn_tensorflow_utils.py8
3 files changed, 51 insertions, 21 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
index 41954b9..2d5ef0e 100644
--- a/src/mlia/nn/tensorflow/optimizations/pruning.py
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
Contains class Pruner to prune a model to a specified sparsity.
@@ -9,19 +9,26 @@ please refer to the documentation for TensorFlow Model Optimization Toolkit.
"""
from __future__ import annotations
+import logging
import typing
from dataclasses import dataclass
+from typing import Any
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module
+ prune_registry,
+)
+from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module
pruning_wrapper,
)
from mlia.nn.tensorflow.optimizations.common import Optimizer
from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+logger = logging.getLogger(__name__)
+
@dataclass
class PruningConfiguration(OptimizerConfiguration):
@@ -43,6 +50,41 @@ class PruningConfiguration(OptimizerConfiguration):
return self.x_train is not None and self.y_train is not None
+@dataclass
+class PrunableLayerPolicy(tfmot.sparsity.keras.PruningPolicy):
+ """A policy to skip unsupported layers.
+
+ PrunableLayerPolicy makes sure that all layers subject for pruning
+ are compatible with the pruning API, and that the model supports pruning.
+ """
+
+ def allow_pruning(self, layer: tf.keras.layers.Layer) -> Any:
+ """Allow pruning only for layers that are prunable.
+
+ Checks the PruneRegistry in TensorFlow Model Optimization Toolkit.
+ """
+ layer_is_supported = prune_registry.PruneRegistry.supports(layer)
+ if not layer_is_supported:
+ logger.warning(
+ "Layer %s is not supported for pruning, will be skipped.", layer.name
+ )
+
+ return layer_is_supported
+
+ def ensure_model_supports_pruning(self, model: tf.keras.Model) -> None:
+ """Ensure that the model contains only supported layers."""
+ # Check whether the model is a Keras model.
+ if not isinstance(model, tf.keras.Model):
+ raise ValueError(
+ "Models that are not part of the \
+ tf.keras.Model base class \
+ are not supported currently."
+ )
+
+ if not model.built:
+ raise ValueError("Unbuilt models are not supported currently.")
+
+
class Pruner(Optimizer):
"""
Pruner class. Used to prune a model to a specified sparsity.
@@ -64,7 +106,7 @@ class Pruner(Optimizer):
self.optimizer_configuration = optimizer_configuration
if not optimizer_configuration.has_training_data():
- mock_x_train, mock_y_train = self._mock_train_data()
+ mock_x_train, mock_y_train = self._mock_train_data(1)
self.optimizer_configuration.x_train = mock_x_train
self.optimizer_configuration.y_train = mock_y_train
@@ -73,14 +115,10 @@ class Pruner(Optimizer):
"""Return string representation of the optimization config."""
return str(self.optimizer_configuration)
- def _mock_train_data(self) -> tuple[np.ndarray, np.ndarray]:
- # get rid of the batch_size dimension in input and output shape
- input_shape = tuple(x for x in self.model.input_shape if x is not None)
- output_shape = tuple(x for x in self.model.output_shape if x is not None)
-
+ def _mock_train_data(self, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
return (
- np.random.rand(*input_shape),
- np.random.randint(0, output_shape[-1], (output_shape[:-1])),
+ np.random.rand(batch_size, *self.model.input_shape[1:]),
+ np.random.rand(batch_size, *self.model.output_shape[1:]),
)
def _setup_pruning_params(self) -> dict:
@@ -112,7 +150,7 @@ class Pruner(Optimizer):
if not self.optimizer_configuration.layers_to_optimize:
pruning_params = self._setup_pruning_params()
prunable_model = tfmot.sparsity.keras.prune_low_magnitude(
- self.model, **pruning_params
+ self.model, pruning_policy=PrunableLayerPolicy(), **pruning_params
)
else:
prunable_model = tf.keras.models.clone_model(
@@ -122,7 +160,7 @@ class Pruner(Optimizer):
self.model = prunable_model
def _train_pruning(self) -> None:
- loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
+ loss_fn = tf.keras.losses.MeanAbsolutePercentageError()
self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
# Model callbacks
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index d688a63..77ac529 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -21,12 +21,10 @@ def representative_dataset(
input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32
) -> Callable:
"""Sample dataset used for quantization."""
- if input_shape[0] != 1:
- raise ValueError("Only the input batch_size=1 is supported!")
def dataset() -> Iterable:
for _ in range(sample_count):
- data = np.random.rand(*input_shape)
+ data = np.random.rand(1, *input_shape[1:])
yield [data.astype(input_dtype)]
return dataset
diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py
index 5131171..14b06c4 100644
--- a/tests/test_nn_tensorflow_utils.py
+++ b/tests/test_nn_tensorflow_utils.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module utils/test_utils."""
from pathlib import Path
@@ -31,12 +31,6 @@ def test_generate_representative_dataset() -> None:
assert isinstance(ndarray, np.ndarray)
-def test_generate_representative_dataset_wrong_shape() -> None:
- """Test that only shape with batch size=1 is supported."""
- with pytest.raises(Exception, match="Only the input batch_size=1 is supported!"):
- representative_dataset([2, 3, 3], 5)
-
-
def test_convert_saved_model_to_tflite(test_tf_model: Path) -> None:
"""Test converting SavedModel to TensorFlow Lite."""
result = convert_to_tflite(test_tf_model.as_posix())