aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn')
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py4
-rw-r--r--src/mlia/nn/tensorflow/tflite_metrics.py7
2 files changed, 8 insertions, 3 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
index 15c043d..0a3fda5 100644
--- a/src/mlia/nn/tensorflow/optimizations/pruning.py
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -7,6 +7,7 @@ In order to do this, we need to have a base model and corresponding training dat
We also have to specify a subset of layers we want to prune. For more details,
please refer to the documentation for TensorFlow Model Optimization Toolkit.
"""
+import typing
from dataclasses import dataclass
from typing import List
from typing import Optional
@@ -138,6 +139,7 @@ class Pruner(Optimizer):
verbose=0,
)
+ @typing.no_type_check
def _assert_sparsity_reached(self) -> None:
for layer in self.model.layers:
if not isinstance(layer, pruning_wrapper.PruneLowMagnitude):
@@ -154,7 +156,7 @@ class Pruner(Optimizer):
self.optimizer_configuration.optimization_target,
1 - nonzero_weights / all_weights,
significant=2,
- ) # type: ignore
+ )
def _strip_pruning(self) -> None:
self.model = tfmot.sparsity.keras.strip_pruning(self.model)
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py
index 2252c6b..3f41487 100644
--- a/src/mlia/nn/tensorflow/tflite_metrics.py
+++ b/src/mlia/nn/tensorflow/tflite_metrics.py
@@ -9,6 +9,7 @@ These metrics include:
* gzip compression ratio
"""
import os
+import typing
from enum import Enum
from pprint import pprint
from typing import Any
@@ -31,12 +32,13 @@ DEFAULT_IGNORE_LIST = [
]
+@typing.no_type_check
def calculate_num_unique_weights(weights: np.ndarray) -> int:
"""Calculate the number of unique weights in the given weights."""
# Types need to be ignored for this function call because
# np.unique does not have type annotation while the
# current context does.
- num_unique_weights = len(np.unique(weights)) # type: ignore
+ num_unique_weights = len(np.unique(weights))
return num_unique_weights
@@ -207,6 +209,7 @@ class TFLiteMetrics:
return name.split("/", 1)[1]
return name
+ @typing.no_type_check
def summary(
self,
report_sparsity: bool,
@@ -248,7 +251,7 @@ class TFLiteMetrics:
# Types need to be ignored for this function call because
# np.unique does not have type annotation while the
# current context does.
- pprint(np.unique(weights)) # type: ignore
+ pprint(np.unique(weights))
# Add summary/overall values
empty_row = ["" for _ in range(len(header))]
summary_row = empty_row