diff options
author | Ruomei Yan <ruomei.yan@arm.com> | 2023-04-20 09:51:20 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 15:44:51 +0100 |
commit | f3e6597dd50ec70f043d692b773f2d9fd31519ae (patch) | |
tree | 322ccb75e0cc594c57308288cae333a72401979e /tests/test_nn_rewrite_core_graph_edit_join.py | |
parent | 867f37d643e66c0223457c28f5345f2f21db97f2 (diff) | |
download | mlia-f3e6597dd50ec70f043d692b773f2d9fd31519ae.tar.gz |
Implement first rewrite (proof of concept)
* Define replacement function fully_connected layer
* Define RewriteConfiguration and Rewriter to integrate
rewrite module into mlia optimize command
* Fix a bug in the ethos_u/data_collection.py file
* Fix a bug in join.py
* Remove diff_stats and use diff instead, added related
changes around this to ensure e2e tests passing
* Add unit tests for all changes
* Fix bug in diff_stats function
* The bug was caused by a dividing by numpy array
of all zeros. The previous way of handling it
did not consider the all zeros case but only
dealt with partially zeros
* unit tests added.
* Fix the bug in rewrite/core/graph_edit/join.py
* Remove the possibility of passing None to append_relabel
function because it is immutable
* The bug happened when empty dictionary was passed in the
append_relabel function and the function overwrites the
reference of operator_map which caused the dictionary
was not updated after the function call
Resolves: MLIA-749, MLIA-864, MLIA-866
Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'tests/test_nn_rewrite_core_graph_edit_join.py')
-rw-r--r-- | tests/test_nn_rewrite_core_graph_edit_join.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py index cbbbeba..cb3e4e2 100644 --- a/tests/test_nn_rewrite_core_graph_edit_join.py +++ b/tests/test_nn_rewrite_core_graph_edit_join.py @@ -1,9 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.graph_edit.join.""" +from contextlib import ExitStack as does_not_raise from pathlib import Path +from typing import Any + +import pytest from mlia.nn.rewrite.core.graph_edit.cut import cut_model +from mlia.nn.rewrite.core.graph_edit.join import append_relabel from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.utils.utils import load from tests.utils.rewrite import models_are_equal @@ -48,3 +53,23 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None: joined_model = load(str(joined_file)) assert models_are_equal(orig_model, joined_model) + + +@pytest.mark.parametrize( + "src, dst, op_map, expected_error", + [ + ([1, 2, 3], [4, 5, 6], {}, does_not_raise()), + ( + [1, 2, 3], + [4, 5, 6], + None, + pytest.raises(Exception, match="The input operator map cannot be None!"), + ), + ], +) +def test_append_relabel( + src: list, dst: list, op_map: dict, expected_error: Any +) -> None: + """Test passing by reference of the object in function append_relabel.""" + with expected_error: + append_relabel(src, dst, op_map) |