aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/diff.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/diff.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/diff.py102
1 files changed, 59 insertions, 43 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py
index 0829f0a..198e47e 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/diff.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py
@@ -1,63 +1,79 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Diff module: compare subgraph outputs."""
+# pylint: disable=too-many-locals
+from __future__ import annotations
+
import os
from collections import defaultdict
+from pathlib import Path
+from typing import Any
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import numpy as np
import tensorflow as tf
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-import numpy as np
-from tensorflow.lite.python import interpreter as interpreter_wrapper
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter
+def dict_mean(mean_dict: dict) -> Any:
+ """Return the mean of values in a given dict."""
+ return np.mean(list(mean_dict.values()))
-def diff_stats(file1, file2, per_tensor_and_channel=False):
- dataset1 = NumpyTFReader(file1)
- dataset2 = NumpyTFReader(file2)
- totals = defaultdict(dict)
+def add_total(name: str, key: str, values: list, totals: dict) -> None:
+ """Append values to dict totals."""
+ if key not in totals[name]:
+ totals[name][key] = values
+ else:
+ totals[name][key] += values
+
- def add_total(name, key, values):
- if not key in totals[name]:
- totals[name][key] = values
- else:
- totals[name][key] += values
+def diff_stats(
+ file1: str | Path, file2: str | Path, per_tensor_and_channel: bool = False
+) -> tuple:
+ """Compare the statistics of outputs between two subgraphs."""
+ dataset1 = numpytf_read(file1)
+ dataset2 = numpytf_read(file2)
- # First iterate through dataset1 and calculate per-channel total for each tensor
+ totals: dict = defaultdict(dict)
+
+ # First iterate through dataset and calculate per-channel total for each tensor
count = 0
- for d in dataset1:
+ for data in dataset1:
count += 1
- for k, v in d.items():
- value = v.numpy().astype(np.double)
- add_total("dataset1_total", k, value)
+ for key, val in data.items():
+ value = val.numpy().astype(np.double)
+ add_total("dataset1_total", key, value, totals)
# Use this to calculate per-channel mean for each tensor
- per_tensor_mean = lambda name: {
- k: total / count for k, total in totals[name].items()
- }
+ def per_tensor_mean(name: str) -> dict:
+ return {k: total / count for k, total in totals[name].items()}
+
dataset1_mean = per_tensor_mean("dataset1_total")
- # Next iterate through both datasets and calculate per-channel total squared error
- # between them for each tensor and dataset1 variance for each tensor using the mean from above
- for i, (x1, x2) in enumerate(zip(dataset1, dataset2)):
- assert x1.keys() == x2.keys(), (
- "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n"
- % (
- i,
- file1,
- ", ".join(x1.keys()),
- file2,
- ", ".join(x2.keys()),
- )
+ # Next iterate through both datasets and calculate per-channel total squared
+ # error between them for each tensor and dataset1 variance for each tensor
+ # using the mean from above
+ for i, (ds1, ds2) in enumerate(zip(dataset1, dataset2)):
+ assert ds1.keys() == ds2.keys(), (
+ f"At input {i} the files have different sets of tensors.\n"
+ f"{file1}: {', '.join(ds1.keys())}\n"
+ f"{file2}: {', '.join(ds2.keys())}\n"
)
- for k in x1.keys():
- v1 = x1[k].numpy().astype(np.double)
- v2 = x2[k].numpy().astype(np.double)
- add_total("ae", k, abs(v1 - v2))
- add_total("se", k, (v1 - v2) ** 2)
- add_total("dataset1_variance", k, (v1 - dataset1_mean[k]) ** 2)
+ for key in ds1.keys():
+ tensor1 = ds1[key].numpy().astype(np.double)
+ tensor2 = ds2[key].numpy().astype(np.double)
+ add_total("ae", key, abs(tensor1 - tensor2), totals)
+ add_total("se", key, (tensor1 - tensor2) ** 2, totals)
+ add_total(
+ "dataset1_variance",
+ key,
+ (tensor1 - dataset1_mean[key]) ** 2,
+ totals,
+ )
# Finally average over number of inputs to get the rmse and the dataset1 variance
mae = per_tensor_mean("ae")
@@ -66,7 +82,8 @@ def diff_stats(file1, file2, per_tensor_and_channel=False):
dataset1_var = per_tensor_mean("dataset1_variance")
is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var}
- # Divide by target standard deviation to get the per-channel nrmse for each tensor where possible
+ # Divide by target standard deviation to get the per-channel nrmse for each
+ # tensor where possible
nrmse = {
k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]])
for k, v in rmse.items()
@@ -74,6 +91,5 @@ def diff_stats(file1, file2, per_tensor_and_channel=False):
if per_tensor_and_channel:
return mae, nrmse
- else:
- dict_mean = lambda d: np.mean(list(d.values()))
- return dict_mean(mae), dict_mean(nrmse)
+
+ return dict_mean(mae), dict_mean(nrmse)