diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-06-13 11:46:29 +0100 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-06-20 14:14:33 +0100 |
commit | 170376730d966a42e0622b5576a7db8fa2fa020e (patch) | |
tree | 635e00ba1f79c4e30245aad954718e59bf7e1462 /tests | |
parent | 09b5122bab771161377321e3f17e05465171ad06 (diff) | |
download | mlia-170376730d966a42e0622b5576a7db8fa2fa020e.tar.gz |
feat: Enable Depthwise Separable conv2d rewrites
Enables rewrites to be replaced with sparse or clustered depthwise-separable-conv2d layers.
Resolves: MLIA-1169
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I83b65142346d468c390c694010cc1bf2218f3be1
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_cli_commands.py | 64 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 247 |
2 files changed, 209 insertions, 102 deletions
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 5a91cd7..3dbac0a 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -139,9 +139,15 @@ def test_performance_unknown_target( Exception, match=re.escape( "Invalid rewrite target: 'random'. " - "Supported rewrites: ['conv2d-clustering', 'conv2d-sparsity', " - "'conv2d-unstructured-sparsity', 'fully-connected', " - "'fully-connected-clustering', 'fully-connected-sparsity', " + "Supported rewrites: ['conv2d', " + "'conv2d-clustering', 'conv2d-sparsity', " + "'conv2d-unstructured-sparsity', " + "'depthwise-separable-conv2d', " + "'depthwise-separable-conv2d-clustering', " + "'depthwise-separable-conv2d-sparsity', " + "'depthwise-separable-conv2d-unstructured-sparsity', " + "'fully-connected', 'fully-connected-clustering', " + "'fully-connected-sparsity', " "'fully-connected-unstructured-sparsity']" ), ), @@ -249,6 +255,58 @@ def test_performance_unknown_target( "sequential/conv2/Relu;sequential/conv2/Conv2D", does_not_raise(), ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "conv2d", + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + does_not_raise(), + ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "depthwise-separable-conv2d-sparsity", + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + does_not_raise(), + ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "depthwise-separable-conv2d-clustering", + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + does_not_raise(), + ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "depthwise-separable-conv2d", + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + does_not_raise(), + ], ], ) def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index 9e3287e..a608b08 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -32,6 +32,7 @@ from mlia.nn.rewrite.core.rewrite import StructuredSparsityRewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.rewrite.core.rewrite import UnstructuredSparsityRewrite from mlia.nn.rewrite.core.train import train_in_dir +from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_unstructured_rewrite @@ -61,11 +62,15 @@ def test_rewrite() -> None: @pytest.mark.parametrize( "rewrite_name, callbacks_length, instance", [ + ("conv2d", 0, GenericRewrite), ("fully-connected", 0, GenericRewrite), + ("depthwise-separable-conv2d", 0, GenericRewrite), ("fully-connected-clustering", 0, ClusteringRewrite), ("fully-connected-sparsity", 1, StructuredSparsityRewrite), ("conv2d-clustering", 0, ClusteringRewrite), ("conv2d-sparsity", 1, StructuredSparsityRewrite), + ("depthwise-separable-conv2d-clustering", 0, ClusteringRewrite), + ("depthwise-separable-conv2d-sparsity", 1, StructuredSparsityRewrite), ], ) def test_rewrite_selection( @@ -135,16 +140,39 @@ def train_rewrite_model( return rewrite_model -def test_rewrite_fully_connected_clustering() -> None: +@pytest.mark.parametrize( + "rewrite_name, input_shape, output_shape, layer_type", + [ + ("conv2d-clustering", np.array([28, 28, 3]), np.array([14, 14, 3]), None), + ( + "depthwise-separable-conv2d-clustering", + np.array([28, 28, 3]), + np.array([14, 14, 3]), + keras.layers.SeparableConv2D, + ), + ("fully-connected-clustering", (28, 28), 10, None), + ], +) +def test_rewrite_clustering( + rewrite_name: str, + input_shape: np.ndarray | tuple, + output_shape: np.ndarray | int, + layer_type: keras.layers.Layer | None, +) -> None: """Check that fully connected clustering rewrite model has the set number of clusters.""" + rewrite_instance = ( + fc_clustering_rewrite + if "fully-connected" in rewrite_name + else conv2d_clustering_rewrite + ) + layer_type = [{"layer_type": layer_type}] if layer_type else [] rewrite = ClusteringRewrite( - "fully-connected-clustering", - fc_clustering_rewrite, + rewrite_name, cast(RewriteCallable, rewrite_instance), *layer_type ) - model = rewrite(input_shape=(28, 28), output_shape=10, num_clusters=2) + model = rewrite(input_shape=input_shape, output_shape=output_shape, num_clusters=2) model = rewrite.post_process(model) assert rewrite.check_optimization( model, @@ -152,49 +180,36 @@ def test_rewrite_fully_connected_clustering() -> None: ) -def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> None: - """ - Check that sparse fully connected - rewrite model is correctly sparse. - """ - - rewrite = StructuredSparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite) - input_shape = (28, 28) - output_shape = 10 - model = rewrite( - input_shape=tuple(input_shape), - output_shape=output_shape, - sparsity_m=2, - sparsity_n=4, - ) - model = rewrite.post_process(model) - assert not rewrite.check_optimization(model) - log_records = caplog.records - warning_messages = [x.message for x in log_records if x.levelno == 30] - assert ( - re.search( - r"\nWARNING: Could not find \(2, 4\) sparsity, in " - r"layer dense_?\d? for weight dense_?\d?\/kernel:0 \n", - warning_messages[0], - ) - is not None - ) - model = rewrite( - input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4 +@pytest.mark.parametrize( + "rewrite_name, input_shape, output_shape, layer_type", + [ + ("conv2d-sparsity", np.array([28, 28, 3]), np.array([14, 14, 3]), None), + ( + "depthwise-separable-conv2d-sparsity", + np.array([28, 28, 3]), + np.array([14, 14, 3]), + keras.layers.SeparableConv2D, + ), + ("fully-connected-sparsity", (28, 28), 10, None), + ], +) +def test_rewrite_sparsity( + rewrite_name: str, + input_shape: np.ndarray | tuple, + output_shape: np.ndarray | int, + layer_type: keras.layers.Layer | None, + caplog: pytest.LogCaptureFixture, +) -> None: + """Check that sparse conv2d rewrite model is correctly sparse.""" + rewrite_instance = ( + fc_sparsity_rewrite + if "fully-connected" in rewrite_name + else conv2d_sparsity_rewrite ) - train_rewrite_model( - input_shape=input_shape, output_shape=output_shape, rewrite_model=model + layer_type = [{"layer_type": layer_type}] if layer_type else [] + rewrite = StructuredSparsityRewrite( + rewrite_name, cast(RewriteCallable, rewrite_instance), *layer_type ) - model = rewrite.post_process(model) - assert rewrite.check_optimization(model) - - -def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None: - """Check that sparse conv2d rewrite model is correctly sparse.""" - - rewrite = StructuredSparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite) - input_shape = np.array([28, 28, 3]) - output_shape = np.array([14, 14, 3]) model = rewrite( input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4 ) @@ -202,14 +217,24 @@ def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None: assert not rewrite.check_optimization(model) log_records = caplog.records warning_messages = [x.message for x in log_records if x.levelno == 30] - assert ( - re.search( - r"\nWARNING: Could not find \(2, 4\) sparsity, in " - r"layer conv2d_?\d? for weight conv2d_?\d?\/kernel:0 \n", - warning_messages[0], + if "fully-connected" in rewrite_name: + assert ( + re.search( + r"\nWARNING: Could not find \(2, 4\) sparsity, in " + r"layer .*dense_?\d? for weight .*dense_?\d?\/.*kernel:0 \n", + warning_messages[0], + ) + is not None + ) + else: + assert ( + re.search( + r"\nWARNING: Could not find \(2, 4\) sparsity, in " + r"layer .*conv2d_?\d? for weight .*conv2d_?\d?\/.*kernel:0 \n", + warning_messages[0], + ) + is not None ) - is not None - ) model = rewrite( input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4 ) @@ -220,57 +245,43 @@ def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None: assert rewrite.check_optimization(model) -def test_rewrite_conv2d_unstructured_sparsity(caplog: pytest.LogCaptureFixture) -> None: - """Check that an unstructured sparse conv2d rewrite is correctly sparse.""" - - rewrite = UnstructuredSparsityRewrite( - "conv2d-unstructured-sparsity", conv2d_sparsity_unstructured_rewrite - ) - input_shape = np.array([28, 28, 3]) - output_shape = np.array([14, 14, 3]) - model = rewrite( - input_shape=input_shape, output_shape=output_shape, final_sparsity=0.50 - ) - model = rewrite.post_process(model) - assert not rewrite.check_optimization(model) - log_records = caplog.records - warning_messages = [x.message for x in log_records if x.levelno == 30] - assert ( - re.search( - r"\nWARNING: Found total sparsity of rewrite model: \d.\d\d " - r"expected total sparsity to be: 0.50\n", - warning_messages[0], - ) - is not None - ) - model = rewrite( - input_shape=input_shape, - output_shape=output_shape, - final_sparsity=0.5, - end_step=120, - ) - train_rewrite_model( - input_shape=input_shape, - output_shape=output_shape, - rewrite_model=model, - epochs=10, - ) - model = rewrite.post_process(model) - assert rewrite.check_optimization(model) - - -def test_rewrite_fully_connected_unstructured_sparsity( +@pytest.mark.parametrize( + "rewrite_name, input_shape, output_shape, layer_type", + [ + ( + "conv2d-unstructured-sparsity", + np.array([28, 28, 3]), + np.array([14, 14, 3]), + None, + ), + ( + "depthwise-separable-conv2d-unstructured-sparsity", + np.array([28, 28, 3]), + np.array([14, 14, 3]), + keras.layers.SeparableConv2D, + ), + ("fully-connected-unstructured-sparsity", (28, 28), 10, None), + ], +) +def test_rewrite_unstructured_sparsity( + rewrite_name: str, + input_shape: np.ndarray | tuple, + output_shape: np.ndarray | int, + layer_type: keras.layers.Layer | None, caplog: pytest.LogCaptureFixture, ) -> None: - """Check that an unstructured sparse FC rewrite is correctly sparse.""" - + """Check that an unstructured sparse conv2d rewrite is correctly sparse.""" + rewrite_instance = ( + fc_sparsity_unstructured_rewrite + if "fully-connected" in rewrite_name + else conv2d_sparsity_unstructured_rewrite + ) + layer_type = [{"layer_type": layer_type}] if layer_type else [] rewrite = UnstructuredSparsityRewrite( - "fully-connected-unstructured-sparsity", fc_sparsity_unstructured_rewrite + rewrite_name, cast(RewriteCallable, rewrite_instance) ) - input_shape = (28, 28) - output_shape = 10 model = rewrite( - input_shape=tuple(input_shape), output_shape=output_shape, final_sparsity=0.5 + input_shape=input_shape, output_shape=output_shape, final_sparsity=0.50 ) model = rewrite.post_process(model) assert not rewrite.check_optimization(model) @@ -304,10 +315,43 @@ def test_rewrite_fully_connected_unstructured_sparsity( "rewrite_type, expected_layers, quant", [ ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False], + [ + "conv2d", + [keras.layers.Conv2D, keras.layers.BatchNormalization, keras.layers.ReLU], + False, + ], ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False], ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True], ["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], False], - ["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], True], + [ + "depthwise-separable-conv2d", + [ + keras.layers.SeparableConv2D, + keras.layers.BatchNormalization, + keras.layers.ReLU, + ], + False, + ], + [ + "depthwise-separable-conv2d-clustering", + [ClusterWeights, ClusterWeights, ClusterWeights], + False, + ], + [ + "depthwise-separable-conv2d-clustering", + [ClusterWeights, ClusterWeights, ClusterWeights], + True, + ], + [ + "depthwise-separable-conv2d-sparsity", + [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude], + False, + ], + [ + "depthwise-separable-conv2d-sparsity", + [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude], + True, + ], ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], False], ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], True], [ @@ -411,9 +455,14 @@ def test_register_rewrite_function() -> None: def test_builtin_rewrite_names() -> None: """Test if all builtin rewrites are properly registered and returned.""" assert set(RewritingOptimizer.builtin_rewrite_names()) == { + "conv2d", "conv2d-clustering", "conv2d-sparsity", "conv2d-unstructured-sparsity", + "depthwise-separable-conv2d", + "depthwise-separable-conv2d-clustering", + "depthwise-separable-conv2d-sparsity", + "depthwise-separable-conv2d-unstructured-sparsity", "fully-connected", "fully-connected-clustering", "fully-connected-sparsity", |