diff options
author | Raul Farkas <raul.farkas@arm.com> | 2022-07-20 15:57:37 +0100 |
---|---|---|
committer | Raul Farkas <raul.farkas@arm.com> | 2022-07-22 17:08:18 +0100 |
commit | 7899b908c1fe6d86b92a80f3827ddd0ac05b674b (patch) | |
tree | 7c0ca1250a8f1e28808660e9482ec55230601405 /src/mlia/nn/tensorflow/tflite_metrics.py | |
parent | 625c280433fef3c9d1b64f58eab930ba0f89cd82 (diff) | |
download | mlia-7899b908c1fe6d86b92a80f3827ddd0ac05b674b.tar.gz |
MLIA-569 Update TensorFlow to version 2.8
- Update TensorFlow to version 2.8 (now supported by Vela 3.4)
- Adapt existing codebase to preserve intermediary tensors in the interpreter in order to avoid errors when trying to print all of them in the future.
- Ignore types for numpy methods that do not have typing annotations in their definitions. This is needed because otherwise mypy would consider the calling function to also be untyped.
Change-Id: I943ac196fd4e378f5238949b15c23a2d628c8b5e
Diffstat (limited to 'src/mlia/nn/tensorflow/tflite_metrics.py')
-rw-r--r-- | src/mlia/nn/tensorflow/tflite_metrics.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py index 9befb2f..0fe36e0 100644 --- a/src/mlia/nn/tensorflow/tflite_metrics.py +++ b/src/mlia/nn/tensorflow/tflite_metrics.py @@ -33,7 +33,10 @@ DEFAULT_IGNORE_LIST = [ def calculate_num_unique_weights(weights: np.ndarray) -> int: """Calculate the number of unique weights in the given weights.""" - num_unique_weights = len(np.unique(weights)) + # Types need to be ignored for this function call because + # np.unique does not have type annotation while the + # current context does. + num_unique_weights = len(np.unique(weights)) # type: ignore return num_unique_weights @@ -114,7 +117,9 @@ class TFLiteMetrics: ignore_list = DEFAULT_IGNORE_LIST self.ignore_list = [ignore.casefold() for ignore in ignore_list] # Initialize the TFLite interpreter with the model file - self.interpreter = tf.lite.Interpreter(model_path=tflite_file) + self.interpreter = tf.lite.Interpreter( + model_path=tflite_file, experimental_preserve_all_tensors=True + ) self.interpreter.allocate_tensors() self.details: dict = {} @@ -242,7 +247,10 @@ class TFLiteMetrics: if verbose: # Print cluster centroids print("{} cluster centroids:".format(name)) - pprint(np.unique(weights)) + # Types need to be ignored for this function call because + # np.unique does not have type annotation while the + # current context does. + pprint(np.unique(weights)) # type: ignore # Add summary/overall values empty_row = ["" for _ in range(len(header))] summary_row = empty_row |