diff options
author | Ruomei Yan <ruomei.yan@arm.com> | 2022-08-16 09:06:55 +0100 |
---|---|---|
committer | Ruomei Yan <ruomei.yan@arm.com> | 2022-08-31 12:06:43 +0100 |
commit | 5f58ee6582030c6317e86e42fbe193dc78cc5619 (patch) | |
tree | adffd7b20ec22e8ee5afd73eb8079990a9ee0e96 /src/mlia/nn/tensorflow/optimizations | |
parent | 088303393df8bf49e4ea2958e88ff05aa50dd1ec (diff) | |
download | mlia-5f58ee6582030c6317e86e42fbe193dc78cc5619.tar.gz |
MLIA-599 Enable testing for aarch64: unit tests
- mypy issue: to make the comment #type: ignore platform specific,
flags like platform.machine() cannot be recognized by mypy, so we
cannot isolate the specific lines of code that fail mypy tests
- numpy issue: for numpy version < 1.20, the function np.unique
has not been type annotated, which caused mypy throwing the error
when we run our unit tests in aarch64
- because of the above two reasons, we use function decorator to
turn off type checking for entire functions to remove all
annotations so that the mypy error for certain lines can be silented
Change-Id: Id91e65ef7677b78b4c9c85b8412229e3672e3a66
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations')
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/pruning.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py index 15c043d..0a3fda5 100644 --- a/src/mlia/nn/tensorflow/optimizations/pruning.py +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -7,6 +7,7 @@ In order to do this, we need to have a base model and corresponding training dat We also have to specify a subset of layers we want to prune. For more details, please refer to the documentation for TensorFlow Model Optimization Toolkit. """ +import typing from dataclasses import dataclass from typing import List from typing import Optional @@ -138,6 +139,7 @@ class Pruner(Optimizer): verbose=0, ) + @typing.no_type_check def _assert_sparsity_reached(self) -> None: for layer in self.model.layers: if not isinstance(layer, pruning_wrapper.PruneLowMagnitude): @@ -154,7 +156,7 @@ class Pruner(Optimizer): self.optimizer_configuration.optimization_target, 1 - nonzero_weights / all_weights, significant=2, - ) # type: ignore + ) def _strip_pruning(self) -> None: self.model = tfmot.sparsity.keras.strip_pruning(self.model) |