diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 80 |
1 files changed, 76 insertions, 4 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index 2542db2..d4aac56 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -6,15 +6,35 @@ from __future__ import annotations from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any +from typing import cast import pytest +from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite +from mlia.nn.rewrite.core.rewrite import Rewrite +from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration -from mlia.nn.rewrite.core.rewrite import Rewriter +from mlia.nn.rewrite.core.rewrite import RewriteRegistry +from mlia.nn.rewrite.core.rewrite import RewritingOptimizer from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import TestTrainingParameters +def mock_rewrite_function(*_: Any) -> Any: + """Mock function to test autoloading of rewrite functions.""" + + +def test_rewrite() -> None: + """Test the Rewrite class.""" + + def bad_rewrite_func() -> Any: + raise NotImplementedError() + + rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)) + with pytest.raises(RuntimeError): + rewrite((1, 2), (1, 2)) + + @pytest.mark.parametrize( "rewrite_name, expected_error", [ @@ -35,9 +55,9 @@ def test_rewrite_configuration( assert config_obj.optimization_target in str(config_obj) - rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj) + rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name - assert isinstance(rewriter_obj, Rewriter) + assert isinstance(rewriter_obj, RewritingOptimizer) def test_rewriting_optimizer( @@ -52,8 +72,60 @@ def test_rewriting_optimizer( train_params=TestTrainingParameters(), ) - test_obj = Rewriter(test_tflite_model_fp32, config_obj) + test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) test_obj.apply_optimization() trained_model = test_obj.get_model() assert isinstance(trained_model, TFLiteModel) + + cfg = test_obj.optimization_config() + assert isinstance(cfg, str) + assert cfg + + +def test_register_rewrite_function() -> None: + """Test adding rewrite functions and verify the are reported via the registry.""" + registry = RewriteRegistry() + + rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1)) + rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2)) + + registry.register_rewrite(rewrite1) + registry.register_rewrite(rewrite2) + assert registry.names() == ["r1", "r2"] + + +def test_builtin_rewrite_names() -> None: + """Test if all builtin rewrites are properly registered and returned.""" + assert RewritingOptimizer.builtin_rewrite_names() == ["fully_connected"] + + +def test_rewrite_function_autoload() -> None: + """Test rewrite function loading.""" + function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function" + rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name) + assert rewrite.name == "mock_rewrite" + + assert rewrite.function is not mock_rewrite_function + assert rewrite.load_function(function_name) is mock_rewrite_function + assert rewrite.function is mock_rewrite_function + + +def test_rewrite_function_autoload_fail() -> None: + """Test rewrite function loading failure.""" + function_name = "invalid_module.invalid_function" + rewrite = DynamicallyLoadedRewrite( + name="mock_rewrite", + function_name="invalid_module.invalid_function", + ) + assert rewrite.name == "mock_rewrite" + + with pytest.raises(Exception) as exc_info: + rewrite.load_function(function_name) + + message = exc_info.value.args[0] + + assert message == ( + "Unable to load rewrite function 'invalid_module.invalid_function'" + " for 'mock_rewrite'." + ) |