aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/conftest.py42
-rw-r--r--tests/test_cli_commands.py33
-rw-r--r--tests/test_cli_helpers.py2
-rw-r--r--tests/test_common_optimization.py91
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py221
-rw-r--r--tests/test_nn_rewrite_core_train.py44
-rw-r--r--tests/test_nn_rewrite_library_helper_functions.py95
-rw-r--r--tests/test_nn_select.py57
-rw-r--r--tests/test_target_cortex_a_advisor.py21
-rw-r--r--tests/test_target_tosa_advisor.py21
10 files changed, 498 insertions, 129 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index 3d0b832..a64f320 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -126,9 +126,28 @@ def get_test_keras_model() -> keras.Model:
return model
+def get_test_keras_model_no_activation() -> keras.Model:
+ """Return test Keras model."""
+ model = keras.Sequential(
+ [
+ keras.Input(shape=(28, 28, 1), batch_size=1, name="input"),
+ keras.layers.Reshape((28, 28, 1)),
+ keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv1"),
+ keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv2"),
+ keras.layers.MaxPool2D(2, 2),
+ keras.layers.Flatten(),
+ keras.layers.Dense(10, name="output"),
+ ]
+ )
+
+ model.compile(optimizer="sgd", loss="mean_squared_error")
+ return model
+
+
TEST_MODEL_KERAS_FILE = "test_model.h5"
TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite"
TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite"
+TEST_MODEL_TFLITE_NO_ACT_FILE = "test_model_no_act.tflite"
TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite"
TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model"
TEST_MODEL_INVALID_FILE = "invalid.tflite"
@@ -153,6 +172,13 @@ def fixture_test_models_path(
keras_model, quantized=False, output_path=tmp_path / TEST_MODEL_TFLITE_FP32_FILE
)
+ # Un-quantized TensorFlow Lite model with ReLU activation (fp32)
+ convert_to_tflite(
+ get_test_keras_model_no_activation(),
+ quantized=False,
+ output_path=tmp_path / TEST_MODEL_TFLITE_NO_ACT_FILE,
+ )
+
# Quantized TensorFlow Lite model (int8)
tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE
convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path)
@@ -195,6 +221,12 @@ def fixture_test_tflite_vela_model(test_models_path: Path) -> Path:
return test_models_path / TEST_MODEL_TFLITE_VELA_FILE
+@pytest.fixture(scope="session", name="test_tflite_no_act_model")
+def fixture_test_tflite_no_act_model(test_models_path: Path) -> Path:
+ """Return test TensorFlow Lite model with relu activation."""
+ return test_models_path / TEST_MODEL_TFLITE_NO_ACT_FILE
+
+
@pytest.fixture(scope="session", name="test_tf_model")
def fixture_test_tf_model(test_models_path: Path) -> Path:
"""Return test TensorFlow Lite model."""
@@ -257,17 +289,17 @@ def fixture_test_tfrecord_fp32(
yield from create_tfrecord(tmp_path_factory, random_data)
-@pytest.fixture(scope="session", autouse=True)
+@pytest.fixture(scope="function", autouse=True)
def set_training_steps(
request: _pytest.fixtures.SubRequest,
) -> Generator[None, None, None]:
"""Speed up tests by using MockTrainingParameters."""
- if "set_training_steps" == request.fixturename:
- yield
- else:
+ if "skip_set_training_steps" not in request.keywords:
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"mlia.nn.select._get_rewrite_params",
- MagicMock(return_value=[MockTrainingParameters(), None, None]),
+ MagicMock(return_value=MockTrainingParameters()),
)
yield
+ else:
+ yield
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 93a05bd..6c54a73 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -90,7 +90,7 @@ def test_performance_unknown_target(
None,
None,
True,
- "fully-connected-sparsity24",
+ "fully-connected-sparsity",
"sequential/flatten/Reshape",
"StatefulPartitionedCall:0",
does_not_raise(),
@@ -139,8 +139,9 @@ def test_performance_unknown_target(
Exception,
match=re.escape(
"Invalid rewrite target: 'random'. "
- "Supported rewrites: ['fully-connected',"
- " 'fully-connected-clustering', 'fully-connected-sparsity24']"
+ "Supported rewrites: ['conv2d-clustering', 'conv2d-sparsity', "
+ "'fully-connected', 'fully-connected-clustering', "
+ "'fully-connected-sparsity']"
),
),
],
@@ -195,6 +196,32 @@ def test_performance_unknown_target(
"StatefulPartitionedCall:0",
does_not_raise(),
],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "conv2d-sparsity",
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "conv2d-clustering",
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ does_not_raise(),
+ ],
],
)
def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments
diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py
index 0e9f0d6..69e6ffe 100644
--- a/tests/test_cli_helpers.py
+++ b/tests/test_cli_helpers.py
@@ -156,7 +156,7 @@ def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None:
def test_copy_optimization_file_to_output_dir(tmp_path: Path) -> None:
- """Test if the optimization profile file is copied into the output directory."""
+ """Test if the profile file is copied into the output directory."""
test_target_profile_name = "optimization"
test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml")
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 58ea8af..bdcf034 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -60,7 +60,10 @@ def test_optimizing_data_collector(
config_parameters={
"common_optimizations": {
"optimizations": optimizations,
- "training_parameters": training_parameters,
+ "rewrite_parameters": {
+ "train_params": training_parameters,
+ "rewrite_specific_params": None,
+ },
}
}
)
@@ -97,12 +100,15 @@ def test_optimizing_data_collector(
collector.set_context(context)
collector.collect_data()
assert optimize_model_mock.call_args.args[0] == opt_settings[0]
- assert optimize_model_mock.call_args.args[1] == training_parameters
+ assert optimize_model_mock.call_args.args[1] == {
+ "train_params": training_parameters,
+ "rewrite_specific_params": None,
+ }
assert fake_optimizer.invocation_count == 1
@pytest.mark.parametrize(
- "extra_args, error_to_raise",
+ "extra_args, error_to_raise, rewrite_parameter_type",
[
(
{
@@ -115,14 +121,39 @@ def test_optimizing_data_collector(
],
},
does_not_raises(),
+ type(None),
),
(
{
+ "optimization_targets": [
+ {
+ "optimization_type": "rewrite",
+ "optimization_target": "fully-connected-clustering",
+ }
+ ],
"optimization_profile": load_profile(
- "src/mlia/resources/optimization_profiles/optimization.toml"
- )
+ "src/mlia/resources/optimization_profiles/"
+ "optimization-fully-connected-clustering.toml"
+ ),
+ },
+ does_not_raises(),
+ dict,
+ ),
+ (
+ {
+ "optimization_targets": [
+ {
+ "optimization_type": "rewrite",
+ "optimization_target": "fully-connected-sparsity",
+ }
+ ],
+ "optimization_profile": load_profile(
+ "src/mlia/resources/optimization_profiles/"
+ "optimization-fully-connected-pruning.toml"
+ ),
},
does_not_raises(),
+ dict,
),
(
{
@@ -135,16 +166,22 @@ def test_optimizing_data_collector(
pytest.raises(
TypeError, match="Optimization targets value has wrong format."
),
+ type(None),
),
(
{"optimization_profile": [32, 1e-3, True, 48000, "cosine", 1, 0]},
pytest.raises(
- TypeError, match="Training Parameter values has wrong format."
+ TypeError, match="Optimization Parameter values has wrong format."
),
+ type(None),
),
],
)
-def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -> None:
+def test_add_common_optimization_params(
+ extra_args: dict,
+ error_to_raise: Any,
+ rewrite_parameter_type: dict | None,
+) -> None:
"""Test to check that optimization_targets and optimization_profiles are
correctly parsed."""
advisor_parameters: dict = {}
@@ -161,14 +198,40 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -
]
if not extra_args.get("optimization_profile"):
- assert (
- advisor_parameters["common_optimizations"]["training_parameters"]
- is None
- )
+ assert advisor_parameters["common_optimizations"]["rewrite_parameters"] == {
+ "train_params": None,
+ "rewrite_specific_params": None,
+ }
else:
- assert (
- advisor_parameters["common_optimizations"]["training_parameters"]
- == extra_args["optimization_profile"]["training"]
+ if not extra_args["optimization_profile"].get("rewrite"):
+ assert isinstance(
+ advisor_parameters["common_optimizations"]["rewrite_parameters"][
+ "train_params"
+ ],
+ type(None),
+ )
+ elif not extra_args["optimization_profile"]["rewrite"].get(
+ "training_parameters"
+ ):
+ assert isinstance(
+ advisor_parameters["common_optimizations"]["rewrite_parameters"][
+ "train_params"
+ ],
+ type(None),
+ )
+ else:
+ assert isinstance(
+ advisor_parameters["common_optimizations"]["rewrite_parameters"][
+ "train_params"
+ ],
+ dict,
+ )
+
+ assert isinstance(
+ advisor_parameters["common_optimizations"]["rewrite_parameters"][
+ "rewrite_specific_params"
+ ],
+ rewrite_parameter_type, # type: ignore
)
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index e502842..c874017 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -3,18 +3,23 @@
"""Tests for module mlia.nn.rewrite.core.rewrite."""
from __future__ import annotations
+import re
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import cast
from unittest.mock import MagicMock
+import numpy as np
import pytest
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module
ClusterWeights,
)
+from tensorflow_model_optimization.python.core.sparsity.keras.pruning_wrapper import ( # pylint: disable=no-name-in-module
+ PruneLowMagnitude,
+)
from mlia.nn.rewrite.core.rewrite import ClusteringRewrite
from mlia.nn.rewrite.core.rewrite import GenericRewrite
@@ -23,40 +28,16 @@ from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
-from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite
+from mlia.nn.rewrite.core.rewrite import SparsityRewrite
from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.rewrite.core.train import train_in_dir
-from mlia.nn.rewrite.library.fc_clustering_layer import (
- get_keras_model_clus as fc_clustering_rewrite,
-)
+from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
-class TestRewrite(Rewrite):
- """Test rewrite class."""
-
- def quantize(self, model: keras.Model) -> keras.Model:
- """Return a quantized model if required."""
- return tfmot.quantization.keras.quantize_model(model)
-
- def preserved_quantize(self, model: keras.Model) -> keras.Model:
- """Not needed."""
- return model
-
- def training_callbacks(self) -> list:
- """Return default rewrite callbacks."""
- return []
-
- def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
- return model
-
- def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
- """Not needed here."""
- return True
-
-
def mock_rewrite_function(*_: Any) -> Any:
"""Mock function to test autoloading of rewrite functions."""
@@ -67,10 +48,10 @@ def test_rewrite() -> None:
def bad_rewrite_func() -> Any:
raise NotImplementedError()
- rewrite = TestRewrite(
+ rewrite = GenericRewrite(
"BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)
)
- with pytest.raises(RuntimeError):
+ with pytest.raises(KeyError):
rewrite((1, 2), (1, 2))
@@ -79,7 +60,9 @@ def test_rewrite() -> None:
[
("fully-connected", 0, GenericRewrite),
("fully-connected-clustering", 0, ClusteringRewrite),
- ("fully-connected-sparsity24", 1, Sparsity24Rewrite),
+ ("fully-connected-sparsity", 1, SparsityRewrite),
+ ("conv2d-clustering", 0, ClusteringRewrite),
+ ("conv2d-sparsity", 1, SparsityRewrite),
],
)
def test_rewrite_selection(
@@ -96,8 +79,10 @@ def test_rewrite_selection(
"rewrite_name, expected_error",
[
("fully-connected", does_not_raise()),
- ("fully-connected-sparsity24", does_not_raise()),
+ ("fully-connected-sparsity", does_not_raise()),
("fully-connected-clustering", does_not_raise()),
+ ("conv2d-clustering", does_not_raise()),
+ ("conv2d-sparsity", does_not_raise()),
("random", does_not_raise()),
],
)
@@ -105,7 +90,8 @@ def test_rewrite_configuration(
test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
) -> None:
"""Test get_rewrite function only supports rewrite type fully-connected,
- fully-connected-clustering and fully-connected-sparsity24."""
+ fully-connected-clustering, fully-connected-sparsity, conv2d-clustering
+ and conv2d-sparsity."""
with expected_error:
config_obj = RewriteConfiguration(
rewrite_name,
@@ -120,29 +106,114 @@ def test_rewrite_configuration(
assert isinstance(rewriter_obj, RewritingOptimizer)
+def train_rewrite_model(
+ input_shape: tuple | np.ndarray,
+ output_shape: int | np.ndarray,
+ rewrite_model: keras.Model,
+) -> keras.Model:
+ """Helper function to quickly train a rewrite model."""
+ rewrite_model.compile(
+ optimizer=keras.optimizers.Nadam(learning_rate=0.01),
+ loss=keras.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ if isinstance(output_shape, int):
+ output_shape_list = [output_shape]
+ else:
+ output_shape_list = output_shape.tolist()
+ rewrite_model.fit(
+ x=np.random.rand(16, *input_shape),
+ y=np.random.rand(16, *output_shape_list),
+ batch_size=1,
+ epochs=1,
+ callbacks=[tfmot.sparsity.keras.UpdatePruningStep()],
+ )
+ return rewrite_model
+
+
def test_rewrite_fully_connected_clustering() -> None:
- """Check that model has the set number of clusters"""
+ """Check that fully connected clustering rewrite model
+ has the set number of clusters."""
+
+ rewrite = ClusteringRewrite(
+ "fully-connected-clustering",
+ fc_clustering_rewrite,
+ )
+
+ model = rewrite(input_shape=(28, 28), output_shape=10, num_clusters=2)
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(
+ model,
+ num_clusters=2,
+ )
+
- rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
- model = rewrite(input_shape=(28, 28), output_shape=10)
+def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> None:
+ """
+ Check that sparse fully connected
+ rewrite model is correctly sparse.
+ """
+
+ rewrite = SparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite)
+ input_shape = (28, 28)
+ output_shape = 10
+ model = rewrite(
+ input_shape=tuple(input_shape),
+ output_shape=output_shape,
+ sparsity_m=2,
+ sparsity_n=4,
+ )
+ model = rewrite.post_process(model)
+ assert not rewrite.check_optimization(model)
+ log_records = caplog.records
+ warning_messages = [x.message for x in log_records if x.levelno == 30]
+ assert (
+ re.search(
+ r"\nWARNING: Could not find \(2, 4\) sparsity, in "
+ r"layer dense_?\d? for weight dense_?\d?\/kernel:0 \n",
+ warning_messages[0],
+ )
+ is not None
+ )
+ model = rewrite(
+ input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4
+ )
+ train_rewrite_model(
+ input_shape=input_shape, output_shape=output_shape, rewrite_model=model
+ )
model = rewrite.post_process(model)
- assert rewrite.check_optimization(model, number_of_clusters=32)
+ assert rewrite.check_optimization(model)
-def test_rewrite_fully_connected_clustering_error_handling() -> None:
- """Check that model has the set number of clusters
- and that when quantized the number of clusters
- remain."""
+def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None:
+ """Check that sparse conv2d rewrite model is correctly sparse."""
- rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
- model = rewrite(input_shape=(28, 28), output_shape=10)
- with pytest.raises(
- ValueError,
- match=(
- r"Expected check_preserved_quantize to have argument number_of_clusters"
- ),
- ):
- rewrite.check_optimization(model, bad_arg_name=25)
+ rewrite = SparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite)
+ input_shape = np.array([28, 28, 3])
+ output_shape = np.array([14, 14, 3])
+ model = rewrite(
+ input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4
+ )
+ model = rewrite.post_process(model)
+ assert not rewrite.check_optimization(model)
+ log_records = caplog.records
+ warning_messages = [x.message for x in log_records if x.levelno == 30]
+ assert (
+ re.search(
+ r"\nWARNING: Could not find \(2, 4\) sparsity, in "
+ r"layer conv2d_?\d? for weight conv2d_?\d?\/kernel:0 \n",
+ warning_messages[0],
+ )
+ is not None
+ )
+ model = rewrite(
+ input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4
+ )
+ train_rewrite_model(
+ input_shape=input_shape, output_shape=output_shape, rewrite_model=model
+ )
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model)
@pytest.mark.parametrize(
@@ -151,6 +222,20 @@ def test_rewrite_fully_connected_clustering_error_handling() -> None:
["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False],
["fully-connected-clustering", [ClusterWeights, ClusterWeights], False],
["fully-connected-clustering", [ClusterWeights, ClusterWeights], True],
+ ["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], False],
+ ["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], True],
+ ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], False],
+ ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], True],
+ [
+ "conv2d-sparsity",
+ [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude],
+ False,
+ ],
+ [
+ "conv2d-sparsity",
+ [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude],
+ True,
+ ],
],
)
def test_rewriting_optimizer( # pylint: disable=too-many-locals
@@ -162,24 +247,32 @@ def test_rewriting_optimizer( # pylint: disable=too-many-locals
expected_layers: list[object],
quant: bool,
) -> None:
- """Test fc_layer rewrite process with rewrite type fully-connected."""
+ """Test the rewrite process with all rewrite types."""
tfrecord = test_tfrecord if quant else test_tfrecord_fp32
tflite_model = test_tflite_model if quant else test_tflite_model_fp32
+ rewrite_function = RewritingOptimizer.registry.items[rewrite_type]
config_obj = RewriteConfiguration(
rewrite_type,
- ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
+ ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"]
+ if "fully-connected" in rewrite_type
+ else [
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ ],
tfrecord,
train_params=MockTrainingParameters(),
)
-
test_obj = RewritingOptimizer(tflite_model, config_obj)
- rewrite_function = RewritingOptimizer.registry.items[
- test_obj.optimizer_configuration.optimization_target
- ]
# Input, output shape does not matter, just need the test the layers are as expected
- rewrite_model = rewrite_function(input_shape=(28, 28, 1), output_shape=12)
+ rewrite_model = (
+ rewrite_function(input_shape=(28, 28, 1), output_shape=12)
+ if "fully-connected" in rewrite_type
+ else rewrite_function(
+ input_shape=np.array([28, 28, 3]), output_shape=np.array([14, 14, 3])
+ )
+ )
for idx, layer in enumerate(rewrite_model.layers):
assert isinstance(layer, expected_layers[idx]) # type: ignore
@@ -197,8 +290,14 @@ def test_register_rewrite_function() -> None:
"""Test adding rewrite functions and verify they are reported via the registry."""
registry = RewriteRegistry()
- rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1))
- rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2))
+ rewrite1 = GenericRewrite(
+ "r1",
+ cast(RewriteCallable, lambda: 1),
+ )
+ rewrite2 = GenericRewrite(
+ "r2",
+ cast(RewriteCallable, lambda: 2),
+ )
registry.register_rewrite(rewrite1)
registry.register_rewrite(rewrite2)
@@ -208,9 +307,11 @@ def test_register_rewrite_function() -> None:
def test_builtin_rewrite_names() -> None:
"""Test if all builtin rewrites are properly registered and returned."""
assert RewritingOptimizer.builtin_rewrite_names() == [
+ "conv2d-clustering",
+ "conv2d-sparsity",
"fully-connected",
"fully-connected-clustering",
- "fully-connected-sparsity24",
+ "fully-connected-sparsity",
]
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 94c99ff..03b230f 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -16,11 +16,12 @@ from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.nn.rewrite.core.train import augment_fn_twins
from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
+from mlia.nn.rewrite.core.train import detect_activation_from_rewrite_function
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
-from tests.test_nn_rewrite_core_rewrite import TestRewrite
+from tests.test_nn_rewrite_core_rewrite import GenericRewrite
from tests.utils.rewrite import MockTrainingParameters
@@ -54,7 +55,7 @@ def check_train(
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
output_file = Path(tmp_dir, "out.tflite")
- mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv)
+ mock_rewrite = GenericRewrite("replace", replace_fully_connected_with_conv)
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
@@ -65,6 +66,7 @@ def check_train(
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
train_params=train_params,
+ rewrite_specific_params={},
)
assert len(result[0][0]) == 2
@@ -249,3 +251,41 @@ def test_train_checkpoint(
use_unmodified_model=False,
quantized=True,
)
+
+
+def test_detect_activation_from_rewrite_function_no_activation(
+ caplog: pytest.LogCaptureFixture, test_tflite_no_act_model: Path
+) -> None:
+ """
+ Test function detect_activation_from_rewrite_function()
+ with a model with no activation functions.
+ """
+ caplog.set_level(level=20)
+ activation = detect_activation_from_rewrite_function(
+ test_tflite_no_act_model.as_posix()
+ )
+ log_records = caplog.get_records(when="call")
+ logging_messages = [x.message for x in log_records if x.levelno == 20]
+ assert activation == "relu"
+ assert (
+ "No activation function specified, setting activation function to ReLU"
+ in logging_messages
+ )
+
+
+def test_detect_activation_from_rewrite_function_relu_activation(
+ caplog: pytest.LogCaptureFixture, test_tflite_model: Path
+) -> None:
+ """
+ Test function detect_activation_from_rewrite_function()
+ with a model with ReLU activation functions.
+ """
+ caplog.set_level(level=20)
+ activation = detect_activation_from_rewrite_function(test_tflite_model.as_posix())
+ log_records = caplog.get_records(when="call")
+ logging_messages = [x.message for x in log_records if x.levelno == 20]
+ assert activation == "relu"
+ assert (
+ "No activation function specified, setting activation function "
+ "to most common activation detected in rewrite graph: relu" in logging_messages
+ )
diff --git a/tests/test_nn_rewrite_library_helper_functions.py b/tests/test_nn_rewrite_library_helper_functions.py
new file mode 100644
index 0000000..9e880aa
--- /dev/null
+++ b/tests/test_nn_rewrite_library_helper_functions.py
@@ -0,0 +1,95 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.library.helper_functions."""
+from __future__ import annotations
+
+from contextlib import ExitStack as does_not_raise
+from typing import Any
+
+import numpy as np
+import pytest
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST
+from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+from mlia.nn.rewrite.library.helper_functions import get_activation_function
+
+
+def compute_conv_output(
+ input_data: np.ndarray, input_shape: np.ndarray, conv_parameters: dict[str, Any]
+) -> np.ndarray:
+ """Compute the output of a conv layer for testing."""
+ test_model = keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv_parameters),
+ ]
+ )
+ output = test_model(input_data)
+ return np.array(output.shape[1:])
+
+
+@pytest.mark.parametrize(
+ "input_shape, output_shape",
+ [
+ (np.array([32, 32, 3]), np.array([16, 16, 3])),
+ (np.array([32, 32, 3]), np.array([8, 8, 3])),
+ (np.array([32, 32, 3]), np.array([8, 16, 3])),
+ (np.array([25, 10, 3]), np.array([13, 5, 3])),
+ (np.array([25, 10, 3]), np.array([7, 5, 3])),
+ (np.array([25, 10, 3]), np.array([6, 4, 3])),
+ (np.array([25, 10, 3]), np.array([5, 5, 3])),
+ ],
+)
+def test_compute_conv2d_parameters(
+ input_shape: np.ndarray, output_shape: np.ndarray
+) -> None:
+ """Test to check compute_conv2d_parameters works as expected."""
+ conv_parameters = compute_conv2d_parameters(
+ input_shape=input_shape, output_shape=output_shape
+ )
+ computed_output_shape = compute_conv_output(
+ np.random.rand(1, *input_shape), input_shape, conv_parameters
+ )
+ assert np.equal(computed_output_shape, output_shape).all()
+
+
+@pytest.mark.parametrize(
+ "activation, expected_function_type, expected_extra_args, expected_error",
+ [
+ ("relu", keras.layers.ReLU, {}, does_not_raise()),
+ ("relu6", keras.layers.ReLU, {"max_value": 6}, does_not_raise()),
+ ("none", keras.layers.Identity, {}, does_not_raise()),
+ (
+ "wrong_key",
+ keras.layers.Identity,
+ {},
+ pytest.raises(
+ KeyError,
+ match=(
+ "Expected activation function to be "
+ rf"in \{ACTIVATION_FUNCTION_LIST}\, found wrong_key"
+ ),
+ ),
+ ),
+ ],
+)
+def test_get_activation_functions(
+ activation: str,
+ expected_function_type: type[keras.layers.Layer],
+ expected_extra_args: dict,
+ expected_error: Any,
+) -> None:
+ """
+ Check the get_activation_function returns
+ the expected layer and extra arguments.
+ """
+ with expected_error:
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ assert isinstance(
+ activation_function(**activation_function_extra_args),
+ expected_function_type,
+ )
+ assert expected_extra_args == activation_function_extra_args
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py
index 4095076..08752bd 100644
--- a/tests/test_nn_select.py
+++ b/tests/test_nn_select.py
@@ -4,12 +4,12 @@
from __future__ import annotations
from contextlib import ExitStack as does_not_raise
-from dataclasses import asdict
from pathlib import Path
from typing import Any
from typing import cast
import pytest
+import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.errors import ConfigurationError
@@ -176,23 +176,50 @@ def test_get_optimizer(
model = test_tflite_model
else:
model = keras.models.load_model(str(test_keras_model))
- optimizer = get_optimizer(model, config)
+ optimizer = get_optimizer(
+ model, config, {"train_params": None, "rewrite_specific_params": None}
+ )
assert isinstance(optimizer, expected_type)
assert optimizer.optimization_config() == expected_config
+# pylint: disable=line-too-long
@pytest.mark.parametrize(
- "rewrite_parameters",
- [None, {"batch_size": 64, "learning_rate": 0.003}],
+ "rewrite_parameters, optimization_target",
+ [
+ [
+ {"train_params": None, "rewrite_specific_params": None},
+ "fully-connected-clustering",
+ ],
+ [
+ {
+ "train_params": None,
+ "rewrite_specific_params": {
+ "num_clusters": 5,
+ "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization(
+ "CentroidInitialization.LINEAR"
+ ),
+ },
+ },
+ "fully-connected-clustering",
+ ],
+ [
+ {"train_params": None, "rewrite_specific_params": None},
+ "fully-connected",
+ ],
+ ],
)
+# pylint: enable=line-too-long
@pytest.mark.skip_set_training_steps
def test_get_optimizer_training_parameters(
- rewrite_parameters: dict | None, test_tflite_model: Path
+ rewrite_parameters: dict,
+ optimization_target: str,
+ test_tflite_model: Path,
) -> None:
"""Test function get_optimzer with various combinations of parameters."""
config = OptimizationSettings(
optimization_type="rewrite",
- optimization_target="fully-connected", # type: ignore
+ optimization_target=optimization_target, # type: ignore
layers_to_optimize=None,
dataset=None,
)
@@ -200,18 +227,20 @@ def test_get_optimizer_training_parameters(
RewritingOptimizer,
get_optimizer(test_tflite_model, config, rewrite_parameters),
)
+ assert len(list(rewrite_parameters.items())) == 2
+ if rewrite_parameters.get("rewrite_specific_params"):
+ assert isinstance(
+ rewrite_parameters["rewrite_specific_params"],
+ type(optimizer.optimizer_configuration.rewrite_specific_params),
+ )
+ assert (
+ optimizer.optimizer_configuration.rewrite_specific_params
+ == rewrite_parameters["rewrite_specific_params"]
+ )
assert isinstance(
optimizer.optimizer_configuration.train_params, TrainingParameters
)
- if not rewrite_parameters:
- assert asdict(TrainingParameters()) == asdict(
- optimizer.optimizer_configuration.train_params
- )
- else:
- assert asdict(TrainingParameters()) | rewrite_parameters == asdict(
- optimizer.optimizer_configuration.train_params
- )
@pytest.mark.parametrize(
diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py
index 7bb57c3..2f06f54 100644
--- a/tests/test_target_cortex_a_advisor.py
+++ b/tests/test_target_cortex_a_advisor.py
@@ -8,6 +8,7 @@ import pytest
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.workflow import DefaultWorkflowExecutor
+from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
from mlia.target.cortex_a.advisor import CortexAInferenceAdvisor
@@ -33,21 +34,11 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None:
"target_profile": "cortex-a",
},
"common_optimizations": {
- "optimizations": [
- [
- {
- "layers_to_optimize": None,
- "optimization_target": 0.5,
- "optimization_type": "pruning",
- },
- {
- "layers_to_optimize": None,
- "optimization_target": 32,
- "optimization_type": "clustering",
- },
- ]
- ],
- "training_parameters": None,
+ "optimizations": [_DEFAULT_OPTIMIZATION_TARGETS],
+ "rewrite_parameters": {
+ "train_params": None,
+ "rewrite_specific_params": None,
+ },
},
}
diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py
index 020acc5..d0b42b9 100644
--- a/tests/test_target_tosa_advisor.py
+++ b/tests/test_target_tosa_advisor.py
@@ -9,6 +9,7 @@ import pytest
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.workflow import DefaultWorkflowExecutor
+from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
from mlia.target.tosa.advisor import configure_and_get_tosa_advisor
from mlia.target.tosa.advisor import TOSAInferenceAdvisor
@@ -33,21 +34,11 @@ def test_configure_and_get_tosa_advisor(
assert ctx.event_handlers is not None
assert ctx.config_parameters == {
"common_optimizations": {
- "optimizations": [
- [
- {
- "layers_to_optimize": None,
- "optimization_target": 0.5,
- "optimization_type": "pruning",
- },
- {
- "layers_to_optimize": None,
- "optimization_target": 32,
- "optimization_type": "clustering",
- },
- ]
- ],
- "training_parameters": None,
+ "optimizations": [_DEFAULT_OPTIMIZATION_TARGETS],
+ "rewrite_parameters": {
+ "train_params": None,
+ "rewrite_specific_params": None,
+ },
},
"tosa_inference_advisor": {
"model": str(test_tflite_model),