aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/tflite_metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/tflite_metrics.py')
-rw-r--r--src/mlia/nn/tensorflow/tflite_metrics.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py
index 9befb2f..0fe36e0 100644
--- a/src/mlia/nn/tensorflow/tflite_metrics.py
+++ b/src/mlia/nn/tensorflow/tflite_metrics.py
@@ -33,7 +33,10 @@ DEFAULT_IGNORE_LIST = [
def calculate_num_unique_weights(weights: np.ndarray) -> int:
"""Calculate the number of unique weights in the given weights."""
- num_unique_weights = len(np.unique(weights))
+ # Types need to be ignored for this function call because
+ # np.unique does not have type annotation while the
+ # current context does.
+ num_unique_weights = len(np.unique(weights)) # type: ignore
return num_unique_weights
@@ -114,7 +117,9 @@ class TFLiteMetrics:
ignore_list = DEFAULT_IGNORE_LIST
self.ignore_list = [ignore.casefold() for ignore in ignore_list]
# Initialize the TFLite interpreter with the model file
- self.interpreter = tf.lite.Interpreter(model_path=tflite_file)
+ self.interpreter = tf.lite.Interpreter(
+ model_path=tflite_file, experimental_preserve_all_tensors=True
+ )
self.interpreter.allocate_tensors()
self.details: dict = {}
@@ -242,7 +247,10 @@ class TFLiteMetrics:
if verbose:
# Print cluster centroids
print("{} cluster centroids:".format(name))
- pprint(np.unique(weights))
+ # Types need to be ignored for this function call because
+ # np.unique does not have type annotation while the
+ # current context does.
+ pprint(np.unique(weights)) # type: ignore
# Add summary/overall values
empty_row = ["" for _ in range(len(header))]
summary_row = empty_row