aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/join.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/join.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
index 14a7347..2530ec8 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/join.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -22,8 +22,8 @@ def join_models(
input_src: str | Path,
input_dst: str | Path,
output_file: str | Path,
- subgraph_src: SubGraphT = 0,
- subgraph_dst: SubGraphT = 0,
+ subgraph_src: int = 0,
+ subgraph_dst: int = 0,
) -> None:
"""Join two models and save the result into a given model file path."""
src_model = load(input_src)
@@ -150,12 +150,12 @@ def join_subgraphs(
dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs))
-def append_relabel(src: list, dst: list, operator_map: dict | None = None) -> dict:
- """Return a map over relabeled tensors in a subgraph."""
- if not operator_map:
- operator_map = {}
+def append_relabel(src: list, dst: list, operator_map: dict) -> None:
+ """Update the operator map over relabeled tensors in a subgraph."""
+ if operator_map is None:
+ raise ValueError("The input operator map cannot be None!")
+
for i, x in enumerate(src): # pylint: disable=invalid-name
if i not in operator_map:
operator_map[i] = len(dst)
dst.append(x)
- return operator_map