diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_graph_edit_diff.py')
-rw-r--r-- | tests/test_nn_rewrite_core_graph_edit_diff.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_graph_edit_diff.py b/tests/test_nn_rewrite_core_graph_edit_diff.py new file mode 100644 index 0000000..fdda04f --- /dev/null +++ b/tests/test_nn_rewrite_core_graph_edit_diff.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.graph_edit.join.""" +from pathlib import Path + +import numpy as np +import pytest + +from mlia.nn.rewrite.core.graph_edit.diff import calc_nrmse +from mlia.nn.rewrite.core.graph_edit.diff import diff_stats + + +def assert_two_dictionaries_with_numpy_arrays(dict1: dict, dict2: dict) -> None: + """Use numpy assertions to check whether two dictionaries are equal.""" + key1, val1 = list(dict1.keys()), list(dict1.values()) + key2, val2 = list(dict2.keys()), list(dict2.values()) + np.testing.assert_equal(key1, key2) + np.testing.assert_almost_equal(val1, val2) + + +@pytest.mark.parametrize( + "rmse, scale, expected_result", + [ + ( + {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))}, + {"test1": np.ndarray((3,), buffer=np.array([4.0, 4.0, 0.0]))}, + {"test1": np.ndarray((3,), buffer=np.array([0.5, 1.0, 3.3]))}, + ), + ( + {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))}, + {"test1": np.ndarray((3,), buffer=np.array([0.0, 0.0, 0.0]))}, + {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))}, + ), + ], +) +def test_calc_nrmse(rmse: dict, scale: dict, expected_result: dict) -> None: + """Test calc_nrmse() function.""" + assert_two_dictionaries_with_numpy_arrays(calc_nrmse(rmse, scale), expected_result) + + +def test_diff_stats(test_tfrecord_fp32: Path) -> None: + """Test diff_stats() function.""" + res1, res2 = diff_stats(test_tfrecord_fp32, test_tfrecord_fp32) + assert res1 == 0.0 + assert res2 == 0.0 |