diff options
Diffstat (limited to 'src/mlia/nn/tensorflow')
-rw-r--r-- | src/mlia/nn/tensorflow/optimizations/pruning.py | 4 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/tflite_metrics.py | 7 |
2 files changed, 8 insertions, 3 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py index 15c043d..0a3fda5 100644 --- a/src/mlia/nn/tensorflow/optimizations/pruning.py +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -7,6 +7,7 @@ In order to do this, we need to have a base model and corresponding training dat We also have to specify a subset of layers we want to prune. For more details, please refer to the documentation for TensorFlow Model Optimization Toolkit. """ +import typing from dataclasses import dataclass from typing import List from typing import Optional @@ -138,6 +139,7 @@ class Pruner(Optimizer): verbose=0, ) + @typing.no_type_check def _assert_sparsity_reached(self) -> None: for layer in self.model.layers: if not isinstance(layer, pruning_wrapper.PruneLowMagnitude): @@ -154,7 +156,7 @@ class Pruner(Optimizer): self.optimizer_configuration.optimization_target, 1 - nonzero_weights / all_weights, significant=2, - ) # type: ignore + ) def _strip_pruning(self) -> None: self.model = tfmot.sparsity.keras.strip_pruning(self.model) diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py index 2252c6b..3f41487 100644 --- a/src/mlia/nn/tensorflow/tflite_metrics.py +++ b/src/mlia/nn/tensorflow/tflite_metrics.py @@ -9,6 +9,7 @@ These metrics include: * gzip compression ratio """ import os +import typing from enum import Enum from pprint import pprint from typing import Any @@ -31,12 +32,13 @@ DEFAULT_IGNORE_LIST = [ ] +@typing.no_type_check def calculate_num_unique_weights(weights: np.ndarray) -> int: """Calculate the number of unique weights in the given weights.""" # Types need to be ignored for this function call because # np.unique does not have type annotation while the # current context does. - num_unique_weights = len(np.unique(weights)) # type: ignore + num_unique_weights = len(np.unique(weights)) return num_unique_weights @@ -207,6 +209,7 @@ class TFLiteMetrics: return name.split("/", 1)[1] return name + @typing.no_type_check def summary( self, report_sparsity: bool, @@ -248,7 +251,7 @@ class TFLiteMetrics: # Types need to be ignored for this function call because # np.unique does not have type annotation while the # current context does. - pprint(np.unique(weights)) # type: ignore + pprint(np.unique(weights)) # Add summary/overall values empty_row = ["" for _ in range(len(header))] summary_row = empty_row |