diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/tflite_metrics.py')
-rw-r--r-- | src/mlia/nn/tensorflow/tflite_metrics.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py index 2252c6b..3f41487 100644 --- a/src/mlia/nn/tensorflow/tflite_metrics.py +++ b/src/mlia/nn/tensorflow/tflite_metrics.py @@ -9,6 +9,7 @@ These metrics include: * gzip compression ratio """ import os +import typing from enum import Enum from pprint import pprint from typing import Any @@ -31,12 +32,13 @@ DEFAULT_IGNORE_LIST = [ ] +@typing.no_type_check def calculate_num_unique_weights(weights: np.ndarray) -> int: """Calculate the number of unique weights in the given weights.""" # Types need to be ignored for this function call because # np.unique does not have type annotation while the # current context does. - num_unique_weights = len(np.unique(weights)) # type: ignore + num_unique_weights = len(np.unique(weights)) return num_unique_weights @@ -207,6 +209,7 @@ class TFLiteMetrics: return name.split("/", 1)[1] return name + @typing.no_type_check def summary( self, report_sparsity: bool, @@ -248,7 +251,7 @@ class TFLiteMetrics: # Types need to be ignored for this function call because # np.unique does not have type annotation while the # current context does. - pprint(np.unique(weights)) # type: ignore + pprint(np.unique(weights)) # Add summary/overall values empty_row = ["" for _ in range(len(header))] summary_row = empty_row |