From 3c79893217bc632c9b0efa815091bef3c779490c Mon Sep 17 00:00:00 2001 From: alexander Date: Fri, 26 Mar 2021 21:42:19 +0000 Subject: Opensource ML embedded evaluation kit Change-Id: I12e807f19f5cacad7cef82572b6dd48252fd61fd --- model_conditioning_examples/Readme.md | 173 +++++++++++++++++++++ .../post_training_quantization.py | 139 +++++++++++++++++ .../quantization_aware_training.py | 139 +++++++++++++++++ model_conditioning_examples/requirements.txt | 3 + model_conditioning_examples/setup.sh | 21 +++ model_conditioning_examples/training_utils.py | 61 ++++++++ model_conditioning_examples/weight_clustering.py | 107 +++++++++++++ model_conditioning_examples/weight_pruning.py | 106 +++++++++++++ 8 files changed, 749 insertions(+) create mode 100644 model_conditioning_examples/Readme.md create mode 100644 model_conditioning_examples/post_training_quantization.py create mode 100644 model_conditioning_examples/quantization_aware_training.py create mode 100644 model_conditioning_examples/requirements.txt create mode 100644 model_conditioning_examples/setup.sh create mode 100644 model_conditioning_examples/training_utils.py create mode 100644 model_conditioning_examples/weight_clustering.py create mode 100644 model_conditioning_examples/weight_pruning.py (limited to 'model_conditioning_examples') diff --git a/model_conditioning_examples/Readme.md b/model_conditioning_examples/Readme.md new file mode 100644 index 0000000..ede2c24 --- /dev/null +++ b/model_conditioning_examples/Readme.md @@ -0,0 +1,173 @@ +# Model conditioning examples + +- [Introduction](#introduction) + - [How to run](#how-to-run) +- [Quantization](#quantization) + - [Post-training quantization](#post-training-quantization) + - [Quantization aware training](#quantization-aware-training) +- [Weight pruning](#weight-pruning) +- [Weight clustering](#weight-clustering) +- [References](#references) + +## Introduction + +This folder contains short example scripts that demonstrate some methods available in TensorFlow to condition your model +in preparation for deployment on Arm Ethos NPU. + +These scripts will cover three main topics: + +- Quantization +- Weight clustering +- Weight pruning + +The objective of these scripts is not to be a single source of knowledge on everything related to model conditioning. +Instead the aim is to provide the reader with a quick starting point that demonstrates some commonly used tools that +will enable models to run on Arm Ethos NPU and also optimize them to enable maximum performance from the Arm Ethos NPU. + +Links to more in-depth guides available on the TensorFlow website are provided in the [references](#references) section +in this Readme. + +### How to run + +From the `model_conditioning_examples` folder run the following command: + +```commandline +./setup.sh +``` + +This will create a Python virtual environment and install the required versions of TensorFlow and TensorFlow model +optimization toolkit to run the examples scripts. + +If the virtual environment has not been activated you can do so by running: + +```commandline +source ./env/bin/activate +``` + +You can then run the examples from the command line. For example to run the post-training quantization example: + +```commandline +python ./post_training_quantization.py +``` + +The produced TensorFlow Lite model files will be saved in a `conditioned_models` sub-folder. + +## Quantization + +Most machine learning models are trained using 32bit floating point precision. However, Arm Ethos NPU performs +calculations in 8bit integer precision. As a result, it is required that any model you wish to deploy on Arm Ethos NPU is +first fully quantized to 8bits. + +TensorFlow provides two methods of quantization and the scripts in this folder will demonstrate these: + +- [Post-training quantization](./post_training_quantization.py) +- [Quantization aware training](./quantization_aware_training.py) + +Both of these techniques will not only quantize weights of the the model but also the variable tensors such as model +input and output, and the activations of each intermediate layer. + +For details on the quantization specification used by TensorFlow please see +[here](https://www.tensorflow.org/lite/performance/quantization_spec). + +In both methods scale and zero point values are chosen to allow the floating point weights to be maximally +represented in this reduced precision. Quantization is performed per-axis, meaning a different scale and zero point +is used for each channel of a layer's weights. + +### Post-training quantization + +The first of the quantization methods that will be covered is called post-training quantization. As the name suggests +this form of quantization takes place after training of your model is complete. It is also the simpler of the methods +we will show to quantize a model. + +In post-training quantization, first the weights of the model are quantized to 8bit integer values. After this we +quantize the variable tensors, such as layer activations. To do this we need to calculate the potential range of values +that all these tensors can take. + +Calculating these ranges requires a small dataset that is representative of what you expect your model to see when +it is deployed. Model inference is then performed using this representative dataset and the resulting minimum and +maximum values for variable tensors are calculated. + +Only a small number of samples need to be used in this calibration dataset (around 100-500 should be enough). These +samples can be taken from the training or validation sets. + +Quantizing your model can result in accuracy drops depending on your model. However for a lot of use cases the accuracy +drop when using post-training quantization is usually minimal. After post-training quantization is complete you will +have a fully quantized TensorFlow Lite model. + +If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +compiler for further optimizations before it can be used. + +### Quantization aware training + +Depending on the model, the use of post-training quantization can result in an accuracy drop that is too large to be +considered suitable. This is where quantization aware training can be used to improve things. Quantization aware +training simulates the quantization of weights and activations during the inference stage of training using fake +quantization nodes. + +By simulating quantization during training, the model weights will be adjusted in the backward pass so that they are +better suited for the reduced precision of quantization. It is this simulating of quantization and adjusting of weights +that can minimize accuracy loss incurred when quantizing. Note that quantization is only simulated +at this stage and backward passes of training are still performed in full floating point precision. + +Importantly, with quantization aware training you do not have to train your model from scratch to use it. Instead, you +can train it normally (not quantization aware) and after training is complete you can then fine-tune it using +quantization aware training. By only fine-tuning you can save a lot of training time. + +As well as simulating quantization and adjusting weights, the ranges for variable tensors are captured so that the +model can be fully quantized afterwards. Once you have finished quantization aware training the TensorFlow Lite converter is +used to produce a fully quantized TensorFlow Lite model. + +If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +compiler for further optimizations before it can be used. + +## Weight pruning + +After you have trained your deep learning model it is common to see that many of the weights in the model +have the value of 0, and also have many values very close to 0. These weights have very little effect in network +calculations so are safe to be removed or 'pruned' from the model. This is accomplished by setting all these weight +values to 0, resulting in a sparse model. + +Compression algorithms can then take advantage of this to reduce model size in memory, which can be very important when +deploying on small embedded systems. Moreover, Arm Ethos NPU can take advantage of model sparsity to further accelerate +execution of a model. + +Training with weight pruning will force your model to have a certain percentage of its weights set (or 'pruned') to 0 +during the training phase. This is done by forcing those that are closest to 0 to become 0. Doing it during training +guarantees your model will have a certain level of sparsity and the weights of your model can also be better adapted +to the sparsity level chosen. This means, accuracy loss will hopefully be minimized if a large pruning percentage +is desired. + +Weight pruning can be further combined with quantization so you have a model that is both pruned and quantized, meaning +that the memory saving affects of both can be combined. Quantization then allows the model to be used with +Arm Ethos NPU. + +If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +compiler for further optimizations before it can be used. + +## Weight clustering + +Another method of model conditioning is weight clustering (also called weight sharing). With this technique, a fixed +number of values (cluster centers) are used in each layer of a model to represent all the possible values that the +layer's weights take. The weights in a layer will then use the value of their closest cluster center. By restricting +the number of possible clusters, weight clustering reduces the amount of memory needed to store all the weight values +in a model. + +Depending on the model and number of clusters chosen, using this kind of technique can have a negative effect on +accuracy. To reduce the impact on accuracy you can introduce clustering during training so the models weights can be +better adjusted to the reduced precision. + +Weight clustering can be further combined with quantization so you have a model that is both clustered and quantized, +meaning that the memory saving affects of both can be combined. Quantization then allows the model to be used with +Arm Ethos NPU. + +If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +compiler for further optimizations before it can be used (see [Optimize model with Vela compiler](./building.md#optimize-custom-model-with-vela-compiler)). + +## References + +- [TensorFlow Model Optimization Toolkit](https://www.tensorflow.org/model_optimization) +- [Post-training quantization](https://www.tensorflow.org/lite/performance/post_training_integer_quant) +- [Quantization aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training) +- [Weight pruning](https://www.tensorflow.org/model_optimization/guide/pruning) +- [Weight clustering](https://www.tensorflow.org/model_optimization/guide/clustering) +- [Vela](https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/) diff --git a/model_conditioning_examples/post_training_quantization.py b/model_conditioning_examples/post_training_quantization.py new file mode 100644 index 0000000..ab535ac --- /dev/null +++ b/model_conditioning_examples/post_training_quantization.py @@ -0,0 +1,139 @@ +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script will provide you with an example of how to perform post-training quantization in TensorFlow. + +The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit +integer values. + +Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm +Ethos NPU. + +In addition to quantizing weights, post-training quantization uses a calibration dataset to +capture the minimum and maximum values of all variable tensors in your model. +By capturing these ranges it is possible to fully quantize not just the weights of the model but also the activations. + +Depending on the model you are quantizing there may be some accuracy loss, but for a lot of models the loss should +be minimal. + +If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +compiler for further optimizations before it can be used. + +For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ +For more information on post-training quantization +see: https://www.tensorflow.org/lite/performance/post_training_integer_quant +""" +import pathlib + +import numpy as np +import tensorflow as tf + +from training_utils import get_data, create_model + + +def post_training_quantize(keras_model, sample_data): + """Quantize Keras model using post-training quantization with some sample data. + + TensorFlow Lite will have fp32 inputs/outputs and the model will handle quantizing/dequantizing. + + Args: + keras_model: Keras model to quantize. + sample_data: A numpy array of data to use as a representative dataset. + + Returns: + Quantized TensorFlow Lite model. + """ + + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + + # We set the following converter options to ensure our model is fully quantized. + # An error should get thrown if there is any ops that can't be quantized. + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + + # To use post training quantization we must provide some sample data that will be used to + # calculate activation ranges for quantization. This data should be representative of the data + # we expect to feed the model and must be provided by a generator function. + def generate_repr_dataset(): + for i in range(100): # 100 samples is all we should need in this example. + yield [np.expand_dims(sample_data[i], axis=0)] + + converter.representative_dataset = generate_repr_dataset + tflite_model = converter.convert() + + return tflite_model + + +def evaluate_tflite_model(tflite_save_path, x_test, y_test): + """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter. + + Args: + tflite_save_path: Path to TensorFlow Lite model to test. + x_test: numpy array of testing data. + y_test: numpy array of testing labels (sparse categorical). + """ + + interpreter = tf.lite.Interpreter(model_path=str(tflite_save_path)) + + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + accuracy_count = 0 + num_test_images = len(y_test) + + for i in range(num_test_images): + interpreter.set_tensor(input_details[0]['index'], x_test[i][np.newaxis, ...]) + interpreter.invoke() + output_data = interpreter.get_tensor(output_details[0]['index']) + + if np.argmax(output_data) == y_test[i]: + accuracy_count += 1 + + print(f"Test accuracy quantized: {accuracy_count / num_test_images:.3f}") + + +def main(): + x_train, y_train, x_test, y_test = get_data() + model = create_model() + + # Compile and train the model in fp32 as normal. + model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + loss=tf.keras.losses.sparse_categorical_crossentropy, + metrics=['accuracy']) + + model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True) + + # Test the fp32 model accuracy. + test_loss, test_acc = model.evaluate(x_test, y_test) + print(f"Test accuracy float: {test_acc:.3f}") + + # Quantize and export the resulting TensorFlow Lite model to file. + tflite_model = post_training_quantize(model, x_train) + + tflite_models_dir = pathlib.Path('./conditioned_models/') + tflite_models_dir.mkdir(exist_ok=True, parents=True) + + quant_model_save_path = tflite_models_dir / 'post_training_quant_model.tflite' + with open(quant_model_save_path, 'wb') as f: + f.write(tflite_model) + + # Test the quantized model accuracy. Save time by only testing a subset of the whole data. + num_test_samples = 1000 + evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples]) + + +if __name__ == "__main__": + main() diff --git a/model_conditioning_examples/quantization_aware_training.py b/model_conditioning_examples/quantization_aware_training.py new file mode 100644 index 0000000..acb768c --- /dev/null +++ b/model_conditioning_examples/quantization_aware_training.py @@ -0,0 +1,139 @@ +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script will provide you with a short example of how to perform quantization aware training in TensorFlow using the +TensorFlow Model Optimization Toolkit. + +The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit +integer values. + +Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm +Ethos NPU. + +In quantization aware training (QAT), the error introduced with quantizing from fp32 to int8 is simulated using +fake quantization nodes. By simulating this quantization error when training, the model can learn better adapted +weights and minimize accuracy losses caused by the reduced precision. + +Minimum and maximum values for activations are also captured during training so activations for every layer can be +quantized along with the weights later. + +Quantization is only simulated during training and the training backward passes are still performed in full float +precision. Actual quantization happens when generating a TensorFlow Lite model. + +If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +compiler for further optimizations before it can be used. + +For more information on using vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ +For more information on quantization aware training +see: https://www.tensorflow.org/model_optimization/guide/quantization/training +""" +import pathlib + +import numpy as np +import tensorflow as tf +import tensorflow_model_optimization as tfmot + +from training_utils import get_data, create_model + + +def quantize_and_convert_to_tflite(keras_model): + """Quantize and convert Keras model trained with QAT to TensorFlow Lite. + + TensorFlow Lite will have fp32 inputs/outputs and the model will handle quantizing/dequantizing. + + Args: + keras_model: Keras model trained with quantization aware training. + + Returns: + Quantized TensorFlow Lite model. + """ + + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + + # After doing quantization aware training all the information for creating a fully quantized + # TensorFlow Lite model is already within the quantization aware Keras model. + # This means we only need to call convert with default optimizations to generate the quantized TensorFlow Lite model. + converter.optimizations = [tf.lite.Optimize.DEFAULT] + tflite_model = converter.convert() + + return tflite_model + + +def evaluate_tflite_model(tflite_save_path, x_test, y_test): + """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter. + + Args: + tflite_save_path: Path to TensorFlow Lite model to test. + x_test: numpy array of testing data. + y_test: numpy array of testing labels (sparse categorical). + """ + + interpreter = tf.lite.Interpreter(model_path=str(tflite_save_path)) + + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + accuracy_count = 0 + num_test_images = len(y_test) + + for i in range(num_test_images): + interpreter.set_tensor(input_details[0]['index'], x_test[i][np.newaxis, ...]) + interpreter.invoke() + output_data = interpreter.get_tensor(output_details[0]['index']) + + if np.argmax(output_data) == y_test[i]: + accuracy_count += 1 + + print(f"Test accuracy quantized: {accuracy_count / num_test_images:.3f}") + + +def main(): + x_train, y_train, x_test, y_test = get_data() + model = create_model() + + # When working with the TensorFlow Keras API and the TF Model Optimization Toolkit we can make our + # model quantization aware in one line. Once this is done we compile the model and train as normal. + # It is important to note that the model is only quantization aware and is not quantized yet. The weights are + # still floating point and will only be converted to int8 when we generate the TensorFlow Lite model later on. + quant_aware_model = tfmot.quantization.keras.quantize_model(model) + + quant_aware_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + loss=tf.keras.losses.sparse_categorical_crossentropy, + metrics=['accuracy']) + + quant_aware_model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True) + + # Test the quantization aware model accuracy. + test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test) + print(f"Test accuracy quant aware: {test_acc:.3f}") + + # Quantize and save the resulting TensorFlow Lite model to file. + tflite_model = quantize_and_convert_to_tflite(quant_aware_model) + + tflite_models_dir = pathlib.Path('./conditioned_models/') + tflite_models_dir.mkdir(exist_ok=True, parents=True) + + quant_model_save_path = tflite_models_dir / 'qat_quant_model.tflite' + with open(quant_model_save_path, 'wb') as f: + f.write(tflite_model) + + # Test quantized model accuracy. Save time by only testing a subset of the whole data. + num_test_samples = 1000 + evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples]) + + +if __name__ == "__main__": + main() diff --git a/model_conditioning_examples/requirements.txt b/model_conditioning_examples/requirements.txt new file mode 100644 index 0000000..96e15a3 --- /dev/null +++ b/model_conditioning_examples/requirements.txt @@ -0,0 +1,3 @@ +tensorflow==2.4.0 +tensorflow-model-optimization==0.5.0 +numpy==1.19.5 \ No newline at end of file diff --git a/model_conditioning_examples/setup.sh b/model_conditioning_examples/setup.sh new file mode 100644 index 0000000..f552662 --- /dev/null +++ b/model_conditioning_examples/setup.sh @@ -0,0 +1,21 @@ +#---------------------------------------------------------------------------- +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#---------------------------------------------------------------------------- +#!/bin/bash +python3 -m venv ./env +source ./env/bin/activate +pip install -U pip +pip install -r requirements.txt \ No newline at end of file diff --git a/model_conditioning_examples/training_utils.py b/model_conditioning_examples/training_utils.py new file mode 100644 index 0000000..3467b2a --- /dev/null +++ b/model_conditioning_examples/training_utils.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utility functions related to data and models that are common to all the model conditioning examples. +""" +import tensorflow as tf +import numpy as np + + +def get_data(): + """Downloads and returns the pre-processed data and labels for training and testing. + + Returns: + Tuple of: (train data, train labels, test data, test labels) + """ + + # To save time we use the MNIST dataset for this example. + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + + # Convolution operations require data to have 4 dimensions. + # We divide by 255 to help training and cast to float32 for TensorFlow. + x_train = (x_train[..., np.newaxis] / 255.0).astype(np.float32) + x_test = (x_test[..., np.newaxis] / 255.0).astype(np.float32) + + return x_train, y_train, x_test, y_test + + +def create_model(): + """Create and returns a simple Keras model for training MNIST. + + We will use a simple convolutional neural network for this example, + but the model optimization methods employed should be compatible with a + wide variety of CNN architectures such as Mobilenet and Inception etc. + + Returns: + Uncompiled Keras model. + """ + + keras_model = tf.keras.models.Sequential([ + tf.keras.layers.Conv2D(32, 3, padding='same', input_shape=(28, 28, 1), activation=tf.nn.relu), + tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu), + tf.keras.layers.MaxPool2D(), + tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu), + tf.keras.layers.MaxPool2D(), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(units=10, activation=tf.nn.softmax) + ]) + + return keras_model diff --git a/model_conditioning_examples/weight_clustering.py b/model_conditioning_examples/weight_clustering.py new file mode 100644 index 0000000..54f241c --- /dev/null +++ b/model_conditioning_examples/weight_clustering.py @@ -0,0 +1,107 @@ +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script will provide you with a short example of how to perform clustering of weights (weight sharing) in +TensorFlow using the TensorFlow Model Optimization Toolkit. + +The output from this example will be a TensorFlow Lite model file where weights in each layer have been 'clustered' into +16 clusters during training - quantization has then been applied on top of this. + +By clustering the model we can improve compression of the model file. This can be essential for deploying certain +models on systems with limited resources - such as embedded systems using an Arm Ethos NPU. + +After performing clustering we do post-training quantization to quantize the model and then generate a TensorFlow Lite file. + +If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +compiler for further optimizations before it can be used. + +For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ +For more information on clustering see: https://www.tensorflow.org/model_optimization/guide/clustering +""" +import pathlib + +import tensorflow as tf +import tensorflow_model_optimization as tfmot + +from training_utils import get_data, create_model +from post_training_quantization import post_training_quantize, evaluate_tflite_model + + +def prepare_for_clustering(keras_model): + """Prepares a Keras model for clustering.""" + + # Choose the number of clusters to use and how to initialize them. Using more clusters will generally + # reduce accuracy so you will need to find the optimal number for your use-case. + number_of_clusters = 16 + cluster_centroids_init = tfmot.clustering.keras.CentroidInitialization.LINEAR + + # Apply the clustering wrapper to the whole model so weights in every layer will get clustered. You may find that + # to avoid too much accuracy loss only certain non-critical layers in your model should be clustered. + clustering_ready_model = tfmot.clustering.keras.cluster_weights(keras_model, + number_of_clusters=number_of_clusters, + cluster_centroids_init=cluster_centroids_init) + + # We must recompile the model after making it ready for clustering. + clustering_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + loss=tf.keras.losses.sparse_categorical_crossentropy, + metrics=['accuracy']) + + return clustering_ready_model + + +def main(): + x_train, y_train, x_test, y_test = get_data() + model = create_model() + + # Compile and train the model first. + # In general it is easier to do clustering as a fine-tuning step after the model is fully trained. + model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + loss=tf.keras.losses.sparse_categorical_crossentropy, + metrics=['accuracy']) + + model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True) + + # Test the trained model accuracy. + test_loss, test_acc = model.evaluate(x_test, y_test) + print(f"Test accuracy before clustering: {test_acc:.3f}") + + # Prepare the model for clustering. + clustered_model = prepare_for_clustering(model) + + # Continue training the model but now with clustering applied. + clustered_model.fit(x=x_train, y=y_train, batch_size=128, epochs=1, verbose=1, shuffle=True) + test_loss, test_acc = clustered_model.evaluate(x_test, y_test) + print(f"Test accuracy after clustering: {test_acc:.3f}") + + # Remove all variables that clustering only needed in the training phase. + model_for_export = tfmot.clustering.keras.strip_clustering(clustered_model) + + # Apply post-training quantization on top of the clustering and save the resulting TensorFlow Lite model to file. + tflite_model = post_training_quantize(model_for_export, x_train) + + tflite_models_dir = pathlib.Path('./conditioned_models/') + tflite_models_dir.mkdir(exist_ok=True, parents=True) + + clustered_quant_model_save_path = tflite_models_dir / 'clustered_post_training_quant_model.tflite' + with open(clustered_quant_model_save_path, 'wb') as f: + f.write(tflite_model) + + # Test the clustered quantized model accuracy. Save time by only testing a subset of the whole data. + num_test_samples = 1000 + evaluate_tflite_model(clustered_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples]) + + +if __name__ == "__main__": + main() diff --git a/model_conditioning_examples/weight_pruning.py b/model_conditioning_examples/weight_pruning.py new file mode 100644 index 0000000..bf26f1f --- /dev/null +++ b/model_conditioning_examples/weight_pruning.py @@ -0,0 +1,106 @@ +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script will provide you with a short example of how to perform magnitude-based weight pruning in TensorFlow +using the TensorFlow Model Optimization Toolkit. + +The output from this example will be a TensorFlow Lite model file where ~75% percent of the weights have been 'pruned' to the +value 0 during training - quantization has then been applied on top of this. + +By pruning the model we can improve compression of the model file. This can be essential for deploying certain models +on systems with limited resources - such as embedded systems using Arm Ethos NPU. Also, if the pruned model is run +on an Arm Ethos NPU then this pruning can improve the execution time of the model. + +After pruning is complete we do post-training quantization to quantize the model and then generate a TensorFlow Lite file. + +If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela +compiler for further optimizations before it can be used. + +For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/ +For more information on weight pruning see: https://www.tensorflow.org/model_optimization/guide/pruning +""" +import pathlib + +import tensorflow as tf +import tensorflow_model_optimization as tfmot + +from training_utils import get_data, create_model +from post_training_quantization import post_training_quantize, evaluate_tflite_model + + +def prepare_for_pruning(keras_model): + """Prepares a Keras model for pruning.""" + + # We use a constant sparsity schedule so the amount of sparsity in the model is kept at the same percent throughout + # training. An alternative is PolynomialDecay where sparsity can be gradually increased during training. + pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.75, begin_step=0) + + # Apply the pruning wrapper to the whole model so weights in every layer will get pruned. You may find that to avoid + # too much accuracy loss only certain non-critical layers in your model should be pruned. + pruning_ready_model = tfmot.sparsity.keras.prune_low_magnitude(keras_model, pruning_schedule=pruning_schedule) + + # We must recompile the model after making it ready for pruning. + pruning_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + loss=tf.keras.losses.sparse_categorical_crossentropy, + metrics=['accuracy']) + + return pruning_ready_model + + +def main(): + x_train, y_train, x_test, y_test = get_data() + model = create_model() + + # Compile and train the model first. + # In general it is easier to do pruning as a fine-tuning step after the model is fully trained. + model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + loss=tf.keras.losses.sparse_categorical_crossentropy, + metrics=['accuracy']) + + model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True) + + # Test the trained model accuracy. + test_loss, test_acc = model.evaluate(x_test, y_test) + print(f"Test accuracy before pruning: {test_acc:.3f}") + + # Prepare the model for pruning and add the pruning update callback needed in training. + pruned_model = prepare_for_pruning(model) + callbacks = [tfmot.sparsity.keras.UpdatePruningStep()] + + # Continue training the model but now with pruning applied - remember to pass in the callbacks! + pruned_model.fit(x=x_train, y=y_train, batch_size=128, epochs=1, verbose=1, shuffle=True, callbacks=callbacks) + test_loss, test_acc = pruned_model.evaluate(x_test, y_test) + print(f"Test accuracy after pruning: {test_acc:.3f}") + + # Remove all variables that pruning only needed in the training phase. + model_for_export = tfmot.sparsity.keras.strip_pruning(pruned_model) + + # Apply post-training quantization on top of the pruning and save the resulting TensorFlow Lite model to file. + tflite_model = post_training_quantize(model_for_export, x_train) + + tflite_models_dir = pathlib.Path('./conditioned_models/') + tflite_models_dir.mkdir(exist_ok=True, parents=True) + + pruned_quant_model_save_path = tflite_models_dir / 'pruned_post_training_quant_model.tflite' + with open(pruned_quant_model_save_path, 'wb') as f: + f.write(tflite_model) + + # Test the pruned quantized model accuracy. Save time by only testing a subset of the whole data. + num_test_samples = 1000 + evaluate_tflite_model(pruned_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples]) + + +if __name__ == "__main__": + main() -- cgit v1.2.1