aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_graph_edit_join.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_graph_edit_join.py')
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_join.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py
index cbbbeba..cb3e4e2 100644
--- a/tests/test_nn_rewrite_core_graph_edit_join.py
+++ b/tests/test_nn_rewrite_core_graph_edit_join.py
@@ -1,9 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.graph_edit.join."""
+from contextlib import ExitStack as does_not_raise
from pathlib import Path
+from typing import Any
+
+import pytest
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
+from mlia.nn.rewrite.core.graph_edit.join import append_relabel
from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.utils.utils import load
from tests.utils.rewrite import models_are_equal
@@ -48,3 +53,23 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None:
joined_model = load(str(joined_file))
assert models_are_equal(orig_model, joined_model)
+
+
+@pytest.mark.parametrize(
+ "src, dst, op_map, expected_error",
+ [
+ ([1, 2, 3], [4, 5, 6], {}, does_not_raise()),
+ (
+ [1, 2, 3],
+ [4, 5, 6],
+ None,
+ pytest.raises(Exception, match="The input operator map cannot be None!"),
+ ),
+ ],
+)
+def test_append_relabel(
+ src: list, dst: list, op_map: dict, expected_error: Any
+) -> None:
+ """Test passing by reference of the object in function append_relabel."""
+ with expected_error:
+ append_relabel(src, dst, op_map)