diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations/pruning.py')
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/pruning.py | 31 |
1 files changed, 15 insertions, 16 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py index a30b301..866e209 100644 --- a/src/mlia/nn/tensorflow/optimizations/pruning.py +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """ Contains class Pruner to prune a model to a specified sparsity. @@ -15,8 +15,8 @@ from dataclasses import dataclass from typing import Any import numpy as np -import tensorflow as tf import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module prune_registry, ) @@ -27,6 +27,7 @@ from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration + logger = logging.getLogger(__name__) @@ -58,7 +59,7 @@ class PrunableLayerPolicy(tfmot.sparsity.keras.PruningPolicy): are compatible with the pruning API, and that the model supports pruning. """ - def allow_pruning(self, layer: tf.keras.layers.Layer) -> Any: + def allow_pruning(self, layer: keras.layers.Layer) -> Any: """Allow pruning only for layers that are prunable. Checks the PruneRegistry in TensorFlow Model Optimization Toolkit. @@ -71,13 +72,13 @@ class PrunableLayerPolicy(tfmot.sparsity.keras.PruningPolicy): return layer_is_supported - def ensure_model_supports_pruning(self, model: tf.keras.Model) -> None: + def ensure_model_supports_pruning(self, model: keras.Model) -> None: """Ensure that the model contains only supported layers.""" # Check whether the model is a Keras model. - if not isinstance(model, tf.keras.Model): + if not isinstance(model, keras.Model): raise ValueError( "Models that are not part of the \ - tf.keras.Model base class \ + keras.Model base class \ are not supported currently." ) @@ -99,7 +100,7 @@ class Pruner(Optimizer): """ def __init__( - self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration + self, model: keras.Model, optimizer_configuration: PruningConfiguration ): """Init Pruner instance.""" self.model = model @@ -132,9 +133,7 @@ class Pruner(Optimizer): ), } - def _apply_pruning_to_layer( - self, layer: tf.keras.layers.Layer - ) -> tf.keras.layers.Layer: + def _apply_pruning_to_layer(self, layer: keras.layers.Layer) -> keras.layers.Layer: layers_to_optimize = self.optimizer_configuration.layers_to_optimize assert layers_to_optimize, "List of the layers to optimize is empty" @@ -145,7 +144,7 @@ class Pruner(Optimizer): return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params) def _init_for_pruning(self) -> None: - # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer` + # Use `keras.models.clone_model` to apply `apply_pruning_to_layer` # to the layers of the model if not self.optimizer_configuration.layers_to_optimize: pruning_params = self._setup_pruning_params() @@ -153,14 +152,14 @@ class Pruner(Optimizer): self.model, pruning_policy=PrunableLayerPolicy(), **pruning_params ) else: - prunable_model = tf.keras.models.clone_model( + prunable_model = keras.models.clone_model( self.model, clone_function=self._apply_pruning_to_layer ) self.model = prunable_model def _train_pruning(self) -> None: - loss_fn = tf.keras.losses.MeanAbsolutePercentageError() + loss_fn = keras.losses.MeanAbsolutePercentageError() self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"]) # Model callbacks @@ -183,8 +182,8 @@ class Pruner(Optimizer): continue for weight in layer.layer.get_prunable_weights(): - nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight)) - all_weights = tf.keras.backend.get_value(weight).size + nonzero_weights = np.count_nonzero(keras.backend.get_value(weight)) + all_weights = keras.backend.get_value(weight).size # Types need to be ignored for this function call because # np.testing.assert_approx_equal does not have type annotation while the @@ -205,6 +204,6 @@ class Pruner(Optimizer): self._assert_sparsity_reached() self._strip_pruning() - def get_model(self) -> tf.keras.Model: + def get_model(self) -> keras.Model: """Get model.""" return self.model |