summaryrefslogtreecommitdiff
path: root/model_conditioning_examples
diff options
context:
space:
mode:
authorAlex Tawse <alex.tawse@arm.com>2023-09-29 15:55:38 +0100
committerRichard <richard.burton@arm.com>2023-10-26 12:35:48 +0000
commitdaba3cf2e3633cbd0e4f8aabe7578b97e88deee1 (patch)
tree51024b8025e28ecb2aecd67246e189e25f5a6e6c /model_conditioning_examples
parenta11976fb866f77305708f832e603b963969e6a14 (diff)
downloadml-embedded-evaluation-kit-daba3cf2e3633cbd0e4f8aabe7578b97e88deee1.tar.gz
MLECO-3995: Pylint + Shellcheck compatibility
* All Python scripts updated to abide by Pylint rules * good-names updated to permit short variable names: i, j, k, f, g, ex * ignore-long-lines regex updated to allow long lines for licence headers * Shell scripts now compliant with Shellcheck Signed-off-by: Alex Tawse <Alex.Tawse@arm.com> Change-Id: I8d5af8279bc08bb8acfe8f6ee7df34965552bbe5
Diffstat (limited to 'model_conditioning_examples')
-rw-r--r--model_conditioning_examples/post_training_quantization.py61
-rw-r--r--model_conditioning_examples/quantization_aware_training.py68
-rw-r--r--model_conditioning_examples/setup.sh9
-rw-r--r--model_conditioning_examples/training_utils.py5
-rw-r--r--model_conditioning_examples/weight_clustering.py87
-rw-r--r--model_conditioning_examples/weight_pruning.py75
6 files changed, 205 insertions, 100 deletions
diff --git a/model_conditioning_examples/post_training_quantization.py b/model_conditioning_examples/post_training_quantization.py
index a39be0e..42069f5 100644
--- a/model_conditioning_examples/post_training_quantization.py
+++ b/model_conditioning_examples/post_training_quantization.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,28 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-This script will provide you with an example of how to perform post-training quantization in TensorFlow.
+This script will provide you with an example of how to perform
+post-training quantization in TensorFlow.
-The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit
-integer values.
+The output from this example will be a TensorFlow Lite model file
+where weights and activations are quantized to 8bit integer values.
-Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm
-Ethos NPU.
+Quantization helps reduce the size of your models and is necessary
+for running models on certain hardware such as Arm Ethos NPU.
-In addition to quantizing weights, post-training quantization uses a calibration dataset to
-capture the minimum and maximum values of all variable tensors in your model.
-By capturing these ranges it is possible to fully quantize not just the weights of the model but also the activations.
+In addition to quantizing weights, post-training quantization uses
+a calibration dataset to capture the minimum and maximum values of
+all variable tensors in your model. By capturing these ranges it
+is possible to fully quantize not just the weights of the model
+but also the activations.
-Depending on the model you are quantizing there may be some accuracy loss, but for a lot of models the loss should
-be minimal.
+Depending on the model you are quantizing there may be some accuracy loss,
+but for a lot of models the loss should be minimal.
-If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+If you are targeting an Arm Ethos-U55 NPU then the output
+TensorFlow Lite file will also need to be passed through the Vela
compiler for further optimizations before it can be used.
-For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
-For more information on post-training quantization
-see: https://www.tensorflow.org/lite/performance/post_training_integer_quant
+For more information on using Vela see:
+ https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on post-training quantization see:
+ https://www.tensorflow.org/lite/performance/post_training_integer_quant
"""
+
import pathlib
import numpy as np
@@ -44,7 +50,8 @@ from training_utils import get_data, create_model
def post_training_quantize(keras_model, sample_data):
- """Quantize Keras model using post-training quantization with some sample data.
+ """
+ Quantize Keras model using post-training quantization with some sample data.
TensorFlow Lite will have fp32 inputs/outputs and the model will handle quantizing/dequantizing.
@@ -76,8 +83,14 @@ def post_training_quantize(keras_model, sample_data):
return tflite_model
-def evaluate_tflite_model(tflite_save_path, x_test, y_test):
- """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
+# pylint: disable=duplicate-code
+def evaluate_tflite_model(
+ tflite_save_path: pathlib.Path,
+ x_test: np.ndarray,
+ y_test: np.ndarray
+):
+ """
+ Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
Args:
tflite_save_path: Path to TensorFlow Lite model to test.
@@ -106,6 +119,9 @@ def evaluate_tflite_model(tflite_save_path, x_test, y_test):
def main():
+ """
+ Run post-training quantization
+ """
x_train, y_train, x_test, y_test = get_data()
model = create_model()
@@ -117,7 +133,7 @@ def main():
model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
# Test the fp32 model accuracy.
- test_loss, test_acc = model.evaluate(x_test, y_test)
+ test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable
print(f"Test accuracy float: {test_acc:.3f}")
# Quantize and export the resulting TensorFlow Lite model to file.
@@ -132,7 +148,12 @@ def main():
# Test the quantized model accuracy. Save time by only testing a subset of the whole data.
num_test_samples = 1000
- evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+ evaluate_tflite_model(
+ quant_model_save_path,
+ x_test[0:num_test_samples],
+ y_test[0:num_test_samples]
+ )
+# pylint: enable=duplicate-code
if __name__ == "__main__":
diff --git a/model_conditioning_examples/quantization_aware_training.py b/model_conditioning_examples/quantization_aware_training.py
index 3d492a7..d590763 100644
--- a/model_conditioning_examples/quantization_aware_training.py
+++ b/model_conditioning_examples/quantization_aware_training.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,31 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-This script will provide you with a short example of how to perform quantization aware training in TensorFlow using the
+This script will provide you with a short example of how to perform
+quantization aware training in TensorFlow using the
TensorFlow Model Optimization Toolkit.
-The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit
-integer values.
+The output from this example will be a TensorFlow Lite model file
+where weights and activations are quantized to 8bit integer values.
-Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm
-Ethos NPU.
+Quantization helps reduce the size of your models and is necessary
+for running models on certain hardware such as Arm Ethos NPU.
-In quantization aware training (QAT), the error introduced with quantizing from fp32 to int8 is simulated using
-fake quantization nodes. By simulating this quantization error when training, the model can learn better adapted
-weights and minimize accuracy losses caused by the reduced precision.
+In quantization aware training (QAT), the error introduced with
+quantizing from fp32 to int8 is simulated using fake quantization nodes.
+By simulating this quantization error when training,
+the model can learn better adapted weights and minimize accuracy losses
+caused by the reduced precision.
-Minimum and maximum values for activations are also captured during training so activations for every layer can be
-quantized along with the weights later.
+Minimum and maximum values for activations are also captured
+during training so activations for every layer can be quantized
+along with the weights later.
-Quantization is only simulated during training and the training backward passes are still performed in full float
-precision. Actual quantization happens when generating a TensorFlow Lite model.
+Quantization is only simulated during training and the
+training backward passes are still performed in full float precision.
+Actual quantization happens when generating a TensorFlow Lite model.
-If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+If you are targeting an Arm Ethos-U55 NPU then the output
+TensorFlow Lite file will also need to be passed through the Vela
compiler for further optimizations before it can be used.
-For more information on using vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
-For more information on quantization aware training
-see: https://www.tensorflow.org/model_optimization/guide/quantization/training
+For more information on using vela see:
+ https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on quantization aware training see:
+ https://www.tensorflow.org/model_optimization/guide/quantization/training
"""
import pathlib
@@ -64,13 +71,15 @@ def quantize_and_convert_to_tflite(keras_model):
# After doing quantization aware training all the information for creating a fully quantized
# TensorFlow Lite model is already within the quantization aware Keras model.
- # This means we only need to call convert with default optimizations to generate the quantized TensorFlow Lite model.
+ # This means we only need to call convert with default optimizations to
+ # generate the quantized TensorFlow Lite model.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
return tflite_model
+# pylint: disable=duplicate-code
def evaluate_tflite_model(tflite_save_path, x_test, y_test):
"""Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
@@ -101,13 +110,19 @@ def evaluate_tflite_model(tflite_save_path, x_test, y_test):
def main():
+ """
+ Run quantization aware training
+ """
x_train, y_train, x_test, y_test = get_data()
model = create_model()
- # When working with the TensorFlow Keras API and the TF Model Optimization Toolkit we can make our
- # model quantization aware in one line. Once this is done we compile the model and train as normal.
- # It is important to note that the model is only quantization aware and is not quantized yet. The weights are
- # still floating point and will only be converted to int8 when we generate the TensorFlow Lite model later on.
+ # When working with the TensorFlow Keras API and theTF Model Optimization Toolkit
+ # we can make our model quantization aware in one line.
+ # Once this is done we compile the model and train as normal.
+ # It is important to note that the model is only quantization aware
+ # and is not quantized yet.
+ # The weights are still floating point and will only be converted
+ # to int8 when we generate the TensorFlow Lite model later on.
quant_aware_model = tfmot.quantization.keras.quantize_model(model)
quant_aware_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
@@ -117,7 +132,7 @@ def main():
quant_aware_model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
# Test the quantization aware model accuracy.
- test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test)
+ test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test) # pylint: disable=unused-variable
print(f"Test accuracy quant aware: {test_acc:.3f}")
# Quantize and save the resulting TensorFlow Lite model to file.
@@ -132,7 +147,12 @@ def main():
# Test quantized model accuracy. Save time by only testing a subset of the whole data.
num_test_samples = 1000
- evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+ evaluate_tflite_model(
+ quant_model_save_path,
+ x_test[0:num_test_samples],
+ y_test[0:num_test_samples]
+ )
+# pylint: enable=duplicate-code
if __name__ == "__main__":
diff --git a/model_conditioning_examples/setup.sh b/model_conditioning_examples/setup.sh
index 92de78a..678f9d3 100644
--- a/model_conditioning_examples/setup.sh
+++ b/model_conditioning_examples/setup.sh
@@ -1,5 +1,7 @@
+#!/bin/bash
+
#----------------------------------------------------------------------------
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,8 +16,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#----------------------------------------------------------------------------
-#!/bin/bash
+
python3 -m venv ./env
+# shellcheck disable=SC1091
source ./env/bin/activate
pip install -U pip
-pip install -r requirements.txt \ No newline at end of file
+pip install -r requirements.txt
diff --git a/model_conditioning_examples/training_utils.py b/model_conditioning_examples/training_utils.py
index a022bd1..2ce94b8 100644
--- a/model_conditioning_examples/training_utils.py
+++ b/model_conditioning_examples/training_utils.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -49,7 +49,8 @@ def create_model():
"""
keras_model = tf.keras.models.Sequential([
- tf.keras.layers.Conv2D(32, 3, padding='same', input_shape=(28, 28, 1), activation=tf.nn.relu),
+ tf.keras.layers.Conv2D(32, 3, padding='same',
+ input_shape=(28, 28, 1), activation=tf.nn.relu),
tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu),
tf.keras.layers.MaxPool2D(),
tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu),
diff --git a/model_conditioning_examples/weight_clustering.py b/model_conditioning_examples/weight_clustering.py
index 6672d53..e966336 100644
--- a/model_conditioning_examples/weight_clustering.py
+++ b/model_conditioning_examples/weight_clustering.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,22 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-This script will provide you with a short example of how to perform clustering of weights (weight sharing) in
-TensorFlow using the TensorFlow Model Optimization Toolkit.
+This script will provide you with a short example of how to perform
+clustering of weights (weight sharing) in TensorFlow
+using the TensorFlow Model Optimization Toolkit.
-The output from this example will be a TensorFlow Lite model file where weights in each layer have been 'clustered' into
-16 clusters during training - quantization has then been applied on top of this.
+The output from this example will be a TensorFlow Lite model file
+where weights in each layer have been 'clustered' into 16 clusters
+during training - quantization has then been applied on top of this.
-By clustering the model we can improve compression of the model file. This can be essential for deploying certain
-models on systems with limited resources - such as embedded systems using an Arm Ethos NPU.
+By clustering the model we can improve compression of the model file.
+This can be essential for deploying certain models on systems with
+limited resources - such as embedded systems using an Arm Ethos NPU.
-After performing clustering we do post-training quantization to quantize the model and then generate a TensorFlow Lite file.
+After performing clustering we do post-training quantization
+to quantize the model and then generate a TensorFlow Lite file.
-If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+If you are targeting an Arm Ethos-U55 NPU then the output
+TensorFlow Lite file will also need to be passed through the Vela
compiler for further optimizations before it can be used.
-For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
-For more information on clustering see: https://www.tensorflow.org/model_optimization/guide/clustering
+For more information on using Vela see:
+ https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on clustering see:
+ https://www.tensorflow.org/model_optimization/guide/clustering
"""
import pathlib
@@ -42,39 +49,52 @@ from post_training_quantization import post_training_quantize, evaluate_tflite_m
def prepare_for_clustering(keras_model):
"""Prepares a Keras model for clustering."""
- # Choose the number of clusters to use and how to initialize them. Using more clusters will generally
- # reduce accuracy so you will need to find the optimal number for your use-case.
+ # Choose the number of clusters to use and how to initialize them.
+ # Using more clusters will generally reduce accuracy,
+ # so you will need to find the optimal number for your use-case.
number_of_clusters = 16
cluster_centroids_init = tfmot.clustering.keras.CentroidInitialization.LINEAR
- # Apply the clustering wrapper to the whole model so weights in every layer will get clustered. You may find that
- # to avoid too much accuracy loss only certain non-critical layers in your model should be clustered.
- clustering_ready_model = tfmot.clustering.keras.cluster_weights(keras_model,
- number_of_clusters=number_of_clusters,
- cluster_centroids_init=cluster_centroids_init)
+ # Apply the clustering wrapper to the whole model so weights in
+ # every layer will get clustered. You may find that to avoid
+ # too much accuracy loss only certain non-critical layers in
+ # your model should be clustered.
+ clustering_ready_model = tfmot.clustering.keras.cluster_weights(
+ keras_model,
+ number_of_clusters=number_of_clusters,
+ cluster_centroids_init=cluster_centroids_init
+ )
# We must recompile the model after making it ready for clustering.
- clustering_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
- loss=tf.keras.losses.sparse_categorical_crossentropy,
- metrics=['accuracy'])
+ clustering_ready_model.compile(
+ optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
+ loss=tf.keras.losses.sparse_categorical_crossentropy,
+ metrics=['accuracy']
+ )
return clustering_ready_model
def main():
+ """
+ Run weight clustering
+ """
x_train, y_train, x_test, y_test = get_data()
model = create_model()
# Compile and train the model first.
- # In general it is easier to do clustering as a fine-tuning step after the model is fully trained.
- model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
- loss=tf.keras.losses.sparse_categorical_crossentropy,
- metrics=['accuracy'])
+ # In general, it is easier to do clustering as a
+ # fine-tuning step after the model is fully trained.
+ model.compile(
+ optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
+ loss=tf.keras.losses.sparse_categorical_crossentropy,
+ metrics=['accuracy']
+ )
model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
# Test the trained model accuracy.
- test_loss, test_acc = model.evaluate(x_test, y_test)
+ test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable
print(f"Test accuracy before clustering: {test_acc:.3f}")
# Prepare the model for clustering.
@@ -88,19 +108,26 @@ def main():
# Remove all variables that clustering only needed in the training phase.
model_for_export = tfmot.clustering.keras.strip_clustering(clustered_model)
- # Apply post-training quantization on top of the clustering and save the resulting TensorFlow Lite model to file.
+ # Apply post-training quantization on top of the clustering
+ # and save the resulting TensorFlow Lite model to file.
tflite_model = post_training_quantize(model_for_export, x_train)
tflite_models_dir = pathlib.Path('./conditioned_models/')
tflite_models_dir.mkdir(exist_ok=True, parents=True)
- clustered_quant_model_save_path = tflite_models_dir / 'clustered_post_training_quant_model.tflite'
+ clustered_quant_model_save_path = \
+ tflite_models_dir / 'clustered_post_training_quant_model.tflite'
with open(clustered_quant_model_save_path, 'wb') as f:
f.write(tflite_model)
- # Test the clustered quantized model accuracy. Save time by only testing a subset of the whole data.
+ # Test the clustered quantized model accuracy.
+ # Save time by only testing a subset of the whole data.
num_test_samples = 1000
- evaluate_tflite_model(clustered_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+ evaluate_tflite_model(
+ clustered_quant_model_save_path,
+ x_test[0:num_test_samples],
+ y_test[0:num_test_samples]
+ )
if __name__ == "__main__":
diff --git a/model_conditioning_examples/weight_pruning.py b/model_conditioning_examples/weight_pruning.py
index cbf9cf9..303b6df 100644
--- a/model_conditioning_examples/weight_pruning.py
+++ b/model_conditioning_examples/weight_pruning.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,23 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-This script will provide you with a short example of how to perform magnitude-based weight pruning in TensorFlow
+This script will provide you with a short example of how to perform
+magnitude-based weight pruning in TensorFlow
using the TensorFlow Model Optimization Toolkit.
-The output from this example will be a TensorFlow Lite model file where ~75% percent of the weights have been 'pruned' to the
+The output from this example will be a TensorFlow Lite model file
+where ~75% percent of the weights have been 'pruned' to the
value 0 during training - quantization has then been applied on top of this.
-By pruning the model we can improve compression of the model file. This can be essential for deploying certain models
-on systems with limited resources - such as embedded systems using Arm Ethos NPU. Also, if the pruned model is run
-on an Arm Ethos NPU then this pruning can improve the execution time of the model.
+By pruning the model we can improve compression of the model file.
+This can be essential for deploying certain models on systems
+with limited resources - such as embedded systems using Arm Ethos NPU.
+Also, if the pruned model is run on an Arm Ethos NPU then
+this pruning can improve the execution time of the model.
-After pruning is complete we do post-training quantization to quantize the model and then generate a TensorFlow Lite file.
+After pruning is complete we do post-training quantization
+to quantize the model and then generate a TensorFlow Lite file.
-If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+If you are targeting an Arm Ethos-U55 NPU then the output
+TensorFlow Lite file will also need to be passed through the Vela
compiler for further optimizations before it can be used.
-For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
-For more information on weight pruning see: https://www.tensorflow.org/model_optimization/guide/pruning
+For more information on using Vela see:
+ https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on weight pruning see:
+ https://www.tensorflow.org/model_optimization/guide/pruning
"""
import pathlib
@@ -43,13 +51,20 @@ from post_training_quantization import post_training_quantize, evaluate_tflite_m
def prepare_for_pruning(keras_model):
"""Prepares a Keras model for pruning."""
- # We use a constant sparsity schedule so the amount of sparsity in the model is kept at the same percent throughout
- # training. An alternative is PolynomialDecay where sparsity can be gradually increased during training.
+ # We use a constant sparsity schedule so the amount of sparsity
+ # in the model is kept at the same percent throughout training.
+ # An alternative is PolynomialDecay where sparsity
+ # can be gradually increased during training.
pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.75, begin_step=0)
- # Apply the pruning wrapper to the whole model so weights in every layer will get pruned. You may find that to avoid
- # too much accuracy loss only certain non-critical layers in your model should be pruned.
- pruning_ready_model = tfmot.sparsity.keras.prune_low_magnitude(keras_model, pruning_schedule=pruning_schedule)
+ # Apply the pruning wrapper to the whole model
+ # so weights in every layer will get pruned.
+ # You may find that to avoid too much accuracy loss only
+ # certain non-critical layers in your model should be pruned.
+ pruning_ready_model = tfmot.sparsity.keras.prune_low_magnitude(
+ keras_model,
+ pruning_schedule=pruning_schedule
+ )
# We must recompile the model after making it ready for pruning.
pruning_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
@@ -60,11 +75,15 @@ def prepare_for_pruning(keras_model):
def main():
+ """
+ Run weight pruning
+ """
x_train, y_train, x_test, y_test = get_data()
model = create_model()
# Compile and train the model first.
- # In general it is easier to do pruning as a fine-tuning step after the model is fully trained.
+ # In general, it is easier to do pruning as a fine-tuning step
+ # after the model is fully trained.
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
@@ -72,7 +91,7 @@ def main():
model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
# Test the trained model accuracy.
- test_loss, test_acc = model.evaluate(x_test, y_test)
+ test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable
print(f"Test accuracy before pruning: {test_acc:.3f}")
# Prepare the model for pruning and add the pruning update callback needed in training.
@@ -80,14 +99,23 @@ def main():
callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
# Continue training the model but now with pruning applied - remember to pass in the callbacks!
- pruned_model.fit(x=x_train, y=y_train, batch_size=128, epochs=1, verbose=1, shuffle=True, callbacks=callbacks)
+ pruned_model.fit(
+ x=x_train,
+ y=y_train,
+ batch_size=128,
+ epochs=1,
+ verbose=1,
+ shuffle=True,
+ callbacks=callbacks
+ )
test_loss, test_acc = pruned_model.evaluate(x_test, y_test)
print(f"Test accuracy after pruning: {test_acc:.3f}")
# Remove all variables that pruning only needed in the training phase.
model_for_export = tfmot.sparsity.keras.strip_pruning(pruned_model)
- # Apply post-training quantization on top of the pruning and save the resulting TensorFlow Lite model to file.
+ # Apply post-training quantization on top of the pruning
+ # and save the resulting TensorFlow Lite model to file.
tflite_model = post_training_quantize(model_for_export, x_train)
tflite_models_dir = pathlib.Path('./conditioned_models/')
@@ -97,9 +125,14 @@ def main():
with open(pruned_quant_model_save_path, 'wb') as f:
f.write(tflite_model)
- # Test the pruned quantized model accuracy. Save time by only testing a subset of the whole data.
+ # Test the pruned quantized model accuracy.
+ # Save time by only testing a subset of the whole data.
num_test_samples = 1000
- evaluate_tflite_model(pruned_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+ evaluate_tflite_model(
+ pruned_quant_model_save_path,
+ x_test[0:num_test_samples],
+ y_test[0:num_test_samples]
+ )
if __name__ == "__main__":