aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations/pruning.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations/pruning.py')
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py31
1 files changed, 15 insertions, 16 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
index a30b301..866e209 100644
--- a/src/mlia/nn/tensorflow/optimizations/pruning.py
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
Contains class Pruner to prune a model to a specified sparsity.
@@ -15,8 +15,8 @@ from dataclasses import dataclass
from typing import Any
import numpy as np
-import tensorflow as tf
import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module
prune_registry,
)
@@ -27,6 +27,7 @@ from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
+
logger = logging.getLogger(__name__)
@@ -58,7 +59,7 @@ class PrunableLayerPolicy(tfmot.sparsity.keras.PruningPolicy):
are compatible with the pruning API, and that the model supports pruning.
"""
- def allow_pruning(self, layer: tf.keras.layers.Layer) -> Any:
+ def allow_pruning(self, layer: keras.layers.Layer) -> Any:
"""Allow pruning only for layers that are prunable.
Checks the PruneRegistry in TensorFlow Model Optimization Toolkit.
@@ -71,13 +72,13 @@ class PrunableLayerPolicy(tfmot.sparsity.keras.PruningPolicy):
return layer_is_supported
- def ensure_model_supports_pruning(self, model: tf.keras.Model) -> None:
+ def ensure_model_supports_pruning(self, model: keras.Model) -> None:
"""Ensure that the model contains only supported layers."""
# Check whether the model is a Keras model.
- if not isinstance(model, tf.keras.Model):
+ if not isinstance(model, keras.Model):
raise ValueError(
"Models that are not part of the \
- tf.keras.Model base class \
+ keras.Model base class \
are not supported currently."
)
@@ -99,7 +100,7 @@ class Pruner(Optimizer):
"""
def __init__(
- self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration
+ self, model: keras.Model, optimizer_configuration: PruningConfiguration
):
"""Init Pruner instance."""
self.model = model
@@ -132,9 +133,7 @@ class Pruner(Optimizer):
),
}
- def _apply_pruning_to_layer(
- self, layer: tf.keras.layers.Layer
- ) -> tf.keras.layers.Layer:
+ def _apply_pruning_to_layer(self, layer: keras.layers.Layer) -> keras.layers.Layer:
layers_to_optimize = self.optimizer_configuration.layers_to_optimize
assert layers_to_optimize, "List of the layers to optimize is empty"
@@ -145,7 +144,7 @@ class Pruner(Optimizer):
return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
def _init_for_pruning(self) -> None:
- # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer`
+ # Use `keras.models.clone_model` to apply `apply_pruning_to_layer`
# to the layers of the model
if not self.optimizer_configuration.layers_to_optimize:
pruning_params = self._setup_pruning_params()
@@ -153,14 +152,14 @@ class Pruner(Optimizer):
self.model, pruning_policy=PrunableLayerPolicy(), **pruning_params
)
else:
- prunable_model = tf.keras.models.clone_model(
+ prunable_model = keras.models.clone_model(
self.model, clone_function=self._apply_pruning_to_layer
)
self.model = prunable_model
def _train_pruning(self) -> None:
- loss_fn = tf.keras.losses.MeanAbsolutePercentageError()
+ loss_fn = keras.losses.MeanAbsolutePercentageError()
self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
# Model callbacks
@@ -183,8 +182,8 @@ class Pruner(Optimizer):
continue
for weight in layer.layer.get_prunable_weights():
- nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight))
- all_weights = tf.keras.backend.get_value(weight).size
+ nonzero_weights = np.count_nonzero(keras.backend.get_value(weight))
+ all_weights = keras.backend.get_value(weight).size
# Types need to be ignored for this function call because
# np.testing.assert_approx_equal does not have type annotation while the
@@ -205,6 +204,6 @@ class Pruner(Optimizer):
self._assert_sparsity_reached()
self._strip_pruning()
- def get_model(self) -> tf.keras.Model:
+ def get_model(self) -> keras.Model:
"""Get model."""
return self.model