From 7899b908c1fe6d86b92a80f3827ddd0ac05b674b Mon Sep 17 00:00:00 2001 From: Raul Farkas Date: Wed, 20 Jul 2022 15:57:37 +0100 Subject: MLIA-569 Update TensorFlow to version 2.8 - Update TensorFlow to version 2.8 (now supported by Vela 3.4) - Adapt existing codebase to preserve intermediary tensors in the interpreter in order to avoid errors when trying to print all of them in the future. - Ignore types for numpy methods that do not have typing annotations in their definitions. This is needed because otherwise mypy would consider the calling function to also be untyped. Change-Id: I943ac196fd4e378f5238949b15c23a2d628c8b5e --- setup.cfg | 2 +- src/mlia/nn/tensorflow/optimizations/pruning.py | 5 ++++- src/mlia/nn/tensorflow/tflite_metrics.py | 14 +++++++++++--- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/setup.cfg b/setup.cfg index dbed6f7..49b5ce5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,7 @@ package_dir = = src packages = find: install_requires = - tensorflow~=2.7.1 + tensorflow~=2.8.2 tensorflow-model-optimization~=0.7.2 ethos-u-vela~=3.4.0 requests diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py index f1e2976..15c043d 100644 --- a/src/mlia/nn/tensorflow/optimizations/pruning.py +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -147,11 +147,14 @@ class Pruner(Optimizer): nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight)) all_weights = tf.keras.backend.get_value(weight).size + # Types need to be ignored for this function call because + # np.testing.assert_approx_equal does not have type annotation while the + # current context does. np.testing.assert_approx_equal( self.optimizer_configuration.optimization_target, 1 - nonzero_weights / all_weights, significant=2, - ) + ) # type: ignore def _strip_pruning(self) -> None: self.model = tfmot.sparsity.keras.strip_pruning(self.model) diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py index 9befb2f..0fe36e0 100644 --- a/src/mlia/nn/tensorflow/tflite_metrics.py +++ b/src/mlia/nn/tensorflow/tflite_metrics.py @@ -33,7 +33,10 @@ DEFAULT_IGNORE_LIST = [ def calculate_num_unique_weights(weights: np.ndarray) -> int: """Calculate the number of unique weights in the given weights.""" - num_unique_weights = len(np.unique(weights)) + # Types need to be ignored for this function call because + # np.unique does not have type annotation while the + # current context does. + num_unique_weights = len(np.unique(weights)) # type: ignore return num_unique_weights @@ -114,7 +117,9 @@ class TFLiteMetrics: ignore_list = DEFAULT_IGNORE_LIST self.ignore_list = [ignore.casefold() for ignore in ignore_list] # Initialize the TFLite interpreter with the model file - self.interpreter = tf.lite.Interpreter(model_path=tflite_file) + self.interpreter = tf.lite.Interpreter( + model_path=tflite_file, experimental_preserve_all_tensors=True + ) self.interpreter.allocate_tensors() self.details: dict = {} @@ -242,7 +247,10 @@ class TFLiteMetrics: if verbose: # Print cluster centroids print("{} cluster centroids:".format(name)) - pprint(np.unique(weights)) + # Types need to be ignored for this function call because + # np.unique does not have type annotation while the + # current context does. + pprint(np.unique(weights)) # type: ignore # Add summary/overall values empty_row = ["" for _ in range(len(header))] summary_row = empty_row -- cgit v1.2.1