aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/conftest.py10
-rw-r--r--tests/test_cli_commands.py56
-rw-r--r--tests/test_common_optimization.py76
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py323
-rw-r--r--tests/test_nn_rewrite_core_train.py26
-rw-r--r--tests/test_nn_rewrite_library_helper_functions.py51
-rw-r--r--tests/test_nn_select.py12
-rw-r--r--tests/test_target_cortex_a_advisor.py2
-rw-r--r--tests/test_target_tosa_advisor.py2
9 files changed, 486 insertions, 72 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index 3d0b832..981bf3b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -257,17 +257,15 @@ def fixture_test_tfrecord_fp32(
yield from create_tfrecord(tmp_path_factory, random_data)
-@pytest.fixture(scope="session", autouse=True)
+@pytest.fixture(scope="function", autouse=True)
def set_training_steps(
request: _pytest.fixtures.SubRequest,
) -> Generator[None, None, None]:
"""Speed up tests by using MockTrainingParameters."""
- if "set_training_steps" == request.fixturename:
- yield
- else:
+ if "skip_set_training_steps" not in request.keywords:
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"mlia.nn.select._get_rewrite_params",
- MagicMock(return_value=[MockTrainingParameters(), None, None]),
+ MagicMock(return_value=MockTrainingParameters()),
)
- yield
+ yield
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 9cda27c..ce9e144 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -84,6 +84,19 @@ def test_performance_unknown_target(
],
[
"ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "fully-connected-sparsity24",
+ "sequential/flatten/Reshape",
+ "StatefulPartitionedCall:0",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
True,
False,
None,
@@ -126,7 +139,9 @@ def test_performance_unknown_target(
Exception,
match=re.escape(
"Invalid rewrite target: 'random'. "
- "Supported rewrites: ['fully-connected']"
+ "Supported rewrites: ['conv2d-clustering', 'conv2d-sparsity24', "
+ "'fully-connected', 'fully-connected-clustering', "
+ "'fully-connected-sparsity24']"
),
),
],
@@ -168,6 +183,45 @@ def test_performance_unknown_target(
),
),
],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "fully-connected-clustering",
+ "sequential/flatten/Reshape",
+ "StatefulPartitionedCall:0",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "conv2d-sparsity24",
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "conv2d-clustering",
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ does_not_raise(),
+ ],
],
)
def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 05a5b55..58ea8af 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the common optimization module."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raises
from pathlib import Path
from typing import Any
@@ -15,6 +17,7 @@ from mlia.nn.tensorflow.config import TFLiteModel
from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.common.optimization import OptimizingDataCollector
+from mlia.target.common.optimization import parse_augmentations
from mlia.target.config import load_profile
from mlia.target.config import TargetProfile
@@ -57,7 +60,7 @@ def test_optimizing_data_collector(
config_parameters={
"common_optimizations": {
"optimizations": optimizations,
- "training_parameters": [training_parameters],
+ "training_parameters": training_parameters,
}
}
)
@@ -94,7 +97,7 @@ def test_optimizing_data_collector(
collector.set_context(context)
collector.collect_data()
assert optimize_model_mock.call_args.args[0] == opt_settings[0]
- assert optimize_model_mock.call_args.args[1] == [training_parameters]
+ assert optimize_model_mock.call_args.args[1] == training_parameters
assert fake_optimizer.invocation_count == 1
@@ -158,10 +161,67 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -
]
if not extra_args.get("optimization_profile"):
- assert advisor_parameters["common_optimizations"][
- "training_parameters"
- ] == [None]
+ assert (
+ advisor_parameters["common_optimizations"]["training_parameters"]
+ is None
+ )
else:
- assert advisor_parameters["common_optimizations"][
- "training_parameters"
- ] == list(extra_args["optimization_profile"].values())
+ assert (
+ advisor_parameters["common_optimizations"]["training_parameters"]
+ == extra_args["optimization_profile"]["training"]
+ )
+
+
+@pytest.mark.parametrize(
+ "augmentations, expected_output",
+ [
+ (
+ {"gaussian_strength": 1.0, "mixup_strength": 1.0},
+ (1.0, 1.0),
+ ),
+ (
+ {"gaussian_strength": 1.0},
+ (None, 1.0),
+ ),
+ (
+ {"Wrong param": 1.0, "mixup_strength": 1.0},
+ (1.0, None),
+ ),
+ (
+ {"Wrong param1": 1.0, "Wrong param2": 1.0},
+ (None, None),
+ ),
+ (
+ "gaussian",
+ (None, 1.0),
+ ),
+ (
+ "mix_gaussian_large",
+ (2.0, 1.0),
+ ),
+ (
+ "not in presets",
+ (None, None),
+ ),
+ (
+ {"gaussian_strength": 1.0, "mixup_strength": 1.0, "mix2": 1.0},
+ (1.0, 1.0),
+ ),
+ (
+ {"gaussian_strength": "not a float", "mixup_strength": 1.0},
+ (1.0, None),
+ ),
+ (
+ None,
+ (None, None),
+ ),
+ ],
+)
+def test_parse_augmentations(
+ augmentations: dict | str | None, expected_output: tuple
+) -> None:
+ """Check that augmentation parameters in optimization_profiles are
+ correctly parsed."""
+
+ augmentation_output = parse_augmentations(augmentations)
+ assert augmentation_output == expected_output
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index b32fafd..97b0b96 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -3,52 +3,120 @@
"""Tests for module mlia.nn.rewrite.core.rewrite."""
from __future__ import annotations
+import re
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import cast
from unittest.mock import MagicMock
+import numpy as np
import pytest
+import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module
+ ClusterWeights,
+)
+from tensorflow_model_optimization.python.core.sparsity.keras.pruning_wrapper import ( # pylint: disable=no-name-in-module
+ PruneLowMagnitude,
+)
-from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
+from mlia.nn.rewrite.core.rewrite import ClusteringRewrite
+from mlia.nn.rewrite.core.rewrite import GenericRewrite
from mlia.nn.rewrite.core.rewrite import Rewrite
from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite
from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.rewrite.core.train import train_in_dir
+from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite
+from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
+class TestRewrite(Rewrite):
+ """Test rewrite class."""
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ return tfmot.quantization.keras.quantize_model(model)
+
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Not needed."""
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ """Not needed here."""
+ return True
+
+
def mock_rewrite_function(*_: Any) -> Any:
"""Mock function to test autoloading of rewrite functions."""
def test_rewrite() -> None:
- """Test the Rewrite class."""
+ """Test a derived Rewrite class."""
def bad_rewrite_func() -> Any:
raise NotImplementedError()
- rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func))
+ rewrite = TestRewrite(
+ "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)
+ )
with pytest.raises(RuntimeError):
rewrite((1, 2), (1, 2))
@pytest.mark.parametrize(
+ "rewrite_name, callbacks_length, instance",
+ [
+ ("fully-connected", 0, GenericRewrite),
+ ("fully-connected-clustering", 0, ClusteringRewrite),
+ ("fully-connected-sparsity24", 1, Sparsity24Rewrite),
+ ("conv2d-clustering", 0, ClusteringRewrite),
+ ("conv2d-sparsity24", 1, Sparsity24Rewrite),
+ ],
+)
+def test_rewrite_selection(
+ rewrite_name: str, callbacks_length: int, instance: Rewrite
+) -> None:
+ """Test that the correct rewrite class is instantiated."""
+ rewrite = RewritingOptimizer.registry.items[rewrite_name]
+ assert rewrite.name == rewrite_name
+ assert isinstance(rewrite, instance) # type: ignore
+ assert len(rewrite.training_callbacks()) == callbacks_length
+
+
+@pytest.mark.parametrize(
"rewrite_name, expected_error",
[
("fully-connected", does_not_raise()),
+ ("fully-connected-sparsity24", does_not_raise()),
+ ("fully-connected-clustering", does_not_raise()),
+ ("conv2d-clustering", does_not_raise()),
+ ("conv2d-sparsity24", does_not_raise()),
("random", does_not_raise()),
],
)
def test_rewrite_configuration(
test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
) -> None:
- """Test get_rewrite function only supports rewrite type fully-connected."""
+ """Test get_rewrite function only supports rewrite type fully-connected,
+ fully-connected-clustering, fully-connected-sparsity24, conv2d-clustering
+ and conv2d-sparsity24."""
with expected_error:
config_obj = RewriteConfiguration(
rewrite_name,
@@ -63,19 +131,209 @@ def test_rewrite_configuration(
assert isinstance(rewriter_obj, RewritingOptimizer)
-def test_rewriting_optimizer(
+def train_rewrite_model(
+ input_shape: tuple | np.ndarray,
+ output_shape: int | np.ndarray,
+ rewrite_model: keras.Model,
+) -> keras.Model:
+ """Helper function to quickly train a rewrite model."""
+ rewrite_model.compile(
+ optimizer=keras.optimizers.Nadam(learning_rate=0.01),
+ loss=keras.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ if isinstance(output_shape, int):
+ output_shape_list = [output_shape]
+ else:
+ output_shape_list = output_shape.tolist()
+ rewrite_model.fit(
+ x=np.random.rand(16, *input_shape),
+ y=np.random.rand(16, *output_shape_list),
+ batch_size=1,
+ epochs=1,
+ callbacks=[tfmot.sparsity.keras.UpdatePruningStep()],
+ )
+ return rewrite_model
+
+
+def test_rewrite_fully_connected_clustering(caplog: pytest.LogCaptureFixture) -> None:
+ """Check that fully connected clustering rewrite model
+ has the set number of clusters."""
+
+ rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
+ model = rewrite(input_shape=(28, 28), output_shape=10)
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model, number_of_clusters=4)
+ rewrite.check_optimization(model, number_of_clusters=2)
+ log_records = caplog.records
+ warning_messages = [x.message for x in log_records if x.levelno == 30]
+ assert (
+ re.search(
+ r"\nWARNING: Expected 2 cluster\(s\), found \d+ cluster\(s\) "
+ r"in layer dense_?\d? for weight kernel:0 \n",
+ warning_messages[0],
+ )
+ is not None
+ )
+
+
+def test_rewrite_conv2d_clustering(caplog: pytest.LogCaptureFixture) -> None:
+ """Check that conv2d clustering rewrite model has the set number of clusters."""
+
+ rewrite = ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite)
+ model = rewrite(
+ input_shape=np.array([28, 28, 3]), output_shape=np.array([14, 14, 3])
+ )
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model, number_of_clusters=4)
+ rewrite.check_optimization(model, number_of_clusters=2)
+ log_records = caplog.records
+ warning_messages = [x.message for x in log_records if x.levelno == 30]
+ assert (
+ re.search(
+ r"\nWARNING: Expected 2 cluster\(s\), found \d+ cluster\(s\) "
+ r"in layer conv2d_?\d? for weight kernel:0 \n",
+ warning_messages[0],
+ )
+ is not None
+ )
+
+
+def test_rewrite_clustering_error_handling() -> None:
+ """
+ Check that the clustering rewrite check_optimization
+ function returns the current error.
+ """
+
+ rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
+ model = rewrite(input_shape=(28, 28), output_shape=10)
+ with pytest.raises(
+ ValueError,
+ match=(r"Expected check_optimization to have argument number_of_clusters"),
+ ):
+ rewrite.check_optimization(model, bad_arg_name=25)
+
+
+def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> None:
+ """
+ Check that sparse fully connected
+ rewrite model is correctly sparse.
+ """
+
+ rewrite = Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite)
+ input_shape = (28, 28)
+ output_shape = 10
+ model = rewrite(input_shape=tuple(input_shape), output_shape=output_shape)
+ model = rewrite.post_process(model)
+ assert not rewrite.check_optimization(model)
+ log_records = caplog.records
+ warning_messages = [x.message for x in log_records if x.levelno == 30]
+ assert (
+ re.search(
+ r"\nWARNING: Could not find \(2,4\) sparsity, in "
+ r"layer dense_?\d? for weight dense_?\d?\/kernel:0 \n",
+ warning_messages[0],
+ )
+ is not None
+ )
+ model = rewrite(input_shape=input_shape, output_shape=output_shape)
+ train_rewrite_model(
+ input_shape=input_shape, output_shape=output_shape, rewrite_model=model
+ )
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model)
+
+
+def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None:
+ """Check that sparse conv2d rewrite model is correctly sparse."""
+
+ rewrite = Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite)
+ input_shape = np.array([28, 28, 3])
+ output_shape = np.array([14, 14, 3])
+ model = rewrite(input_shape=input_shape, output_shape=output_shape)
+ model = rewrite.post_process(model)
+ assert not rewrite.check_optimization(model)
+ log_records = caplog.records
+ warning_messages = [x.message for x in log_records if x.levelno == 30]
+ assert (
+ re.search(
+ r"\nWARNING: Could not find \(2,4\) sparsity, in "
+ r"layer conv2d_?\d? for weight conv2d_?\d?\/kernel:0 \n",
+ warning_messages[0],
+ )
+ is not None
+ )
+ model = rewrite(input_shape=input_shape, output_shape=output_shape)
+ train_rewrite_model(
+ input_shape=input_shape, output_shape=output_shape, rewrite_model=model
+ )
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model)
+
+
+@pytest.mark.parametrize(
+ "rewrite_type, expected_layers, quant",
+ [
+ ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False],
+ ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False],
+ ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True],
+ ["fully-connected-sparsity24", [PruneLowMagnitude, PruneLowMagnitude], False],
+ ["fully-connected-sparsity24", [PruneLowMagnitude, PruneLowMagnitude], True],
+ ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], False],
+ ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], True],
+ [
+ "conv2d-sparsity24",
+ [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude],
+ False,
+ ],
+ [
+ "conv2d-sparsity24",
+ [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude],
+ True,
+ ],
+ ],
+)
+def test_rewriting_optimizer( # pylint: disable=too-many-locals
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
+ test_tflite_model: Path,
+ test_tfrecord: Path,
+ rewrite_type: str,
+ expected_layers: list[object],
+ quant: bool,
) -> None:
- """Test fc_layer rewrite process with rewrite type fully-connected."""
+ """Test the rewrite process with all rewrite types."""
+
+ tfrecord = test_tfrecord if quant else test_tfrecord_fp32
+ tflite_model = test_tflite_model if quant else test_tflite_model_fp32
+
config_obj = RewriteConfiguration(
- "fully-connected",
- ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
- test_tfrecord_fp32,
+ rewrite_type,
+ ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"]
+ if "fully-connected" in rewrite_type
+ else [
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ ],
+ tfrecord,
train_params=MockTrainingParameters(),
)
- test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
+ test_obj = RewritingOptimizer(tflite_model, config_obj)
+ rewrite_function = RewritingOptimizer.registry.items[
+ test_obj.optimizer_configuration.optimization_target
+ ]
+ # Input, output shape does not matter, just need the test the layers are as expected
+ rewrite_model = (
+ rewrite_function(input_shape=(28, 28, 1), output_shape=12)
+ if "fully-connected" in rewrite_type
+ else rewrite_function(
+ input_shape=np.array([28, 28, 3]), output_shape=np.array([14, 14, 3])
+ )
+ )
+ for idx, layer in enumerate(rewrite_model.layers):
+ assert isinstance(layer, expected_layers[idx]) # type: ignore
+
test_obj.apply_optimization()
trained_model = test_obj.get_model()
@@ -87,11 +345,11 @@ def test_rewriting_optimizer(
def test_register_rewrite_function() -> None:
- """Test adding rewrite functions and verify the are reported via the registry."""
+ """Test adding rewrite functions and verify they are reported via the registry."""
registry = RewriteRegistry()
- rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1))
- rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2))
+ rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1))
+ rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2))
registry.register_rewrite(rewrite1)
registry.register_rewrite(rewrite2)
@@ -100,38 +358,13 @@ def test_register_rewrite_function() -> None:
def test_builtin_rewrite_names() -> None:
"""Test if all builtin rewrites are properly registered and returned."""
- assert RewritingOptimizer.builtin_rewrite_names() == ["fully-connected"]
-
-
-def test_rewrite_function_autoload() -> None:
- """Test rewrite function loading."""
- function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function"
- rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name)
- assert rewrite.name == "mock_rewrite"
-
- assert rewrite.function is not mock_rewrite_function
- assert rewrite.load_function(function_name) is mock_rewrite_function
- assert rewrite.function is mock_rewrite_function
-
-
-def test_rewrite_function_autoload_fail() -> None:
- """Test rewrite function loading failure."""
- function_name = "invalid_module.invalid_function"
- rewrite = DynamicallyLoadedRewrite(
- name="mock_rewrite",
- function_name="invalid_module.invalid_function",
- )
- assert rewrite.name == "mock_rewrite"
-
- with pytest.raises(Exception) as exc_info:
- rewrite.load_function(function_name)
-
- message = exc_info.value.args[0]
-
- assert message == (
- "Unable to load rewrite function 'invalid_module.invalid_function'"
- " for 'mock_rewrite'."
- )
+ assert RewritingOptimizer.builtin_rewrite_names() == [
+ "conv2d-clustering",
+ "conv2d-sparsity24",
+ "fully-connected",
+ "fully-connected-clustering",
+ "fully-connected-sparsity24",
+ ]
def test_rewrite_configuration_train_params(
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 6d24133..94c99ff 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -20,6 +20,7 @@ from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
+from tests.test_nn_rewrite_core_rewrite import TestRewrite
from tests.utils.rewrite import MockTrainingParameters
@@ -53,18 +54,23 @@ def check_train(
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
output_file = Path(tmp_dir, "out.tflite")
+ mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv)
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
output_model=str(output_file),
input_tfrec=str(tfrecord),
- replace_fn=replace_fully_connected_with_conv,
+ rewrite=mock_rewrite,
+ is_qat=False,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
train_params=train_params,
)
- assert len(result) == 2
- assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}"
+
+ assert len(result[0][0]) == 2
+ assert all(
+ res >= 0.0 for res in result[0][0]
+ ), f"Results out of bound: {result}"
assert output_file.is_file()
if quantized:
@@ -229,3 +235,17 @@ def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None:
with expected_error:
fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore
assert len(fn_twins) == 2
+
+
+def test_train_checkpoint(
+ test_tflite_model: Path,
+ test_tfrecord: Path,
+) -> None:
+ """Test the train() function with valid checkpoint parameters."""
+ check_train(
+ tflite_model=test_tflite_model,
+ tfrecord=test_tfrecord,
+ train_params=MockTrainingParameters(steps=64, checkpoint_at=[24, 32]),
+ use_unmodified_model=False,
+ quantized=True,
+ )
diff --git a/tests/test_nn_rewrite_library_helper_functions.py b/tests/test_nn_rewrite_library_helper_functions.py
new file mode 100644
index 0000000..c3117f0
--- /dev/null
+++ b/tests/test_nn_rewrite_library_helper_functions.py
@@ -0,0 +1,51 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.library.helper_functions."""
+from __future__ import annotations
+
+from typing import Any
+
+import numpy as np
+import pytest
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+
+
+def compute_conv_output(
+ input_data: np.ndarray, input_shape: np.ndarray, conv_parameters: dict[str, Any]
+) -> np.ndarray:
+ """Compute the output of a conv layer for testing."""
+ test_model = keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv_parameters),
+ ]
+ )
+ output = test_model(input_data)
+ return np.array(output.shape[1:])
+
+
+@pytest.mark.parametrize(
+ "input_shape, output_shape",
+ [
+ (np.array([32, 32, 3]), np.array([16, 16, 3])),
+ (np.array([32, 32, 3]), np.array([8, 8, 3])),
+ (np.array([32, 32, 3]), np.array([8, 16, 3])),
+ (np.array([25, 10, 3]), np.array([13, 5, 3])),
+ (np.array([25, 10, 3]), np.array([7, 5, 3])),
+ (np.array([25, 10, 3]), np.array([6, 4, 3])),
+ (np.array([25, 10, 3]), np.array([5, 5, 3])),
+ ],
+)
+def test_compute_conv2d_parameters(
+ input_shape: np.ndarray, output_shape: np.ndarray
+) -> None:
+ """Test to check compute_conv2d_parameters works as expected."""
+ conv_parameters = compute_conv2d_parameters(
+ input_shape=input_shape, output_shape=output_shape
+ )
+ computed_output_shape = compute_conv_output(
+ np.random.rand(1, *input_shape), input_shape, conv_parameters
+ )
+ assert np.equal(computed_output_shape, output_shape).all()
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py
index aac07b4..4095076 100644
--- a/tests/test_nn_select.py
+++ b/tests/test_nn_select.py
@@ -183,11 +183,11 @@ def test_get_optimizer(
@pytest.mark.parametrize(
"rewrite_parameters",
- [[None], [{"batch_size": 64, "learning_rate": 0.003}]],
+ [None, {"batch_size": 64, "learning_rate": 0.003}],
)
@pytest.mark.skip_set_training_steps
def test_get_optimizer_training_parameters(
- rewrite_parameters: list[dict], test_tflite_model: Path
+ rewrite_parameters: dict | None, test_tflite_model: Path
) -> None:
"""Test function get_optimzer with various combinations of parameters."""
config = OptimizationSettings(
@@ -198,20 +198,18 @@ def test_get_optimizer_training_parameters(
)
optimizer = cast(
RewritingOptimizer,
- get_optimizer(test_tflite_model, config, list(rewrite_parameters)),
+ get_optimizer(test_tflite_model, config, rewrite_parameters),
)
- assert len(rewrite_parameters) == 1
-
assert isinstance(
optimizer.optimizer_configuration.train_params, TrainingParameters
)
- if not rewrite_parameters[0]:
+ if not rewrite_parameters:
assert asdict(TrainingParameters()) == asdict(
optimizer.optimizer_configuration.train_params
)
else:
- assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict(
+ assert asdict(TrainingParameters()) | rewrite_parameters == asdict(
optimizer.optimizer_configuration.train_params
)
diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py
index 59d54b5..7bb57c3 100644
--- a/tests/test_target_cortex_a_advisor.py
+++ b/tests/test_target_cortex_a_advisor.py
@@ -47,7 +47,7 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None:
},
]
],
- "training_parameters": [None],
+ "training_parameters": None,
},
}
diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py
index cc47321..020acc5 100644
--- a/tests/test_target_tosa_advisor.py
+++ b/tests/test_target_tosa_advisor.py
@@ -47,7 +47,7 @@ def test_configure_and_get_tosa_advisor(
},
]
],
- "training_parameters": [None],
+ "training_parameters": None,
},
"tosa_inference_advisor": {
"model": str(test_tflite_model),