aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_graph_edit_diff.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_graph_edit_diff.py')
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_diff.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_graph_edit_diff.py b/tests/test_nn_rewrite_core_graph_edit_diff.py
new file mode 100644
index 0000000..fdda04f
--- /dev/null
+++ b/tests/test_nn_rewrite_core_graph_edit_diff.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.graph_edit.join."""
+from pathlib import Path
+
+import numpy as np
+import pytest
+
+from mlia.nn.rewrite.core.graph_edit.diff import calc_nrmse
+from mlia.nn.rewrite.core.graph_edit.diff import diff_stats
+
+
+def assert_two_dictionaries_with_numpy_arrays(dict1: dict, dict2: dict) -> None:
+ """Use numpy assertions to check whether two dictionaries are equal."""
+ key1, val1 = list(dict1.keys()), list(dict1.values())
+ key2, val2 = list(dict2.keys()), list(dict2.values())
+ np.testing.assert_equal(key1, key2)
+ np.testing.assert_almost_equal(val1, val2)
+
+
+@pytest.mark.parametrize(
+ "rmse, scale, expected_result",
+ [
+ (
+ {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))},
+ {"test1": np.ndarray((3,), buffer=np.array([4.0, 4.0, 0.0]))},
+ {"test1": np.ndarray((3,), buffer=np.array([0.5, 1.0, 3.3]))},
+ ),
+ (
+ {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))},
+ {"test1": np.ndarray((3,), buffer=np.array([0.0, 0.0, 0.0]))},
+ {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))},
+ ),
+ ],
+)
+def test_calc_nrmse(rmse: dict, scale: dict, expected_result: dict) -> None:
+ """Test calc_nrmse() function."""
+ assert_two_dictionaries_with_numpy_arrays(calc_nrmse(rmse, scale), expected_result)
+
+
+def test_diff_stats(test_tfrecord_fp32: Path) -> None:
+ """Test diff_stats() function."""
+ res1, res2 = diff_stats(test_tfrecord_fp32, test_tfrecord_fp32)
+ assert res1 == 0.0
+ assert res2 == 0.0