aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py80
1 files changed, 76 insertions, 4 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index 2542db2..d4aac56 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -6,15 +6,35 @@ from __future__ import annotations
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
+from typing import cast
import pytest
+from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
+from mlia.nn.rewrite.core.rewrite import Rewrite
+from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
-from mlia.nn.rewrite.core.rewrite import Rewriter
+from mlia.nn.rewrite.core.rewrite import RewriteRegistry
+from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import TestTrainingParameters
+def mock_rewrite_function(*_: Any) -> Any:
+ """Mock function to test autoloading of rewrite functions."""
+
+
+def test_rewrite() -> None:
+ """Test the Rewrite class."""
+
+ def bad_rewrite_func() -> Any:
+ raise NotImplementedError()
+
+ rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func))
+ with pytest.raises(RuntimeError):
+ rewrite((1, 2), (1, 2))
+
+
@pytest.mark.parametrize(
"rewrite_name, expected_error",
[
@@ -35,9 +55,9 @@ def test_rewrite_configuration(
assert config_obj.optimization_target in str(config_obj)
- rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name
- assert isinstance(rewriter_obj, Rewriter)
+ assert isinstance(rewriter_obj, RewritingOptimizer)
def test_rewriting_optimizer(
@@ -52,8 +72,60 @@ def test_rewriting_optimizer(
train_params=TestTrainingParameters(),
)
- test_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
test_obj.apply_optimization()
trained_model = test_obj.get_model()
assert isinstance(trained_model, TFLiteModel)
+
+ cfg = test_obj.optimization_config()
+ assert isinstance(cfg, str)
+ assert cfg
+
+
+def test_register_rewrite_function() -> None:
+ """Test adding rewrite functions and verify the are reported via the registry."""
+ registry = RewriteRegistry()
+
+ rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1))
+ rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2))
+
+ registry.register_rewrite(rewrite1)
+ registry.register_rewrite(rewrite2)
+ assert registry.names() == ["r1", "r2"]
+
+
+def test_builtin_rewrite_names() -> None:
+ """Test if all builtin rewrites are properly registered and returned."""
+ assert RewritingOptimizer.builtin_rewrite_names() == ["fully_connected"]
+
+
+def test_rewrite_function_autoload() -> None:
+ """Test rewrite function loading."""
+ function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function"
+ rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name)
+ assert rewrite.name == "mock_rewrite"
+
+ assert rewrite.function is not mock_rewrite_function
+ assert rewrite.load_function(function_name) is mock_rewrite_function
+ assert rewrite.function is mock_rewrite_function
+
+
+def test_rewrite_function_autoload_fail() -> None:
+ """Test rewrite function loading failure."""
+ function_name = "invalid_module.invalid_function"
+ rewrite = DynamicallyLoadedRewrite(
+ name="mock_rewrite",
+ function_name="invalid_module.invalid_function",
+ )
+ assert rewrite.name == "mock_rewrite"
+
+ with pytest.raises(Exception) as exc_info:
+ rewrite.load_function(function_name)
+
+ message = exc_info.value.args[0]
+
+ assert message == (
+ "Unable to load rewrite function 'invalid_module.invalid_function'"
+ " for 'mock_rewrite'."
+ )