aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2022-07-20 15:57:37 +0100
committerRaul Farkas <raul.farkas@arm.com>2022-07-22 17:08:18 +0100
commit7899b908c1fe6d86b92a80f3827ddd0ac05b674b (patch)
tree7c0ca1250a8f1e28808660e9482ec55230601405 /src/mlia/nn/tensorflow/optimizations
parent625c280433fef3c9d1b64f58eab930ba0f89cd82 (diff)
downloadmlia-7899b908c1fe6d86b92a80f3827ddd0ac05b674b.tar.gz
MLIA-569 Update TensorFlow to version 2.8
- Update TensorFlow to version 2.8 (now supported by Vela 3.4) - Adapt existing codebase to preserve intermediary tensors in the interpreter in order to avoid errors when trying to print all of them in the future. - Ignore types for numpy methods that do not have typing annotations in their definitions. This is needed because otherwise mypy would consider the calling function to also be untyped. Change-Id: I943ac196fd4e378f5238949b15c23a2d628c8b5e
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations')
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
index f1e2976..15c043d 100644
--- a/src/mlia/nn/tensorflow/optimizations/pruning.py
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -147,11 +147,14 @@ class Pruner(Optimizer):
nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight))
all_weights = tf.keras.backend.get_value(weight).size
+ # Types need to be ignored for this function call because
+ # np.testing.assert_approx_equal does not have type annotation while the
+ # current context does.
np.testing.assert_approx_equal(
self.optimizer_configuration.optimization_target,
1 - nonzero_weights / all_weights,
significant=2,
- )
+ ) # type: ignore
def _strip_pruning(self) -> None:
self.model = tfmot.sparsity.keras.strip_pruning(self.model)