aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2022-07-20 15:57:37 +0100
committerRaul Farkas <raul.farkas@arm.com>2022-07-22 17:08:18 +0100
commit7899b908c1fe6d86b92a80f3827ddd0ac05b674b (patch)
tree7c0ca1250a8f1e28808660e9482ec55230601405
parent625c280433fef3c9d1b64f58eab930ba0f89cd82 (diff)
downloadmlia-7899b908c1fe6d86b92a80f3827ddd0ac05b674b.tar.gz
MLIA-569 Update TensorFlow to version 2.8
- Update TensorFlow to version 2.8 (now supported by Vela 3.4) - Adapt existing codebase to preserve intermediary tensors in the interpreter in order to avoid errors when trying to print all of them in the future. - Ignore types for numpy methods that do not have typing annotations in their definitions. This is needed because otherwise mypy would consider the calling function to also be untyped. Change-Id: I943ac196fd4e378f5238949b15c23a2d628c8b5e
-rw-r--r--setup.cfg2
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py5
-rw-r--r--src/mlia/nn/tensorflow/tflite_metrics.py14
3 files changed, 16 insertions, 5 deletions
diff --git a/setup.cfg b/setup.cfg
index dbed6f7..49b5ce5 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -29,7 +29,7 @@ package_dir =
= src
packages = find:
install_requires =
- tensorflow~=2.7.1
+ tensorflow~=2.8.2
tensorflow-model-optimization~=0.7.2
ethos-u-vela~=3.4.0
requests
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
index f1e2976..15c043d 100644
--- a/src/mlia/nn/tensorflow/optimizations/pruning.py
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -147,11 +147,14 @@ class Pruner(Optimizer):
nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight))
all_weights = tf.keras.backend.get_value(weight).size
+ # Types need to be ignored for this function call because
+ # np.testing.assert_approx_equal does not have type annotation while the
+ # current context does.
np.testing.assert_approx_equal(
self.optimizer_configuration.optimization_target,
1 - nonzero_weights / all_weights,
significant=2,
- )
+ ) # type: ignore
def _strip_pruning(self) -> None:
self.model = tfmot.sparsity.keras.strip_pruning(self.model)
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py
index 9befb2f..0fe36e0 100644
--- a/src/mlia/nn/tensorflow/tflite_metrics.py
+++ b/src/mlia/nn/tensorflow/tflite_metrics.py
@@ -33,7 +33,10 @@ DEFAULT_IGNORE_LIST = [
def calculate_num_unique_weights(weights: np.ndarray) -> int:
"""Calculate the number of unique weights in the given weights."""
- num_unique_weights = len(np.unique(weights))
+ # Types need to be ignored for this function call because
+ # np.unique does not have type annotation while the
+ # current context does.
+ num_unique_weights = len(np.unique(weights)) # type: ignore
return num_unique_weights
@@ -114,7 +117,9 @@ class TFLiteMetrics:
ignore_list = DEFAULT_IGNORE_LIST
self.ignore_list = [ignore.casefold() for ignore in ignore_list]
# Initialize the TFLite interpreter with the model file
- self.interpreter = tf.lite.Interpreter(model_path=tflite_file)
+ self.interpreter = tf.lite.Interpreter(
+ model_path=tflite_file, experimental_preserve_all_tensors=True
+ )
self.interpreter.allocate_tensors()
self.details: dict = {}
@@ -242,7 +247,10 @@ class TFLiteMetrics:
if verbose:
# Print cluster centroids
print("{} cluster centroids:".format(name))
- pprint(np.unique(weights))
+ # Types need to be ignored for this function call because
+ # np.unique does not have type annotation while the
+ # current context does.
+ pprint(np.unique(weights)) # type: ignore
# Add summary/overall values
empty_row = ["" for _ in range(len(header))]
summary_row = empty_row