aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations/pruning.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations/pruning.py')
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
index 15c043d..0a3fda5 100644
--- a/src/mlia/nn/tensorflow/optimizations/pruning.py
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -7,6 +7,7 @@ In order to do this, we need to have a base model and corresponding training dat
We also have to specify a subset of layers we want to prune. For more details,
please refer to the documentation for TensorFlow Model Optimization Toolkit.
"""
+import typing
from dataclasses import dataclass
from typing import List
from typing import Optional
@@ -138,6 +139,7 @@ class Pruner(Optimizer):
verbose=0,
)
+ @typing.no_type_check
def _assert_sparsity_reached(self) -> None:
for layer in self.model.layers:
if not isinstance(layer, pruning_wrapper.PruneLowMagnitude):
@@ -154,7 +156,7 @@ class Pruner(Optimizer):
self.optimizer_configuration.optimization_target,
1 - nonzero_weights / all_weights,
significant=2,
- ) # type: ignore
+ )
def _strip_pruning(self) -> None:
self.model = tfmot.sparsity.keras.strip_pruning(self.model)