diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/join.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/join.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py index 14a7347..2530ec8 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/join.py +++ b/src/mlia/nn/rewrite/core/graph_edit/join.py @@ -22,8 +22,8 @@ def join_models( input_src: str | Path, input_dst: str | Path, output_file: str | Path, - subgraph_src: SubGraphT = 0, - subgraph_dst: SubGraphT = 0, + subgraph_src: int = 0, + subgraph_dst: int = 0, ) -> None: """Join two models and save the result into a given model file path.""" src_model = load(input_src) @@ -150,12 +150,12 @@ def join_subgraphs( dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs)) -def append_relabel(src: list, dst: list, operator_map: dict | None = None) -> dict: - """Return a map over relabeled tensors in a subgraph.""" - if not operator_map: - operator_map = {} +def append_relabel(src: list, dst: list, operator_map: dict) -> None: + """Update the operator map over relabeled tensors in a subgraph.""" + if operator_map is None: + raise ValueError("The input operator map cannot be None!") + for i, x in enumerate(src): # pylint: disable=invalid-name if i not in operator_map: operator_map[i] = len(dst) dst.append(x) - return operator_map |