summaryrefslogtreecommitdiff
path: root/model_conditioning_examples/weight_pruning.py
diff options
context:
space:
mode:
authorAlex Tawse <alex.tawse@arm.com>2023-09-29 15:55:38 +0100
committerRichard <richard.burton@arm.com>2023-10-26 12:35:48 +0000
commitdaba3cf2e3633cbd0e4f8aabe7578b97e88deee1 (patch)
tree51024b8025e28ecb2aecd67246e189e25f5a6e6c /model_conditioning_examples/weight_pruning.py
parenta11976fb866f77305708f832e603b963969e6a14 (diff)
downloadml-embedded-evaluation-kit-daba3cf2e3633cbd0e4f8aabe7578b97e88deee1.tar.gz
MLECO-3995: Pylint + Shellcheck compatibility
* All Python scripts updated to abide by Pylint rules * good-names updated to permit short variable names: i, j, k, f, g, ex * ignore-long-lines regex updated to allow long lines for licence headers * Shell scripts now compliant with Shellcheck Signed-off-by: Alex Tawse <Alex.Tawse@arm.com> Change-Id: I8d5af8279bc08bb8acfe8f6ee7df34965552bbe5
Diffstat (limited to 'model_conditioning_examples/weight_pruning.py')
-rw-r--r--model_conditioning_examples/weight_pruning.py75
1 files changed, 54 insertions, 21 deletions
diff --git a/model_conditioning_examples/weight_pruning.py b/model_conditioning_examples/weight_pruning.py
index cbf9cf9..303b6df 100644
--- a/model_conditioning_examples/weight_pruning.py
+++ b/model_conditioning_examples/weight_pruning.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,23 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-This script will provide you with a short example of how to perform magnitude-based weight pruning in TensorFlow
+This script will provide you with a short example of how to perform
+magnitude-based weight pruning in TensorFlow
using the TensorFlow Model Optimization Toolkit.
-The output from this example will be a TensorFlow Lite model file where ~75% percent of the weights have been 'pruned' to the
+The output from this example will be a TensorFlow Lite model file
+where ~75% percent of the weights have been 'pruned' to the
value 0 during training - quantization has then been applied on top of this.
-By pruning the model we can improve compression of the model file. This can be essential for deploying certain models
-on systems with limited resources - such as embedded systems using Arm Ethos NPU. Also, if the pruned model is run
-on an Arm Ethos NPU then this pruning can improve the execution time of the model.
+By pruning the model we can improve compression of the model file.
+This can be essential for deploying certain models on systems
+with limited resources - such as embedded systems using Arm Ethos NPU.
+Also, if the pruned model is run on an Arm Ethos NPU then
+this pruning can improve the execution time of the model.
-After pruning is complete we do post-training quantization to quantize the model and then generate a TensorFlow Lite file.
+After pruning is complete we do post-training quantization
+to quantize the model and then generate a TensorFlow Lite file.
-If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+If you are targeting an Arm Ethos-U55 NPU then the output
+TensorFlow Lite file will also need to be passed through the Vela
compiler for further optimizations before it can be used.
-For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
-For more information on weight pruning see: https://www.tensorflow.org/model_optimization/guide/pruning
+For more information on using Vela see:
+ https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on weight pruning see:
+ https://www.tensorflow.org/model_optimization/guide/pruning
"""
import pathlib
@@ -43,13 +51,20 @@ from post_training_quantization import post_training_quantize, evaluate_tflite_m
def prepare_for_pruning(keras_model):
"""Prepares a Keras model for pruning."""
- # We use a constant sparsity schedule so the amount of sparsity in the model is kept at the same percent throughout
- # training. An alternative is PolynomialDecay where sparsity can be gradually increased during training.
+ # We use a constant sparsity schedule so the amount of sparsity
+ # in the model is kept at the same percent throughout training.
+ # An alternative is PolynomialDecay where sparsity
+ # can be gradually increased during training.
pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.75, begin_step=0)
- # Apply the pruning wrapper to the whole model so weights in every layer will get pruned. You may find that to avoid
- # too much accuracy loss only certain non-critical layers in your model should be pruned.
- pruning_ready_model = tfmot.sparsity.keras.prune_low_magnitude(keras_model, pruning_schedule=pruning_schedule)
+ # Apply the pruning wrapper to the whole model
+ # so weights in every layer will get pruned.
+ # You may find that to avoid too much accuracy loss only
+ # certain non-critical layers in your model should be pruned.
+ pruning_ready_model = tfmot.sparsity.keras.prune_low_magnitude(
+ keras_model,
+ pruning_schedule=pruning_schedule
+ )
# We must recompile the model after making it ready for pruning.
pruning_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
@@ -60,11 +75,15 @@ def prepare_for_pruning(keras_model):
def main():
+ """
+ Run weight pruning
+ """
x_train, y_train, x_test, y_test = get_data()
model = create_model()
# Compile and train the model first.
- # In general it is easier to do pruning as a fine-tuning step after the model is fully trained.
+ # In general, it is easier to do pruning as a fine-tuning step
+ # after the model is fully trained.
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
@@ -72,7 +91,7 @@ def main():
model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
# Test the trained model accuracy.
- test_loss, test_acc = model.evaluate(x_test, y_test)
+ test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable
print(f"Test accuracy before pruning: {test_acc:.3f}")
# Prepare the model for pruning and add the pruning update callback needed in training.
@@ -80,14 +99,23 @@ def main():
callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
# Continue training the model but now with pruning applied - remember to pass in the callbacks!
- pruned_model.fit(x=x_train, y=y_train, batch_size=128, epochs=1, verbose=1, shuffle=True, callbacks=callbacks)
+ pruned_model.fit(
+ x=x_train,
+ y=y_train,
+ batch_size=128,
+ epochs=1,
+ verbose=1,
+ shuffle=True,
+ callbacks=callbacks
+ )
test_loss, test_acc = pruned_model.evaluate(x_test, y_test)
print(f"Test accuracy after pruning: {test_acc:.3f}")
# Remove all variables that pruning only needed in the training phase.
model_for_export = tfmot.sparsity.keras.strip_pruning(pruned_model)
- # Apply post-training quantization on top of the pruning and save the resulting TensorFlow Lite model to file.
+ # Apply post-training quantization on top of the pruning
+ # and save the resulting TensorFlow Lite model to file.
tflite_model = post_training_quantize(model_for_export, x_train)
tflite_models_dir = pathlib.Path('./conditioned_models/')
@@ -97,9 +125,14 @@ def main():
with open(pruned_quant_model_save_path, 'wb') as f:
f.write(tflite_model)
- # Test the pruned quantized model accuracy. Save time by only testing a subset of the whole data.
+ # Test the pruned quantized model accuracy.
+ # Save time by only testing a subset of the whole data.
num_test_samples = 1000
- evaluate_tflite_model(pruned_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+ evaluate_tflite_model(
+ pruned_quant_model_save_path,
+ x_test[0:num_test_samples],
+ y_test[0:num_test_samples]
+ )
if __name__ == "__main__":