aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py71
1 files changed, 23 insertions, 48 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index e614cad..8ef5bd2 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -11,45 +11,51 @@ from unittest.mock import MagicMock
import pytest
-from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
-from mlia.nn.rewrite.core.rewrite import Rewrite
+from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite
from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite
from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.rewrite.core.train import train_in_dir
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
-def mock_rewrite_function(*_: Any) -> Any:
- """Mock function to test autoloading of rewrite functions."""
-
-
def test_rewrite() -> None:
- """Test the Rewrite class."""
+ """Test a derived Rewrite class."""
def bad_rewrite_func() -> Any:
raise NotImplementedError()
- rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func))
+ rewrite = Sparsity24Rewrite(
+ "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)
+ )
with pytest.raises(RuntimeError):
rewrite((1, 2), (1, 2))
@pytest.mark.parametrize(
- "rewrite_name, callbacks_length",
+ "rewrite_name, rewrite_class",
[
- ("fully-connected", 0),
- ("fully-connected-sparsity24", 1),
+ ("fully-connected", FullyConnectedRewrite),
+ ("fully-connected-sparsity24", Sparsity24Rewrite),
],
)
-def test_rewrite_selection(rewrite_name: str, callbacks_length: int) -> None:
- """Test that the correct rewrite class is instantiated."""
- rewrite = RewritingOptimizer.registry.items[rewrite_name]
+def test_rewrite_selection(
+ rewrite_name: str,
+ rewrite_class: Any,
+) -> None:
+ """Check that the correct rewrite class is instantiated through the registry"""
+ config_obj = RewriteConfiguration(
+ rewrite_name,
+ ["sample_node_start", "sample_node_end"],
+ )
+
+ rewrite = RewritingOptimizer.registry.items[config_obj.optimization_target]
assert rewrite.name == rewrite_name
- assert len(rewrite.training_callbacks()) == callbacks_length
+ assert isinstance(rewrite, rewrite_class)
@pytest.mark.parametrize(
@@ -106,8 +112,8 @@ def test_register_rewrite_function() -> None:
"""Test adding rewrite functions and verify they are reported via the registry."""
registry = RewriteRegistry()
- rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1))
- rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2))
+ rewrite1 = FullyConnectedRewrite("r1", cast(RewriteCallable, lambda: 1))
+ rewrite2 = Sparsity24Rewrite("r2", cast(RewriteCallable, lambda: 2))
registry.register_rewrite(rewrite1)
registry.register_rewrite(rewrite2)
@@ -122,37 +128,6 @@ def test_builtin_rewrite_names() -> None:
]
-def test_rewrite_function_autoload() -> None:
- """Test rewrite function loading."""
- function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function"
- rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name)
- assert rewrite.name == "mock_rewrite"
-
- assert rewrite.function is not mock_rewrite_function
- assert rewrite.load_function(function_name) is mock_rewrite_function
- assert rewrite.function is mock_rewrite_function
-
-
-def test_rewrite_function_autoload_fail() -> None:
- """Test rewrite function loading failure."""
- function_name = "invalid_module.invalid_function"
- rewrite = DynamicallyLoadedRewrite(
- name="mock_rewrite",
- function_name="invalid_module.invalid_function",
- )
- assert rewrite.name == "mock_rewrite"
-
- with pytest.raises(Exception) as exc_info:
- rewrite.load_function(function_name)
-
- message = exc_info.value.args[0]
-
- assert message == (
- "Unable to load rewrite function 'invalid_module.invalid_function'"
- " for 'mock_rewrite'."
- )
-
-
def test_rewrite_configuration_train_params(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,