aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/diff.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/diff.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/diff.py23
1 files changed, 16 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py
index 198e47e..7fa2a72 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/diff.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py
@@ -31,6 +31,21 @@ def add_total(name: str, key: str, values: list, totals: dict) -> None:
totals[name][key] += values
+def _handle_zeros_in_denominator(denominator: np.ndarray) -> np.ndarray:
+ """Handle zeros in the denominator in nrmse to avoid dividing by zero(s)."""
+ denominator[denominator == 0.0] = 1.0
+ return denominator
+
+
+def calc_nrmse(rmse: dict, dataset1_var: dict) -> dict:
+ """Divide rmse by target standard deviation."""
+ nrmse = {
+ k: v / _handle_zeros_in_denominator(np.sqrt(dataset1_var[k]))
+ for k, v in rmse.items()
+ }
+ return nrmse
+
+
def diff_stats(
file1: str | Path, file2: str | Path, per_tensor_and_channel: bool = False
) -> tuple:
@@ -80,14 +95,8 @@ def diff_stats(
mse = per_tensor_mean("se")
rmse = {k: np.sqrt(v) for k, v in mse.items()}
dataset1_var = per_tensor_mean("dataset1_variance")
- is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var}
- # Divide by target standard deviation to get the per-channel nrmse for each
- # tensor where possible
- nrmse = {
- k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]])
- for k, v in rmse.items()
- }
+ nrmse = calc_nrmse(rmse, dataset1_var)
if per_tensor_and_channel:
return mae, nrmse