diff options
-rw-r--r-- | setup.cfg | 2 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/pruning.py | 5 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/tflite_metrics.py | 14 |
3 files changed, 16 insertions, 5 deletions
@@ -29,7 +29,7 @@ package_dir = = src packages = find: install_requires = - tensorflow~=2.7.1 + tensorflow~=2.8.2 tensorflow-model-optimization~=0.7.2 ethos-u-vela~=3.4.0 requests diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py index f1e2976..15c043d 100644 --- a/src/mlia/nn/tensorflow/optimizations/pruning.py +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -147,11 +147,14 @@ class Pruner(Optimizer): nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight)) all_weights = tf.keras.backend.get_value(weight).size + # Types need to be ignored for this function call because + # np.testing.assert_approx_equal does not have type annotation while the + # current context does. np.testing.assert_approx_equal( self.optimizer_configuration.optimization_target, 1 - nonzero_weights / all_weights, significant=2, - ) + ) # type: ignore def _strip_pruning(self) -> None: self.model = tfmot.sparsity.keras.strip_pruning(self.model) diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py index 9befb2f..0fe36e0 100644 --- a/src/mlia/nn/tensorflow/tflite_metrics.py +++ b/src/mlia/nn/tensorflow/tflite_metrics.py @@ -33,7 +33,10 @@ DEFAULT_IGNORE_LIST = [ def calculate_num_unique_weights(weights: np.ndarray) -> int: """Calculate the number of unique weights in the given weights.""" - num_unique_weights = len(np.unique(weights)) + # Types need to be ignored for this function call because + # np.unique does not have type annotation while the + # current context does. + num_unique_weights = len(np.unique(weights)) # type: ignore return num_unique_weights @@ -114,7 +117,9 @@ class TFLiteMetrics: ignore_list = DEFAULT_IGNORE_LIST self.ignore_list = [ignore.casefold() for ignore in ignore_list] # Initialize the TFLite interpreter with the model file - self.interpreter = tf.lite.Interpreter(model_path=tflite_file) + self.interpreter = tf.lite.Interpreter( + model_path=tflite_file, experimental_preserve_all_tensors=True + ) self.interpreter.allocate_tensors() self.details: dict = {} @@ -242,7 +247,10 @@ class TFLiteMetrics: if verbose: # Print cluster centroids print("{} cluster centroids:".format(name)) - pprint(np.unique(weights)) + # Types need to be ignored for this function call because + # np.unique does not have type annotation while the + # current context does. + pprint(np.unique(weights)) # type: ignore # Add summary/overall values empty_row = ["" for _ in range(len(header))] summary_row = empty_row |