diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/tflite_metrics.py')
-rw-r--r-- | src/mlia/nn/tensorflow/tflite_metrics.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py index 9befb2f..0fe36e0 100644 --- a/src/mlia/nn/tensorflow/tflite_metrics.py +++ b/src/mlia/nn/tensorflow/tflite_metrics.py @@ -33,7 +33,10 @@ DEFAULT_IGNORE_LIST = [ def calculate_num_unique_weights(weights: np.ndarray) -> int: """Calculate the number of unique weights in the given weights.""" - num_unique_weights = len(np.unique(weights)) + # Types need to be ignored for this function call because + # np.unique does not have type annotation while the + # current context does. + num_unique_weights = len(np.unique(weights)) # type: ignore return num_unique_weights @@ -114,7 +117,9 @@ class TFLiteMetrics: ignore_list = DEFAULT_IGNORE_LIST self.ignore_list = [ignore.casefold() for ignore in ignore_list] # Initialize the TFLite interpreter with the model file - self.interpreter = tf.lite.Interpreter(model_path=tflite_file) + self.interpreter = tf.lite.Interpreter( + model_path=tflite_file, experimental_preserve_all_tensors=True + ) self.interpreter.allocate_tensors() self.details: dict = {} @@ -242,7 +247,10 @@ class TFLiteMetrics: if verbose: # Print cluster centroids print("{} cluster centroids:".format(name)) - pprint(np.unique(weights)) + # Types need to be ignored for this function call because + # np.unique does not have type annotation while the + # current context does. + pprint(np.unique(weights)) # type: ignore # Add summary/overall values empty_row = ["" for _ in range(len(header))] summary_row = empty_row |