summaryrefslogtreecommitdiff
path: root/model_conditioning_examples/quantization_aware_training.py
blob: d590763d04eaee930b22baa122edff35cdeafa81 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#  SPDX-FileCopyrightText:  Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#  SPDX-License-Identifier: Apache-2.0
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""
This script will provide you with a short example of how to perform
quantization aware training in TensorFlow using the
TensorFlow Model Optimization Toolkit.

The output from this example will be a TensorFlow Lite model file
where weights and activations are quantized to 8bit integer values.

Quantization helps reduce the size of your models and is necessary
for running models on certain hardware such as Arm Ethos NPU.

In quantization aware training (QAT), the error introduced with
quantizing from fp32 to int8 is simulated using fake quantization nodes.
By simulating this quantization error when training,
the model can learn better adapted weights and minimize accuracy losses
caused by the reduced precision.

Minimum and maximum values for activations are also captured
during training so activations for every layer can be quantized
along with the weights later.

Quantization is only simulated during training and the
training backward passes are still performed in full float precision.
Actual quantization happens when generating a TensorFlow Lite model.

If you are targeting an Arm Ethos-U55 NPU then the output
TensorFlow Lite file will also need to be passed through the Vela
compiler for further optimizations before it can be used.

For more information on using vela see:
    https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
For more information on quantization aware training see:
    https://www.tensorflow.org/model_optimization/guide/quantization/training
"""
import pathlib

import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot

from training_utils import get_data, create_model


def quantize_and_convert_to_tflite(keras_model):
    """Quantize and convert Keras model trained with QAT to TensorFlow Lite.

    TensorFlow Lite will have fp32 inputs/outputs and the model will handle quantizing/dequantizing.

    Args:
        keras_model: Keras model trained with quantization aware training.

    Returns:
        Quantized TensorFlow Lite model.
    """

    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)

    # After doing quantization aware training all the information for creating a fully quantized
    # TensorFlow Lite model is already within the quantization aware Keras model.
    # This means we only need to call convert with default optimizations to
    # generate the quantized TensorFlow Lite model.
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model = converter.convert()

    return tflite_model


# pylint: disable=duplicate-code
def evaluate_tflite_model(tflite_save_path, x_test, y_test):
    """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.

    Args:
        tflite_save_path: Path to TensorFlow Lite model to test.
        x_test: numpy array of testing data.
        y_test: numpy array of testing labels (sparse categorical).
    """

    interpreter = tf.lite.Interpreter(model_path=str(tflite_save_path))

    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    accuracy_count = 0
    num_test_images = len(y_test)

    for i in range(num_test_images):
        interpreter.set_tensor(input_details[0]['index'], x_test[i][np.newaxis, ...])
        interpreter.invoke()
        output_data = interpreter.get_tensor(output_details[0]['index'])

        if np.argmax(output_data) == y_test[i]:
            accuracy_count += 1

    print(f"Test accuracy quantized: {accuracy_count / num_test_images:.3f}")


def main():
    """
    Run quantization aware training
    """
    x_train, y_train, x_test, y_test = get_data()
    model = create_model()

    # When working with the TensorFlow Keras API and theTF Model Optimization Toolkit
    # we can make our model quantization aware in one line.
    # Once this is done we compile the model and train as normal.
    # It is important to note that the model is only quantization aware
    # and is not quantized yet.
    # The weights are still floating point and will only be converted
    # to int8 when we generate the TensorFlow Lite model later on.
    quant_aware_model = tfmot.quantization.keras.quantize_model(model)

    quant_aware_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                              loss=tf.keras.losses.sparse_categorical_crossentropy,
                              metrics=['accuracy'])

    quant_aware_model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)

    # Test the quantization aware model accuracy.
    test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test)  # pylint: disable=unused-variable
    print(f"Test accuracy quant aware: {test_acc:.3f}")

    # Quantize and save the resulting TensorFlow Lite model to file.
    tflite_model = quantize_and_convert_to_tflite(quant_aware_model)

    tflite_models_dir = pathlib.Path('./conditioned_models/')
    tflite_models_dir.mkdir(exist_ok=True, parents=True)

    quant_model_save_path = tflite_models_dir / 'qat_quant_model.tflite'
    with open(quant_model_save_path, 'wb') as f:
        f.write(tflite_model)

    # Test quantized model accuracy. Save time by only testing a subset of the whole data.
    num_test_samples = 1000
    evaluate_tflite_model(
        quant_model_save_path,
        x_test[0:num_test_samples],
        y_test[0:num_test_samples]
    )
# pylint: enable=duplicate-code


if __name__ == "__main__":
    main()