diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-17 09:05:03 +0100 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-21 16:51:15 +0100 |
commit | 3002baa6b1fd226d38bcfabfe3dc15556833be6a (patch) | |
tree | a7158e696f1b61cf98cef1de24f056bf9a71c6cd /tests/test_nn_rewrite_core_rewrite.py | |
parent | 856111bcaef76c60303bdf2ae7cbf718d93d1df4 (diff) | |
download | mlia-main.tar.gz |
Rework doctrings in rewrite functions based on recent changes
Resolves MLIA-944
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I31a37e17a296f8a16d0db408d48c6de65c05300e
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index dc938ce..97b0b96 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -158,7 +158,7 @@ def train_rewrite_model( def test_rewrite_fully_connected_clustering(caplog: pytest.LogCaptureFixture) -> None: """Check that fully connected clustering rewrite model - has the set number of clusters""" + has the set number of clusters.""" rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) model = rewrite(input_shape=(28, 28), output_shape=10) @@ -178,7 +178,7 @@ def test_rewrite_fully_connected_clustering(caplog: pytest.LogCaptureFixture) -> def test_rewrite_conv2d_clustering(caplog: pytest.LogCaptureFixture) -> None: - """Check that conv2d clustering rewrite model has the set number of clusters""" + """Check that conv2d clustering rewrite model has the set number of clusters.""" rewrite = ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite) model = rewrite( @@ -200,23 +200,25 @@ def test_rewrite_conv2d_clustering(caplog: pytest.LogCaptureFixture) -> None: def test_rewrite_clustering_error_handling() -> None: - """Check that model has the set number of clusters - and that when quantized the number of clusters - remain.""" + """ + Check that the clustering rewrite check_optimization + function returns the current error. + """ rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) model = rewrite(input_shape=(28, 28), output_shape=10) with pytest.raises( ValueError, - match=( - r"Expected check_preserved_quantize to have argument number_of_clusters" - ), + match=(r"Expected check_optimization to have argument number_of_clusters"), ): rewrite.check_optimization(model, bad_arg_name=25) def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> None: - """Check that.""" + """ + Check that sparse fully connected + rewrite model is correctly sparse. + """ rewrite = Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite) input_shape = (28, 28) @@ -243,7 +245,7 @@ def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> N def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None: - """Check that.""" + """Check that sparse conv2d rewrite model is correctly sparse.""" rewrite = Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite) input_shape = np.array([28, 28, 3]) @@ -300,7 +302,7 @@ def test_rewriting_optimizer( # pylint: disable=too-many-locals expected_layers: list[object], quant: bool, ) -> None: - """Test fc_layer rewrite process with rewrite type fully-connected.""" + """Test the rewrite process with all rewrite types.""" tfrecord = test_tfrecord if quant else test_tfrecord_fp32 tflite_model = test_tflite_model if quant else test_tflite_model_fp32 |