diff options
author | Ruomei Yan <ruomei.yan@arm.com> | 2023-04-20 09:51:20 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 15:44:51 +0100 |
commit | f3e6597dd50ec70f043d692b773f2d9fd31519ae (patch) | |
tree | 322ccb75e0cc594c57308288cae333a72401979e /src/mlia/nn/rewrite/core/graph_edit/diff.py | |
parent | 867f37d643e66c0223457c28f5345f2f21db97f2 (diff) | |
download | mlia-f3e6597dd50ec70f043d692b773f2d9fd31519ae.tar.gz |
Implement first rewrite (proof of concept)
* Define replacement function fully_connected layer
* Define RewriteConfiguration and Rewriter to integrate
rewrite module into mlia optimize command
* Fix a bug in the ethos_u/data_collection.py file
* Fix a bug in join.py
* Remove diff_stats and use diff instead, added related
changes around this to ensure e2e tests passing
* Add unit tests for all changes
* Fix bug in diff_stats function
* The bug was caused by a dividing by numpy array
of all zeros. The previous way of handling it
did not consider the all zeros case but only
dealt with partially zeros
* unit tests added.
* Fix the bug in rewrite/core/graph_edit/join.py
* Remove the possibility of passing None to append_relabel
function because it is immutable
* The bug happened when empty dictionary was passed in the
append_relabel function and the function overwrites the
reference of operator_map which caused the dictionary
was not updated after the function call
Resolves: MLIA-749, MLIA-864, MLIA-866
Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/diff.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/diff.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py index 198e47e..7fa2a72 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/diff.py +++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py @@ -31,6 +31,21 @@ def add_total(name: str, key: str, values: list, totals: dict) -> None: totals[name][key] += values +def _handle_zeros_in_denominator(denominator: np.ndarray) -> np.ndarray: + """Handle zeros in the denominator in nrmse to avoid dividing by zero(s).""" + denominator[denominator == 0.0] = 1.0 + return denominator + + +def calc_nrmse(rmse: dict, dataset1_var: dict) -> dict: + """Divide rmse by target standard deviation.""" + nrmse = { + k: v / _handle_zeros_in_denominator(np.sqrt(dataset1_var[k])) + for k, v in rmse.items() + } + return nrmse + + def diff_stats( file1: str | Path, file2: str | Path, per_tensor_and_channel: bool = False ) -> tuple: @@ -80,14 +95,8 @@ def diff_stats( mse = per_tensor_mean("se") rmse = {k: np.sqrt(v) for k, v in mse.items()} dataset1_var = per_tensor_mean("dataset1_variance") - is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var} - # Divide by target standard deviation to get the per-channel nrmse for each - # tensor where possible - nrmse = { - k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]]) - for k, v in rmse.items() - } + nrmse = calc_nrmse(rmse, dataset1_var) if per_tensor_and_channel: return mae, nrmse |