diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations/pruning.py')
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/pruning.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py index 15c043d..0a3fda5 100644 --- a/src/mlia/nn/tensorflow/optimizations/pruning.py +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -7,6 +7,7 @@ In order to do this, we need to have a base model and corresponding training dat We also have to specify a subset of layers we want to prune. For more details, please refer to the documentation for TensorFlow Model Optimization Toolkit. """ +import typing from dataclasses import dataclass from typing import List from typing import Optional @@ -138,6 +139,7 @@ class Pruner(Optimizer): verbose=0, ) + @typing.no_type_check def _assert_sparsity_reached(self) -> None: for layer in self.model.layers: if not isinstance(layer, pruning_wrapper.PruneLowMagnitude): @@ -154,7 +156,7 @@ class Pruner(Optimizer): self.optimizer_configuration.optimization_target, 1 - nonzero_weights / all_weights, significant=2, - ) # type: ignore + ) def _strip_pruning(self) -> None: self.model = tfmot.sparsity.keras.strip_pruning(self.model) |