aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorMadeleine Dunn <madeleine.dunn@arm.com>2024-02-21 17:10:07 +0000
committerMadeleine Dunn <madeleine.dunn@arm.com>2024-04-04 15:26:36 +0100
commit1ebb335cba516bcf973b041efa6a9878d1022b93 (patch)
tree9038cc30c9f32403b715506abbd76f59cbf3d6a6 /tests
parent17813ba5be09f0e11fc0748afa4ccf2da02881b6 (diff)
downloadmlia-1ebb335cba516bcf973b041efa6a9878d1022b93.tar.gz
feat: Implement int8 sparsity 2:4 rewrite
- Implement pruning-preserving quantisation aware training - Rework the training logic to avoid duplication - Remove the DynamicallyLoadedRewrite class as it is now unused Resolves: MLIA-1003 Signed-off-by: Madeleine Dunn <madeleine.dunn@arm.com> Change-Id: Ia7a4acf5f477a27963cffa88180cca085b32ffe4
Diffstat (limited to 'tests')
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py71
-rw-r--r--tests/test_nn_rewrite_core_train.py12
2 files changed, 30 insertions, 53 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index e614cad..8ef5bd2 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -11,45 +11,51 @@ from unittest.mock import MagicMock
import pytest
-from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
-from mlia.nn.rewrite.core.rewrite import Rewrite
+from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite
from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite
from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.rewrite.core.train import train_in_dir
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
-def mock_rewrite_function(*_: Any) -> Any:
- """Mock function to test autoloading of rewrite functions."""
-
-
def test_rewrite() -> None:
- """Test the Rewrite class."""
+ """Test a derived Rewrite class."""
def bad_rewrite_func() -> Any:
raise NotImplementedError()
- rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func))
+ rewrite = Sparsity24Rewrite(
+ "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)
+ )
with pytest.raises(RuntimeError):
rewrite((1, 2), (1, 2))
@pytest.mark.parametrize(
- "rewrite_name, callbacks_length",
+ "rewrite_name, rewrite_class",
[
- ("fully-connected", 0),
- ("fully-connected-sparsity24", 1),
+ ("fully-connected", FullyConnectedRewrite),
+ ("fully-connected-sparsity24", Sparsity24Rewrite),
],
)
-def test_rewrite_selection(rewrite_name: str, callbacks_length: int) -> None:
- """Test that the correct rewrite class is instantiated."""
- rewrite = RewritingOptimizer.registry.items[rewrite_name]
+def test_rewrite_selection(
+ rewrite_name: str,
+ rewrite_class: Any,
+) -> None:
+ """Check that the correct rewrite class is instantiated through the registry"""
+ config_obj = RewriteConfiguration(
+ rewrite_name,
+ ["sample_node_start", "sample_node_end"],
+ )
+
+ rewrite = RewritingOptimizer.registry.items[config_obj.optimization_target]
assert rewrite.name == rewrite_name
- assert len(rewrite.training_callbacks()) == callbacks_length
+ assert isinstance(rewrite, rewrite_class)
@pytest.mark.parametrize(
@@ -106,8 +112,8 @@ def test_register_rewrite_function() -> None:
"""Test adding rewrite functions and verify they are reported via the registry."""
registry = RewriteRegistry()
- rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1))
- rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2))
+ rewrite1 = FullyConnectedRewrite("r1", cast(RewriteCallable, lambda: 1))
+ rewrite2 = Sparsity24Rewrite("r2", cast(RewriteCallable, lambda: 2))
registry.register_rewrite(rewrite1)
registry.register_rewrite(rewrite2)
@@ -122,37 +128,6 @@ def test_builtin_rewrite_names() -> None:
]
-def test_rewrite_function_autoload() -> None:
- """Test rewrite function loading."""
- function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function"
- rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name)
- assert rewrite.name == "mock_rewrite"
-
- assert rewrite.function is not mock_rewrite_function
- assert rewrite.load_function(function_name) is mock_rewrite_function
- assert rewrite.function is mock_rewrite_function
-
-
-def test_rewrite_function_autoload_fail() -> None:
- """Test rewrite function loading failure."""
- function_name = "invalid_module.invalid_function"
- rewrite = DynamicallyLoadedRewrite(
- name="mock_rewrite",
- function_name="invalid_module.invalid_function",
- )
- assert rewrite.name == "mock_rewrite"
-
- with pytest.raises(Exception) as exc_info:
- rewrite.load_function(function_name)
-
- message = exc_info.value.args[0]
-
- assert message == (
- "Unable to load rewrite function 'invalid_module.invalid_function'"
- " for 'mock_rewrite'."
- )
-
-
def test_rewrite_configuration_train_params(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 34b9543..371c79f 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -14,13 +14,15 @@ import pytest
import tensorflow as tf
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
+from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite
+from mlia.nn.rewrite.core.rewrite import QATRewrite
from mlia.nn.rewrite.core.train import augment_fn_twins
from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
+from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
from tests.utils.rewrite import MockTrainingParameters
@@ -54,12 +56,11 @@ def check_train(
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
output_file = Path(tmp_dir, "out.tflite")
- mock_rewrite = DynamicallyLoadedRewrite(
+ mock_rewrite = FullyConnectedRewrite(
name="replace",
- function_name=(
- "tests.test_nn_rewrite_core_train.replace_fully_connected_with_conv"
- ),
+ rewrite_fn=fc_rewrite,
)
+ is_qat = isinstance(mock_rewrite, QATRewrite)
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
@@ -68,6 +69,7 @@ def check_train(
rewrite=mock_rewrite,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
+ is_qat=is_qat,
train_params=train_params,
)