summaryrefslogtreecommitdiff
path: root/model_conditioning_examples/quantization_aware_training.py
diff options
context:
space:
mode:
Diffstat (limited to 'model_conditioning_examples/quantization_aware_training.py')
-rw-r--r--model_conditioning_examples/quantization_aware_training.py68
1 files changed, 44 insertions, 24 deletions
diff --git a/model_conditioning_examples/quantization_aware_training.py b/model_conditioning_examples/quantization_aware_training.py
index 3d492a7..d590763 100644
--- a/model_conditioning_examples/quantization_aware_training.py
+++ b/model_conditioning_examples/quantization_aware_training.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,31 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-This script will provide you with a short example of how to perform quantization aware training in TensorFlow using the
+This script will provide you with a short example of how to perform
+quantization aware training in TensorFlow using the
TensorFlow Model Optimization Toolkit.
-The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit
-integer values.
+The output from this example will be a TensorFlow Lite model file
+where weights and activations are quantized to 8bit integer values.
-Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm
-Ethos NPU.
+Quantization helps reduce the size of your models and is necessary
+for running models on certain hardware such as Arm Ethos NPU.
-In quantization aware training (QAT), the error introduced with quantizing from fp32 to int8 is simulated using
-fake quantization nodes. By simulating this quantization error when training, the model can learn better adapted
-weights and minimize accuracy losses caused by the reduced precision.
+In quantization aware training (QAT), the error introduced with
+quantizing from fp32 to int8 is simulated using fake quantization nodes.
+By simulating this quantization error when training,
+the model can learn better adapted weights and minimize accuracy losses
+caused by the reduced precision.
-Minimum and maximum values for activations are also captured during training so activations for every layer can be
-quantized along with the weights later.
+Minimum and maximum values for activations are also captured
+during training so activations for every layer can be quantized
+along with the weights later.
-Quantization is only simulated during training and the training backward passes are still performed in full float
-precision. Actual quantization happens when generating a TensorFlow Lite model.
+Quantization is only simulated during training and the
+training backward passes are still performed in full float precision.
+Actual quantization happens when generating a TensorFlow Lite model.
-If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+If you are targeting an Arm Ethos-U55 NPU then the output
+TensorFlow Lite file will also need to be passed through the Vela
compiler for further optimizations before it can be used.
-For more information on using vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
-For more information on quantization aware training
-see: https://www.tensorflow.org/model_optimization/guide/quantization/training
+For more information on using vela see:
+ https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on quantization aware training see:
+ https://www.tensorflow.org/model_optimization/guide/quantization/training
"""
import pathlib
@@ -64,13 +71,15 @@ def quantize_and_convert_to_tflite(keras_model):
# After doing quantization aware training all the information for creating a fully quantized
# TensorFlow Lite model is already within the quantization aware Keras model.
- # This means we only need to call convert with default optimizations to generate the quantized TensorFlow Lite model.
+ # This means we only need to call convert with default optimizations to
+ # generate the quantized TensorFlow Lite model.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
return tflite_model
+# pylint: disable=duplicate-code
def evaluate_tflite_model(tflite_save_path, x_test, y_test):
"""Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
@@ -101,13 +110,19 @@ def evaluate_tflite_model(tflite_save_path, x_test, y_test):
def main():
+ """
+ Run quantization aware training
+ """
x_train, y_train, x_test, y_test = get_data()
model = create_model()
- # When working with the TensorFlow Keras API and the TF Model Optimization Toolkit we can make our
- # model quantization aware in one line. Once this is done we compile the model and train as normal.
- # It is important to note that the model is only quantization aware and is not quantized yet. The weights are
- # still floating point and will only be converted to int8 when we generate the TensorFlow Lite model later on.
+ # When working with the TensorFlow Keras API and theTF Model Optimization Toolkit
+ # we can make our model quantization aware in one line.
+ # Once this is done we compile the model and train as normal.
+ # It is important to note that the model is only quantization aware
+ # and is not quantized yet.
+ # The weights are still floating point and will only be converted
+ # to int8 when we generate the TensorFlow Lite model later on.
quant_aware_model = tfmot.quantization.keras.quantize_model(model)
quant_aware_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
@@ -117,7 +132,7 @@ def main():
quant_aware_model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
# Test the quantization aware model accuracy.
- test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test)
+ test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test) # pylint: disable=unused-variable
print(f"Test accuracy quant aware: {test_acc:.3f}")
# Quantize and save the resulting TensorFlow Lite model to file.
@@ -132,7 +147,12 @@ def main():
# Test quantized model accuracy. Save time by only testing a subset of the whole data.
num_test_samples = 1000
- evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+ evaluate_tflite_model(
+ quant_model_save_path,
+ x_test[0:num_test_samples],
+ y_test[0:num_test_samples]
+ )
+# pylint: enable=duplicate-code
if __name__ == "__main__":