diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_graph_edit_join.py')
-rw-r--r-- | tests/test_nn_rewrite_core_graph_edit_join.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py index cbbbeba..cb3e4e2 100644 --- a/tests/test_nn_rewrite_core_graph_edit_join.py +++ b/tests/test_nn_rewrite_core_graph_edit_join.py @@ -1,9 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.graph_edit.join.""" +from contextlib import ExitStack as does_not_raise from pathlib import Path +from typing import Any + +import pytest from mlia.nn.rewrite.core.graph_edit.cut import cut_model +from mlia.nn.rewrite.core.graph_edit.join import append_relabel from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.utils.utils import load from tests.utils.rewrite import models_are_equal @@ -48,3 +53,23 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None: joined_model = load(str(joined_file)) assert models_are_equal(orig_model, joined_model) + + +@pytest.mark.parametrize( + "src, dst, op_map, expected_error", + [ + ([1, 2, 3], [4, 5, 6], {}, does_not_raise()), + ( + [1, 2, 3], + [4, 5, 6], + None, + pytest.raises(Exception, match="The input operator map cannot be None!"), + ), + ], +) +def test_append_relabel( + src: list, dst: list, op_map: dict, expected_error: Any +) -> None: + """Test passing by reference of the object in function append_relabel.""" + with expected_error: + append_relabel(src, dst, op_map) |