aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit
diff options
context:
space:
mode:
authorRuomei Yan <ruomei.yan@arm.com>2023-04-20 09:51:20 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:44:51 +0100
commitf3e6597dd50ec70f043d692b773f2d9fd31519ae (patch)
tree322ccb75e0cc594c57308288cae333a72401979e /src/mlia/nn/rewrite/core/graph_edit
parent867f37d643e66c0223457c28f5345f2f21db97f2 (diff)
downloadmlia-f3e6597dd50ec70f043d692b773f2d9fd31519ae.tar.gz
Implement first rewrite (proof of concept)
* Define replacement function fully_connected layer * Define RewriteConfiguration and Rewriter to integrate rewrite module into mlia optimize command * Fix a bug in the ethos_u/data_collection.py file * Fix a bug in join.py * Remove diff_stats and use diff instead, added related changes around this to ensure e2e tests passing * Add unit tests for all changes * Fix bug in diff_stats function * The bug was caused by a dividing by numpy array of all zeros. The previous way of handling it did not consider the all zeros case but only dealt with partially zeros * unit tests added. * Fix the bug in rewrite/core/graph_edit/join.py * Remove the possibility of passing None to append_relabel function because it is immutable * The bug happened when empty dictionary was passed in the append_relabel function and the function overwrites the reference of operator_map which caused the dictionary was not updated after the function call Resolves: MLIA-749, MLIA-864, MLIA-866 Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/diff.py23
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py14
2 files changed, 23 insertions, 14 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py
index 198e47e..7fa2a72 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/diff.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py
@@ -31,6 +31,21 @@ def add_total(name: str, key: str, values: list, totals: dict) -> None:
totals[name][key] += values
+def _handle_zeros_in_denominator(denominator: np.ndarray) -> np.ndarray:
+ """Handle zeros in the denominator in nrmse to avoid dividing by zero(s)."""
+ denominator[denominator == 0.0] = 1.0
+ return denominator
+
+
+def calc_nrmse(rmse: dict, dataset1_var: dict) -> dict:
+ """Divide rmse by target standard deviation."""
+ nrmse = {
+ k: v / _handle_zeros_in_denominator(np.sqrt(dataset1_var[k]))
+ for k, v in rmse.items()
+ }
+ return nrmse
+
+
def diff_stats(
file1: str | Path, file2: str | Path, per_tensor_and_channel: bool = False
) -> tuple:
@@ -80,14 +95,8 @@ def diff_stats(
mse = per_tensor_mean("se")
rmse = {k: np.sqrt(v) for k, v in mse.items()}
dataset1_var = per_tensor_mean("dataset1_variance")
- is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var}
- # Divide by target standard deviation to get the per-channel nrmse for each
- # tensor where possible
- nrmse = {
- k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]])
- for k, v in rmse.items()
- }
+ nrmse = calc_nrmse(rmse, dataset1_var)
if per_tensor_and_channel:
return mae, nrmse
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
index 14a7347..2530ec8 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/join.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -22,8 +22,8 @@ def join_models(
input_src: str | Path,
input_dst: str | Path,
output_file: str | Path,
- subgraph_src: SubGraphT = 0,
- subgraph_dst: SubGraphT = 0,
+ subgraph_src: int = 0,
+ subgraph_dst: int = 0,
) -> None:
"""Join two models and save the result into a given model file path."""
src_model = load(input_src)
@@ -150,12 +150,12 @@ def join_subgraphs(
dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs))
-def append_relabel(src: list, dst: list, operator_map: dict | None = None) -> dict:
- """Return a map over relabeled tensors in a subgraph."""
- if not operator_map:
- operator_map = {}
+def append_relabel(src: list, dst: list, operator_map: dict) -> None:
+ """Update the operator map over relabeled tensors in a subgraph."""
+ if operator_map is None:
+ raise ValueError("The input operator map cannot be None!")
+
for i, x in enumerate(src): # pylint: disable=invalid-name
if i not in operator_map:
operator_map[i] = len(dst)
dst.append(x)
- return operator_map