aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-06-13 11:46:29 +0100
committerNathan Bailey <nathan.bailey@arm.com>2024-06-20 14:14:33 +0100
commit170376730d966a42e0622b5576a7db8fa2fa020e (patch)
tree635e00ba1f79c4e30245aad954718e59bf7e1462 /tests
parent09b5122bab771161377321e3f17e05465171ad06 (diff)
downloadmlia-170376730d966a42e0622b5576a7db8fa2fa020e.tar.gz
feat: Enable Depthwise Separable conv2d rewrites
Enables rewrites to be replaced with sparse or clustered depthwise-separable-conv2d layers. Resolves: MLIA-1169 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I83b65142346d468c390c694010cc1bf2218f3be1
Diffstat (limited to 'tests')
-rw-r--r--tests/test_cli_commands.py64
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py247
2 files changed, 209 insertions, 102 deletions
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 5a91cd7..3dbac0a 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -139,9 +139,15 @@ def test_performance_unknown_target(
Exception,
match=re.escape(
"Invalid rewrite target: 'random'. "
- "Supported rewrites: ['conv2d-clustering', 'conv2d-sparsity', "
- "'conv2d-unstructured-sparsity', 'fully-connected', "
- "'fully-connected-clustering', 'fully-connected-sparsity', "
+ "Supported rewrites: ['conv2d', "
+ "'conv2d-clustering', 'conv2d-sparsity', "
+ "'conv2d-unstructured-sparsity', "
+ "'depthwise-separable-conv2d', "
+ "'depthwise-separable-conv2d-clustering', "
+ "'depthwise-separable-conv2d-sparsity', "
+ "'depthwise-separable-conv2d-unstructured-sparsity', "
+ "'fully-connected', 'fully-connected-clustering', "
+ "'fully-connected-sparsity', "
"'fully-connected-unstructured-sparsity']"
),
),
@@ -249,6 +255,58 @@ def test_performance_unknown_target(
"sequential/conv2/Relu;sequential/conv2/Conv2D",
does_not_raise(),
],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "conv2d",
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "depthwise-separable-conv2d-sparsity",
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "depthwise-separable-conv2d-clustering",
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "depthwise-separable-conv2d",
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ does_not_raise(),
+ ],
],
)
def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index 9e3287e..a608b08 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -32,6 +32,7 @@ from mlia.nn.rewrite.core.rewrite import StructuredSparsityRewrite
from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.rewrite.core.rewrite import UnstructuredSparsityRewrite
from mlia.nn.rewrite.core.train import train_in_dir
+from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite
from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite
from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite
from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_unstructured_rewrite
@@ -61,11 +62,15 @@ def test_rewrite() -> None:
@pytest.mark.parametrize(
"rewrite_name, callbacks_length, instance",
[
+ ("conv2d", 0, GenericRewrite),
("fully-connected", 0, GenericRewrite),
+ ("depthwise-separable-conv2d", 0, GenericRewrite),
("fully-connected-clustering", 0, ClusteringRewrite),
("fully-connected-sparsity", 1, StructuredSparsityRewrite),
("conv2d-clustering", 0, ClusteringRewrite),
("conv2d-sparsity", 1, StructuredSparsityRewrite),
+ ("depthwise-separable-conv2d-clustering", 0, ClusteringRewrite),
+ ("depthwise-separable-conv2d-sparsity", 1, StructuredSparsityRewrite),
],
)
def test_rewrite_selection(
@@ -135,16 +140,39 @@ def train_rewrite_model(
return rewrite_model
-def test_rewrite_fully_connected_clustering() -> None:
+@pytest.mark.parametrize(
+ "rewrite_name, input_shape, output_shape, layer_type",
+ [
+ ("conv2d-clustering", np.array([28, 28, 3]), np.array([14, 14, 3]), None),
+ (
+ "depthwise-separable-conv2d-clustering",
+ np.array([28, 28, 3]),
+ np.array([14, 14, 3]),
+ keras.layers.SeparableConv2D,
+ ),
+ ("fully-connected-clustering", (28, 28), 10, None),
+ ],
+)
+def test_rewrite_clustering(
+ rewrite_name: str,
+ input_shape: np.ndarray | tuple,
+ output_shape: np.ndarray | int,
+ layer_type: keras.layers.Layer | None,
+) -> None:
"""Check that fully connected clustering rewrite model
has the set number of clusters."""
+ rewrite_instance = (
+ fc_clustering_rewrite
+ if "fully-connected" in rewrite_name
+ else conv2d_clustering_rewrite
+ )
+ layer_type = [{"layer_type": layer_type}] if layer_type else []
rewrite = ClusteringRewrite(
- "fully-connected-clustering",
- fc_clustering_rewrite,
+ rewrite_name, cast(RewriteCallable, rewrite_instance), *layer_type
)
- model = rewrite(input_shape=(28, 28), output_shape=10, num_clusters=2)
+ model = rewrite(input_shape=input_shape, output_shape=output_shape, num_clusters=2)
model = rewrite.post_process(model)
assert rewrite.check_optimization(
model,
@@ -152,49 +180,36 @@ def test_rewrite_fully_connected_clustering() -> None:
)
-def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> None:
- """
- Check that sparse fully connected
- rewrite model is correctly sparse.
- """
-
- rewrite = StructuredSparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite)
- input_shape = (28, 28)
- output_shape = 10
- model = rewrite(
- input_shape=tuple(input_shape),
- output_shape=output_shape,
- sparsity_m=2,
- sparsity_n=4,
- )
- model = rewrite.post_process(model)
- assert not rewrite.check_optimization(model)
- log_records = caplog.records
- warning_messages = [x.message for x in log_records if x.levelno == 30]
- assert (
- re.search(
- r"\nWARNING: Could not find \(2, 4\) sparsity, in "
- r"layer dense_?\d? for weight dense_?\d?\/kernel:0 \n",
- warning_messages[0],
- )
- is not None
- )
- model = rewrite(
- input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4
+@pytest.mark.parametrize(
+ "rewrite_name, input_shape, output_shape, layer_type",
+ [
+ ("conv2d-sparsity", np.array([28, 28, 3]), np.array([14, 14, 3]), None),
+ (
+ "depthwise-separable-conv2d-sparsity",
+ np.array([28, 28, 3]),
+ np.array([14, 14, 3]),
+ keras.layers.SeparableConv2D,
+ ),
+ ("fully-connected-sparsity", (28, 28), 10, None),
+ ],
+)
+def test_rewrite_sparsity(
+ rewrite_name: str,
+ input_shape: np.ndarray | tuple,
+ output_shape: np.ndarray | int,
+ layer_type: keras.layers.Layer | None,
+ caplog: pytest.LogCaptureFixture,
+) -> None:
+ """Check that sparse conv2d rewrite model is correctly sparse."""
+ rewrite_instance = (
+ fc_sparsity_rewrite
+ if "fully-connected" in rewrite_name
+ else conv2d_sparsity_rewrite
)
- train_rewrite_model(
- input_shape=input_shape, output_shape=output_shape, rewrite_model=model
+ layer_type = [{"layer_type": layer_type}] if layer_type else []
+ rewrite = StructuredSparsityRewrite(
+ rewrite_name, cast(RewriteCallable, rewrite_instance), *layer_type
)
- model = rewrite.post_process(model)
- assert rewrite.check_optimization(model)
-
-
-def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None:
- """Check that sparse conv2d rewrite model is correctly sparse."""
-
- rewrite = StructuredSparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite)
- input_shape = np.array([28, 28, 3])
- output_shape = np.array([14, 14, 3])
model = rewrite(
input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4
)
@@ -202,14 +217,24 @@ def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None:
assert not rewrite.check_optimization(model)
log_records = caplog.records
warning_messages = [x.message for x in log_records if x.levelno == 30]
- assert (
- re.search(
- r"\nWARNING: Could not find \(2, 4\) sparsity, in "
- r"layer conv2d_?\d? for weight conv2d_?\d?\/kernel:0 \n",
- warning_messages[0],
+ if "fully-connected" in rewrite_name:
+ assert (
+ re.search(
+ r"\nWARNING: Could not find \(2, 4\) sparsity, in "
+ r"layer .*dense_?\d? for weight .*dense_?\d?\/.*kernel:0 \n",
+ warning_messages[0],
+ )
+ is not None
+ )
+ else:
+ assert (
+ re.search(
+ r"\nWARNING: Could not find \(2, 4\) sparsity, in "
+ r"layer .*conv2d_?\d? for weight .*conv2d_?\d?\/.*kernel:0 \n",
+ warning_messages[0],
+ )
+ is not None
)
- is not None
- )
model = rewrite(
input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4
)
@@ -220,57 +245,43 @@ def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None:
assert rewrite.check_optimization(model)
-def test_rewrite_conv2d_unstructured_sparsity(caplog: pytest.LogCaptureFixture) -> None:
- """Check that an unstructured sparse conv2d rewrite is correctly sparse."""
-
- rewrite = UnstructuredSparsityRewrite(
- "conv2d-unstructured-sparsity", conv2d_sparsity_unstructured_rewrite
- )
- input_shape = np.array([28, 28, 3])
- output_shape = np.array([14, 14, 3])
- model = rewrite(
- input_shape=input_shape, output_shape=output_shape, final_sparsity=0.50
- )
- model = rewrite.post_process(model)
- assert not rewrite.check_optimization(model)
- log_records = caplog.records
- warning_messages = [x.message for x in log_records if x.levelno == 30]
- assert (
- re.search(
- r"\nWARNING: Found total sparsity of rewrite model: \d.\d\d "
- r"expected total sparsity to be: 0.50\n",
- warning_messages[0],
- )
- is not None
- )
- model = rewrite(
- input_shape=input_shape,
- output_shape=output_shape,
- final_sparsity=0.5,
- end_step=120,
- )
- train_rewrite_model(
- input_shape=input_shape,
- output_shape=output_shape,
- rewrite_model=model,
- epochs=10,
- )
- model = rewrite.post_process(model)
- assert rewrite.check_optimization(model)
-
-
-def test_rewrite_fully_connected_unstructured_sparsity(
+@pytest.mark.parametrize(
+ "rewrite_name, input_shape, output_shape, layer_type",
+ [
+ (
+ "conv2d-unstructured-sparsity",
+ np.array([28, 28, 3]),
+ np.array([14, 14, 3]),
+ None,
+ ),
+ (
+ "depthwise-separable-conv2d-unstructured-sparsity",
+ np.array([28, 28, 3]),
+ np.array([14, 14, 3]),
+ keras.layers.SeparableConv2D,
+ ),
+ ("fully-connected-unstructured-sparsity", (28, 28), 10, None),
+ ],
+)
+def test_rewrite_unstructured_sparsity(
+ rewrite_name: str,
+ input_shape: np.ndarray | tuple,
+ output_shape: np.ndarray | int,
+ layer_type: keras.layers.Layer | None,
caplog: pytest.LogCaptureFixture,
) -> None:
- """Check that an unstructured sparse FC rewrite is correctly sparse."""
-
+ """Check that an unstructured sparse conv2d rewrite is correctly sparse."""
+ rewrite_instance = (
+ fc_sparsity_unstructured_rewrite
+ if "fully-connected" in rewrite_name
+ else conv2d_sparsity_unstructured_rewrite
+ )
+ layer_type = [{"layer_type": layer_type}] if layer_type else []
rewrite = UnstructuredSparsityRewrite(
- "fully-connected-unstructured-sparsity", fc_sparsity_unstructured_rewrite
+ rewrite_name, cast(RewriteCallable, rewrite_instance)
)
- input_shape = (28, 28)
- output_shape = 10
model = rewrite(
- input_shape=tuple(input_shape), output_shape=output_shape, final_sparsity=0.5
+ input_shape=input_shape, output_shape=output_shape, final_sparsity=0.50
)
model = rewrite.post_process(model)
assert not rewrite.check_optimization(model)
@@ -304,10 +315,43 @@ def test_rewrite_fully_connected_unstructured_sparsity(
"rewrite_type, expected_layers, quant",
[
["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False],
+ [
+ "conv2d",
+ [keras.layers.Conv2D, keras.layers.BatchNormalization, keras.layers.ReLU],
+ False,
+ ],
["fully-connected-clustering", [ClusterWeights, ClusterWeights], False],
["fully-connected-clustering", [ClusterWeights, ClusterWeights], True],
["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], False],
- ["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], True],
+ [
+ "depthwise-separable-conv2d",
+ [
+ keras.layers.SeparableConv2D,
+ keras.layers.BatchNormalization,
+ keras.layers.ReLU,
+ ],
+ False,
+ ],
+ [
+ "depthwise-separable-conv2d-clustering",
+ [ClusterWeights, ClusterWeights, ClusterWeights],
+ False,
+ ],
+ [
+ "depthwise-separable-conv2d-clustering",
+ [ClusterWeights, ClusterWeights, ClusterWeights],
+ True,
+ ],
+ [
+ "depthwise-separable-conv2d-sparsity",
+ [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude],
+ False,
+ ],
+ [
+ "depthwise-separable-conv2d-sparsity",
+ [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude],
+ True,
+ ],
["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], False],
["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], True],
[
@@ -411,9 +455,14 @@ def test_register_rewrite_function() -> None:
def test_builtin_rewrite_names() -> None:
"""Test if all builtin rewrites are properly registered and returned."""
assert set(RewritingOptimizer.builtin_rewrite_names()) == {
+ "conv2d",
"conv2d-clustering",
"conv2d-sparsity",
"conv2d-unstructured-sparsity",
+ "depthwise-separable-conv2d",
+ "depthwise-separable-conv2d-clustering",
+ "depthwise-separable-conv2d-sparsity",
+ "depthwise-separable-conv2d-unstructured-sparsity",
"fully-connected",
"fully-connected-clustering",
"fully-connected-sparsity",