aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py166
1 files changed, 123 insertions, 43 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index b32fafd..e502842 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -10,45 +10,102 @@ from typing import cast
from unittest.mock import MagicMock
import pytest
+import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module
+ ClusterWeights,
+)
-from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
+from mlia.nn.rewrite.core.rewrite import ClusteringRewrite
+from mlia.nn.rewrite.core.rewrite import GenericRewrite
from mlia.nn.rewrite.core.rewrite import Rewrite
from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite
from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.rewrite.core.train import train_in_dir
+from mlia.nn.rewrite.library.fc_clustering_layer import (
+ get_keras_model_clus as fc_clustering_rewrite,
+)
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
+class TestRewrite(Rewrite):
+ """Test rewrite class."""
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ return tfmot.quantization.keras.quantize_model(model)
+
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Not needed."""
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ """Not needed here."""
+ return True
+
+
def mock_rewrite_function(*_: Any) -> Any:
"""Mock function to test autoloading of rewrite functions."""
def test_rewrite() -> None:
- """Test the Rewrite class."""
+ """Test a derived Rewrite class."""
def bad_rewrite_func() -> Any:
raise NotImplementedError()
- rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func))
+ rewrite = TestRewrite(
+ "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)
+ )
with pytest.raises(RuntimeError):
rewrite((1, 2), (1, 2))
@pytest.mark.parametrize(
+ "rewrite_name, callbacks_length, instance",
+ [
+ ("fully-connected", 0, GenericRewrite),
+ ("fully-connected-clustering", 0, ClusteringRewrite),
+ ("fully-connected-sparsity24", 1, Sparsity24Rewrite),
+ ],
+)
+def test_rewrite_selection(
+ rewrite_name: str, callbacks_length: int, instance: Rewrite
+) -> None:
+ """Test that the correct rewrite class is instantiated."""
+ rewrite = RewritingOptimizer.registry.items[rewrite_name]
+ assert rewrite.name == rewrite_name
+ assert isinstance(rewrite, instance) # type: ignore
+ assert len(rewrite.training_callbacks()) == callbacks_length
+
+
+@pytest.mark.parametrize(
"rewrite_name, expected_error",
[
("fully-connected", does_not_raise()),
+ ("fully-connected-sparsity24", does_not_raise()),
+ ("fully-connected-clustering", does_not_raise()),
("random", does_not_raise()),
],
)
def test_rewrite_configuration(
test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
) -> None:
- """Test get_rewrite function only supports rewrite type fully-connected."""
+ """Test get_rewrite function only supports rewrite type fully-connected,
+ fully-connected-clustering and fully-connected-sparsity24."""
with expected_error:
config_obj = RewriteConfiguration(
rewrite_name,
@@ -63,19 +120,69 @@ def test_rewrite_configuration(
assert isinstance(rewriter_obj, RewritingOptimizer)
-def test_rewriting_optimizer(
+def test_rewrite_fully_connected_clustering() -> None:
+ """Check that model has the set number of clusters"""
+
+ rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
+ model = rewrite(input_shape=(28, 28), output_shape=10)
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model, number_of_clusters=32)
+
+
+def test_rewrite_fully_connected_clustering_error_handling() -> None:
+ """Check that model has the set number of clusters
+ and that when quantized the number of clusters
+ remain."""
+
+ rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
+ model = rewrite(input_shape=(28, 28), output_shape=10)
+ with pytest.raises(
+ ValueError,
+ match=(
+ r"Expected check_preserved_quantize to have argument number_of_clusters"
+ ),
+ ):
+ rewrite.check_optimization(model, bad_arg_name=25)
+
+
+@pytest.mark.parametrize(
+ "rewrite_type, expected_layers, quant",
+ [
+ ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False],
+ ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False],
+ ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True],
+ ],
+)
+def test_rewriting_optimizer( # pylint: disable=too-many-locals
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
+ test_tflite_model: Path,
+ test_tfrecord: Path,
+ rewrite_type: str,
+ expected_layers: list[object],
+ quant: bool,
) -> None:
"""Test fc_layer rewrite process with rewrite type fully-connected."""
+
+ tfrecord = test_tfrecord if quant else test_tfrecord_fp32
+ tflite_model = test_tflite_model if quant else test_tflite_model_fp32
+
config_obj = RewriteConfiguration(
- "fully-connected",
+ rewrite_type,
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
- test_tfrecord_fp32,
+ tfrecord,
train_params=MockTrainingParameters(),
)
- test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
+ test_obj = RewritingOptimizer(tflite_model, config_obj)
+ rewrite_function = RewritingOptimizer.registry.items[
+ test_obj.optimizer_configuration.optimization_target
+ ]
+ # Input, output shape does not matter, just need the test the layers are as expected
+ rewrite_model = rewrite_function(input_shape=(28, 28, 1), output_shape=12)
+ for idx, layer in enumerate(rewrite_model.layers):
+ assert isinstance(layer, expected_layers[idx]) # type: ignore
+
test_obj.apply_optimization()
trained_model = test_obj.get_model()
@@ -87,11 +194,11 @@ def test_rewriting_optimizer(
def test_register_rewrite_function() -> None:
- """Test adding rewrite functions and verify the are reported via the registry."""
+ """Test adding rewrite functions and verify they are reported via the registry."""
registry = RewriteRegistry()
- rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1))
- rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2))
+ rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1))
+ rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2))
registry.register_rewrite(rewrite1)
registry.register_rewrite(rewrite2)
@@ -100,38 +207,11 @@ def test_register_rewrite_function() -> None:
def test_builtin_rewrite_names() -> None:
"""Test if all builtin rewrites are properly registered and returned."""
- assert RewritingOptimizer.builtin_rewrite_names() == ["fully-connected"]
-
-
-def test_rewrite_function_autoload() -> None:
- """Test rewrite function loading."""
- function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function"
- rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name)
- assert rewrite.name == "mock_rewrite"
-
- assert rewrite.function is not mock_rewrite_function
- assert rewrite.load_function(function_name) is mock_rewrite_function
- assert rewrite.function is mock_rewrite_function
-
-
-def test_rewrite_function_autoload_fail() -> None:
- """Test rewrite function loading failure."""
- function_name = "invalid_module.invalid_function"
- rewrite = DynamicallyLoadedRewrite(
- name="mock_rewrite",
- function_name="invalid_module.invalid_function",
- )
- assert rewrite.name == "mock_rewrite"
-
- with pytest.raises(Exception) as exc_info:
- rewrite.load_function(function_name)
-
- message = exc_info.value.args[0]
-
- assert message == (
- "Unable to load rewrite function 'invalid_module.invalid_function'"
- " for 'mock_rewrite'."
- )
+ assert RewritingOptimizer.builtin_rewrite_names() == [
+ "fully-connected",
+ "fully-connected-clustering",
+ "fully-connected-sparsity24",
+ ]
def test_rewrite_configuration_train_params(