aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_cli_commands.py17
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_diff.py45
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_join.py25
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py55
4 files changed, 135 insertions, 7 deletions
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 6765a53..e4bbe91 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -73,8 +73,8 @@ def test_performance_unknown_target(
None,
True,
"fully_connected",
- "node_a",
- "node_b",
+ "sequential/flatten/Reshape",
+ "StatefulPartitionedCall:0",
does_not_raise(),
],
[
@@ -85,8 +85,8 @@ def test_performance_unknown_target(
None,
True,
"fully_connected",
- "node_a",
- "node_b",
+ "sequential/flatten/Reshape",
+ "StatefulPartitionedCall:0",
pytest.raises(
Exception,
match=(r"Only 'rewrite' is supported for TensorFlow Lite files."),
@@ -157,7 +157,7 @@ def test_performance_unknown_target(
],
],
)
-def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments
+def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments
target_profile: str,
sample_context: ExecutionContext,
pruning: bool,
@@ -171,12 +171,14 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments
expected_error: Any,
monkeypatch: pytest.MonkeyPatch,
test_keras_model: Path,
- test_tflite_model: Path,
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
) -> None:
"""Test that command should not fail with valid optimization targets."""
mock_performance_estimation(monkeypatch)
- model_type = test_tflite_model if rewrite else test_keras_model
+ model_type = test_tflite_model_fp32 if rewrite else test_keras_model
+ data = test_tfrecord_fp32 if rewrite else None
with expected_error:
optimize(
@@ -191,6 +193,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments
rewrite_target=rewrite_target,
rewrite_start=rewrite_start,
rewrite_end=rewrite_end,
+ dataset=data,
)
diff --git a/tests/test_nn_rewrite_core_graph_edit_diff.py b/tests/test_nn_rewrite_core_graph_edit_diff.py
new file mode 100644
index 0000000..fdda04f
--- /dev/null
+++ b/tests/test_nn_rewrite_core_graph_edit_diff.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.graph_edit.join."""
+from pathlib import Path
+
+import numpy as np
+import pytest
+
+from mlia.nn.rewrite.core.graph_edit.diff import calc_nrmse
+from mlia.nn.rewrite.core.graph_edit.diff import diff_stats
+
+
+def assert_two_dictionaries_with_numpy_arrays(dict1: dict, dict2: dict) -> None:
+ """Use numpy assertions to check whether two dictionaries are equal."""
+ key1, val1 = list(dict1.keys()), list(dict1.values())
+ key2, val2 = list(dict2.keys()), list(dict2.values())
+ np.testing.assert_equal(key1, key2)
+ np.testing.assert_almost_equal(val1, val2)
+
+
+@pytest.mark.parametrize(
+ "rmse, scale, expected_result",
+ [
+ (
+ {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))},
+ {"test1": np.ndarray((3,), buffer=np.array([4.0, 4.0, 0.0]))},
+ {"test1": np.ndarray((3,), buffer=np.array([0.5, 1.0, 3.3]))},
+ ),
+ (
+ {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))},
+ {"test1": np.ndarray((3,), buffer=np.array([0.0, 0.0, 0.0]))},
+ {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))},
+ ),
+ ],
+)
+def test_calc_nrmse(rmse: dict, scale: dict, expected_result: dict) -> None:
+ """Test calc_nrmse() function."""
+ assert_two_dictionaries_with_numpy_arrays(calc_nrmse(rmse, scale), expected_result)
+
+
+def test_diff_stats(test_tfrecord_fp32: Path) -> None:
+ """Test diff_stats() function."""
+ res1, res2 = diff_stats(test_tfrecord_fp32, test_tfrecord_fp32)
+ assert res1 == 0.0
+ assert res2 == 0.0
diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py
index cbbbeba..cb3e4e2 100644
--- a/tests/test_nn_rewrite_core_graph_edit_join.py
+++ b/tests/test_nn_rewrite_core_graph_edit_join.py
@@ -1,9 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.graph_edit.join."""
+from contextlib import ExitStack as does_not_raise
from pathlib import Path
+from typing import Any
+
+import pytest
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
+from mlia.nn.rewrite.core.graph_edit.join import append_relabel
from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.utils.utils import load
from tests.utils.rewrite import models_are_equal
@@ -48,3 +53,23 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None:
joined_model = load(str(joined_file))
assert models_are_equal(orig_model, joined_model)
+
+
+@pytest.mark.parametrize(
+ "src, dst, op_map, expected_error",
+ [
+ ([1, 2, 3], [4, 5, 6], {}, does_not_raise()),
+ (
+ [1, 2, 3],
+ [4, 5, 6],
+ None,
+ pytest.raises(Exception, match="The input operator map cannot be None!"),
+ ),
+ ],
+)
+def test_append_relabel(
+ src: list, dst: list, op_map: dict, expected_error: Any
+) -> None:
+ """Test passing by reference of the object in function append_relabel."""
+ with expected_error:
+ append_relabel(src, dst, op_map)
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
new file mode 100644
index 0000000..b98971e
--- /dev/null
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -0,0 +1,55 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.core.rewrite."""
+from __future__ import annotations
+
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+
+import pytest
+
+from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
+from mlia.nn.rewrite.core.rewrite import Rewriter
+from mlia.nn.tensorflow.config import TFLiteModel
+
+
+@pytest.mark.parametrize(
+ "rewrite_name, expected_error",
+ [
+ ("fully_connected", does_not_raise()),
+ ("random", does_not_raise()),
+ ],
+)
+def test_rewrite_configuration(
+ test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
+) -> None:
+ """Test get_rewrite function only supports rewrite type fully_connected."""
+ with expected_error:
+ config_obj = RewriteConfiguration(
+ rewrite_name,
+ ["sample_node_start", "sample_node_end"],
+ None,
+ )
+
+ rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name
+ assert isinstance(rewriter_obj, Rewriter)
+
+
+def test_rewriter(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+) -> None:
+ """Test fc_layer rewrite process with rewrite type fully_connected."""
+ config_obj = RewriteConfiguration(
+ "fully_connected",
+ ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
+ test_tfrecord_fp32,
+ )
+
+ test_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ test_obj.apply_optimization()
+ trained_model = test_obj.get_model()
+
+ assert isinstance(trained_model, TFLiteModel)