aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/diff.py
blob: 7fa2a72633f4d426f60588cf7696d8b88038b215 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Diff module: compare subgraph outputs."""
# pylint: disable=too-many-locals
from __future__ import annotations

import os
from collections import defaultdict
from pathlib import Path
from typing import Any

import numpy as np
import tensorflow as tf

from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


def dict_mean(mean_dict: dict) -> Any:
    """Return the mean of values in a given dict."""
    return np.mean(list(mean_dict.values()))


def add_total(name: str, key: str, values: list, totals: dict) -> None:
    """Append values to dict totals."""
    if key not in totals[name]:
        totals[name][key] = values
    else:
        totals[name][key] += values


def _handle_zeros_in_denominator(denominator: np.ndarray) -> np.ndarray:
    """Handle zeros in the denominator in nrmse to avoid dividing by zero(s)."""
    denominator[denominator == 0.0] = 1.0
    return denominator


def calc_nrmse(rmse: dict, dataset1_var: dict) -> dict:
    """Divide rmse by target standard deviation."""
    nrmse = {
        k: v / _handle_zeros_in_denominator(np.sqrt(dataset1_var[k]))
        for k, v in rmse.items()
    }
    return nrmse


def diff_stats(
    file1: str | Path, file2: str | Path, per_tensor_and_channel: bool = False
) -> tuple:
    """Compare the statistics of outputs between two subgraphs."""
    dataset1 = numpytf_read(file1)
    dataset2 = numpytf_read(file2)

    totals: dict = defaultdict(dict)

    # First iterate through dataset and calculate per-channel total for each tensor
    count = 0
    for data in dataset1:
        count += 1
        for key, val in data.items():
            value = val.numpy().astype(np.double)
            add_total("dataset1_total", key, value, totals)

    # Use this to calculate per-channel mean for each tensor
    def per_tensor_mean(name: str) -> dict:
        return {k: total / count for k, total in totals[name].items()}

    dataset1_mean = per_tensor_mean("dataset1_total")

    # Next iterate through both datasets and calculate per-channel total squared
    # error between them for each tensor and dataset1 variance for each tensor
    # using the mean from above
    for i, (ds1, ds2) in enumerate(zip(dataset1, dataset2)):
        assert ds1.keys() == ds2.keys(), (
            f"At input {i} the files have different sets of tensors.\n"
            f"{file1}: {', '.join(ds1.keys())}\n"
            f"{file2}: {', '.join(ds2.keys())}\n"
        )
        for key in ds1.keys():
            tensor1 = ds1[key].numpy().astype(np.double)
            tensor2 = ds2[key].numpy().astype(np.double)
            add_total("ae", key, abs(tensor1 - tensor2), totals)
            add_total("se", key, (tensor1 - tensor2) ** 2, totals)
            add_total(
                "dataset1_variance",
                key,
                (tensor1 - dataset1_mean[key]) ** 2,
                totals,
            )

    # Finally average over number of inputs to get the rmse and the dataset1 variance
    mae = per_tensor_mean("ae")
    mse = per_tensor_mean("se")
    rmse = {k: np.sqrt(v) for k, v in mse.items()}
    dataset1_var = per_tensor_mean("dataset1_variance")

    nrmse = calc_nrmse(rmse, dataset1_var)

    if per_tensor_and_channel:
        return mae, nrmse

    return dict_mean(mae), dict_mean(nrmse)