aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations
AgeCommit message (Collapse)Author
2023-12-04Update to Vela 3.10, TensorFlow 2.14, Python 3.9Benjamin Klimczak
Updating to Vela 3.10 which requires TensorFlow 2.14 which requires Python 3.9 (dropping support for Python 3.8). Resolves: MLIA-997 Change-Id: Id60bd08f7156a8efa204ef71ba81590edf0e3b28 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
2023-10-11Enable rewrites for quantized input modelsBenjamin Klimczak
If the input model for rewriting is quantized: - Record de-quantized TFRecords - enable writing de-quantized calibration data for the training - re-generate augmented training data, if needed - Use quantization-aware training (QAT) to train the replacement models - Check if replacement model is quantized: If source model is quantized, we make sure rewrite's output model is quantized too. Right now, only int8 is supported so raising an error if any other datatype is present in the output. Resolves: MLIA-907, MLIA-908, MLIA-927 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: Icb4070a9e6f1fdb5ce36120d73823986e89ac955
2023-10-11Add a CLI component to enable rewritesRuomei Yan
* Add flags for rewrite (--rewrite, --rewrite-start, --rewrite-end, --rewrite-target) * Refactor CLI interfaces to accept tflite models with optimize for rewrite, keras models with optimize for clustering and pruning * Refactor and move common.py and select.py out of the folder nn/tensorflow/optimizations * Add file nn/rewrite/core/rewrite.py as placeholder * Update/add unit tests * Refactor OptimizeModel in ethos_u/data_collection.py for accepting tflite model case * Extend the logic so that if "--rewrite" is specified, we don't add pruning to also accept TFLite models. * Update README.md Resolves: MLIA-750, MLIA-854, MLIA-865 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: I67d85f71fa253d2bad4efe304ad8225970b9622c
2023-09-26MLIA-469 Support batch size > 1 for optimizationsAnnie Tallund
- Add a PruningPolicy to skip layers that are not supported by the Keras pruning API - Make dataset generation more generic to support use-cases beyond classification Signed-off-by: Annie Tallund <annie.tallund@arm.com> Change-Id: I198dae2b53860f449f2fdbc71575babceed1ffcf
2023-09-05MLIA-961 Update tox dependenciesBenjamin Klimczak
- Update version dependencies in the tox.ini - Fix linter issues Change-Id: I04c3a841ee2646a865dab037701d66c28792f2a4 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
2022-09-09MLIA-386 Simplify typing in the source codeDmitrii Agibov
- Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
2022-08-31MLIA-599 Enable testing for aarch64: unit testsRuomei Yan
- mypy issue: to make the comment #type: ignore platform specific, flags like platform.machine() cannot be recognized by mypy, so we cannot isolate the specific lines of code that fail mypy tests - numpy issue: for numpy version < 1.20, the function np.unique has not been type annotated, which caused mypy throwing the error when we run our unit tests in aarch64 - because of the above two reasons, we use function decorator to turn off type checking for entire functions to remove all annotations so that the mypy error for certain lines can be silented Change-Id: Id91e65ef7677b78b4c9c85b8412229e3672e3a66
2022-07-22MLIA-569 Update TensorFlow to version 2.8Raul Farkas
- Update TensorFlow to version 2.8 (now supported by Vela 3.4) - Adapt existing codebase to preserve intermediary tensors in the interpreter in order to avoid errors when trying to print all of them in the future. - Ignore types for numpy methods that do not have typing annotations in their definitions. This is needed because otherwise mypy would consider the calling function to also be untyped. Change-Id: I943ac196fd4e378f5238949b15c23a2d628c8b5e
2022-07-22MLIA-507 Upgrade Vela versionRaul Farkas
Upgrade Vela version from 3.3.0 to 3.4.0. - Adapt code to use new typing notation by replacing `numpy.array` with `numpy.ndarray` where necessary. Change-Id: I035e9564d448652aa09a52d79c71ef09663ea776
2022-05-30Add MLIA codebase0.3.0-rc.1Diego Russo
Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd