diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations/pruning.py')
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/pruning.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py index f1e2976..15c043d 100644 --- a/src/mlia/nn/tensorflow/optimizations/pruning.py +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -147,11 +147,14 @@ class Pruner(Optimizer): nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight)) all_weights = tf.keras.backend.get_value(weight).size + # Types need to be ignored for this function call because + # np.testing.assert_approx_equal does not have type annotation while the + # current context does. np.testing.assert_approx_equal( self.optimizer_configuration.optimization_target, 1 - nonzero_weights / all_weights, significant=2, - ) + ) # type: ignore def _strip_pruning(self) -> None: self.model = tfmot.sparsity.keras.strip_pruning(self.model) |