aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/tflite_metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/tflite_metrics.py')
-rw-r--r--src/mlia/nn/tensorflow/tflite_metrics.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py
index 2252c6b..3f41487 100644
--- a/src/mlia/nn/tensorflow/tflite_metrics.py
+++ b/src/mlia/nn/tensorflow/tflite_metrics.py
@@ -9,6 +9,7 @@ These metrics include:
* gzip compression ratio
"""
import os
+import typing
from enum import Enum
from pprint import pprint
from typing import Any
@@ -31,12 +32,13 @@ DEFAULT_IGNORE_LIST = [
]
+@typing.no_type_check
def calculate_num_unique_weights(weights: np.ndarray) -> int:
"""Calculate the number of unique weights in the given weights."""
# Types need to be ignored for this function call because
# np.unique does not have type annotation while the
# current context does.
- num_unique_weights = len(np.unique(weights)) # type: ignore
+ num_unique_weights = len(np.unique(weights))
return num_unique_weights
@@ -207,6 +209,7 @@ class TFLiteMetrics:
return name.split("/", 1)[1]
return name
+ @typing.no_type_check
def summary(
self,
report_sparsity: bool,
@@ -248,7 +251,7 @@ class TFLiteMetrics:
# Types need to be ignored for this function call because
# np.unique does not have type annotation while the
# current context does.
- pprint(np.unique(weights)) # type: ignore
+ pprint(np.unique(weights))
# Add summary/overall values
empty_row = ["" for _ in range(len(header))]
summary_row = empty_row