aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
AgeCommit message (Collapse)Author
4 daysfeat: Enable user parameters for activation functions in conv2d rewritesNathan Bailey
Allow the user to specify an activation function for conv2d rewrites Enable automatic detection of most common activation function in rewrite in the case that the user does not specify one Resolves: MLIA-1163 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: Icbf6f4c6f8eaba6d78b88bdf62448f1d30aed1ae
6 daysfeat: Enable rewrite parameterisation for specific rewritesNathan Bailey
Adds support for rewrite-specific parameters Resolves: MLIA-1114 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I290c326af3356033a916a43b28027819c876c3dd
13 daysfix: Extend docstrings in the rewrite moduleNathan Bailey
Rework doctrings in rewrite functions based on recent changes Resolves MLIA-944 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I31a37e17a296f8a16d0db408d48c6de65c05300e
13 daysfeat: Implement the conv2D rewrites for int8 and fp32 modelsNathan Bailey
Enable clustering and fully connected rewrites for conv2D layers. Resolves: MLIA-1159 and MLIA-1160 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I640b8a7e79e455b12fb68d02ac1c33213b8de9c6
13 daysfeat: CLI and API changes for the conv2d rewritesNathan Bailey
Implements CLI and API changes for the new conv2d rewrite targets Resolves: MLIA-1157 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I03c7a3a536d2f0a805b4689a9d96b95f8b4ab86c
2024-04-16feat: Implement the clustering rewrite for int8Nathan Bailey
Implements a clustering rewrite for fully connected layers for int8 models Resolves: MLIA-1080 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: If48efb22764187a382e5b84bbb5c3b75a6e71b75
2024-04-16feat: Implement the clustering rewrite for fp32Nathan Bailey
Implements a clustering rewrite for fully connected layers for fp32 models Resolves: MLIA-1079 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I4c12f0bf911219b4066f0760976e424ebe900a0b
2024-04-16feat: CLI and API changes for the clustering rewriteNathan Bailey
Adds API changes for a fully-connected-clustering rewrite Resolves: MLIA-1077 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I845796a391c5020e66472456b97ecad5ee8139a8
2024-04-04feat: Implement int8 sparsity 2:4 rewriteMadeleine Dunn
- Implement pruning-preserving quantisation aware training - Rework the training logic to avoid duplication - Remove the DynamicallyLoadedRewrite class as it is now unused Resolves: MLIA-1003 Signed-off-by: Madeleine Dunn <madeleine.dunn@arm.com> Change-Id: Ia7a4acf5f477a27963cffa88180cca085b32ffe4
2024-04-03feat: Implement fp32 sparsity 2:4 rewriteMadeleine Dunn
- Update the existing placeholder with code to prune the given model Resolves: MLIA-1002 Signed-off-by: Madeleine Dunn <madeleine.dunn@arm.com> Change-Id: I76b0e0bfe81be5e57d518cd7bb588eef76a11641
2024-04-03feat: Add placeholder rewrite_target for sparsity24Madeleine Dunn
- The placeholder currently duplicates the existing fc target Resolves: MLIA-1000 Signed-off-by: Madeleine Dunn <madeleine.dunn@arm.com> Change-Id: I0df5d47e61dafa567e212566bbcb0f1639fe7642
2024-04-03feat: Implement sparsityMadeleine Dunn
- Add a placeholder file and registry option for sparsity Resolves: MLIA-999 Signed-off-by: Madeleine Dunn <madeleine.dunn@arm.com> Change-Id: I273192ba6813309f5226e3d8e0b686ce87ee6b79
2024-03-28fix: Check that training checkpoint feature works as expectedNathan Bailey
Fixes the checkpoint feature in training and also completes unit tests for it Resolves: MLIA-1111 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: Ic2b84b4b045db5ba3cb299fcd137ae9d31df5298
2024-03-28feat: Update Vela versionNathan Bailey
Updates Vela Version to 3.11.0 and TensorFlow version to 2.15.1 Required keras import to change: from keras.api._v2 import keras needed instead of calling tf.keras Subsequently tf.keras.X needed to change to keras.X Resolves: MLIA-1107 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I53bcaa9cdad58b0e6c311c8c6490393d33cb18bc
2024-03-27fix: Update rewrite target nameMadeleine Dunn
- Rename "fully_connected" to "fully-connected" - This will resolve issues with upstreaming rewrite library changes Signed-off-by: Madeleine Dunn <madeleine.dunn@arm.com> Change-Id: I2f24ae4917a556fd0bd44f0db6ee4e0f7a68cd24
2023-10-11Re-factoring of rewrite management & added metricsGergely Nagy
- List available rewrites - Refactor/rename 'Rewrite' class to 'RewritingOptimizer' - Introduce a registry for rewrite functions - Refactor 'Rewriter' to use the registry to look up rewrite functions - Remove mentions of hardcoded "fully_connected" from CLI help and error messages, using the registry instead - Add unit tests - Enable rewrites for all targets: Extract optimization (including rewrite specific code) from the Ethos-U-specific data collector into OptimizingDataCollector. This is reused in other targets' collectors, such as TOSA and Cortex-A. - Add more logging for rewrite - add display of MAE and NRMSE values for the trained result - add total model MAE and NRMSE metric Resolves: MLIA-891, MLIA-899, MLIA-906 Change-Id: Ie798749e1ed60cab14fdb6d9c2271c833960e93f Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
2023-10-11Bug-fixes and re-factoring for the rewrite moduleBenjamin Klimczak
- Fix input shape of rewrite replacement: During and after training of the replacement model for a rewrite the Keras model is converted and saved in TensorFlow Lite format. If the input shape does not match the teacher model exactly, e.g. if the batch size is undefined, the TFLiteConverter adds extra operators during conversion. - Fix rewritten model output - Save the model output with the rewritten operator in the output dir - Log MAE and NRMSE of the rewrite - Remove 'verbose' flag from rewrite module and rely on the logging mechanism to control verbose output. - Re-factor utility classes for rewrites - Merge the two TFLiteModel classes - Move functionality to load/save TensorFlow Lite flatbuffers to nn/tensorflow/tflite_graph - Fix issue with unknown shape in datasets After upgrading to TensorFlow 2.12 the unknown shape of the TFRecordDataset is causing problems when training the replacement models for rewrites. By explicitly setting the right shape of the tensors we can work around the issue. - Adapt default parameters for rewrites. The training steps especially had to be increased significantly to be effective. Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c
2023-10-11Implement first rewrite (proof of concept)Ruomei Yan
* Define replacement function fully_connected layer * Define RewriteConfiguration and Rewriter to integrate rewrite module into mlia optimize command * Fix a bug in the ethos_u/data_collection.py file * Fix a bug in join.py * Remove diff_stats and use diff instead, added related changes around this to ensure e2e tests passing * Add unit tests for all changes * Fix bug in diff_stats function * The bug was caused by a dividing by numpy array of all zeros. The previous way of handling it did not consider the all zeros case but only dealt with partially zeros * unit tests added. * Fix the bug in rewrite/core/graph_edit/join.py * Remove the possibility of passing None to append_relabel function because it is immutable * The bug happened when empty dictionary was passed in the append_relabel function and the function overwrites the reference of operator_map which caused the dictionary was not updated after the function call Resolves: MLIA-749, MLIA-864, MLIA-866 Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
2023-10-11Adapt rewrite module to MLIA coding standardsAnnie Tallund
- Fix imports - Update variable names - Refactor helper functions - Add licence headers - Add docstrings - Use f-strings rather than % notation - Create type annotations in rewrite module - Migrate from tqdm to rich progress bar - Use logging module in rewrite module: All print statements are replaced with logging module Resolves: MLIA-831, MLIA-842, MLIA-844, MLIA-846 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: Idee37538d72b9f01128a894281a8d10155f7c17c
2023-10-11Add a CLI component to enable rewritesRuomei Yan
* Add flags for rewrite (--rewrite, --rewrite-start, --rewrite-end, --rewrite-target) * Refactor CLI interfaces to accept tflite models with optimize for rewrite, keras models with optimize for clustering and pruning * Refactor and move common.py and select.py out of the folder nn/tensorflow/optimizations * Add file nn/rewrite/core/rewrite.py as placeholder * Update/add unit tests * Refactor OptimizeModel in ethos_u/data_collection.py for accepting tflite model case * Extend the logic so that if "--rewrite" is specified, we don't add pruning to also accept TFLite models. * Update README.md Resolves: MLIA-750, MLIA-854, MLIA-865 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: I67d85f71fa253d2bad4efe304ad8225970b9622c