aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations/pruning.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations/pruning.py')
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
index f1e2976..15c043d 100644
--- a/src/mlia/nn/tensorflow/optimizations/pruning.py
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -147,11 +147,14 @@ class Pruner(Optimizer):
nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight))
all_weights = tf.keras.backend.get_value(weight).size
+ # Types need to be ignored for this function call because
+ # np.testing.assert_approx_equal does not have type annotation while the
+ # current context does.
np.testing.assert_approx_equal(
self.optimizer_configuration.optimization_target,
1 - nonzero_weights / all_weights,
significant=2,
- )
+ ) # type: ignore
def _strip_pruning(self) -> None:
self.model = tfmot.sparsity.keras.strip_pruning(self.model)