aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_rewrite.py
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-05-17 09:05:03 +0100
committerNathan Bailey <nathan.bailey@arm.com>2024-05-21 16:51:15 +0100
commit3002baa6b1fd226d38bcfabfe3dc15556833be6a (patch)
treea7158e696f1b61cf98cef1de24f056bf9a71c6cd /tests/test_nn_rewrite_core_rewrite.py
parent856111bcaef76c60303bdf2ae7cbf718d93d1df4 (diff)
downloadmlia-main.tar.gz
fix: Extend docstrings in the rewrite moduleHEADmain
Rework doctrings in rewrite functions based on recent changes Resolves MLIA-944 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I31a37e17a296f8a16d0db408d48c6de65c05300e
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py24
1 files changed, 13 insertions, 11 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index dc938ce..97b0b96 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -158,7 +158,7 @@ def train_rewrite_model(
def test_rewrite_fully_connected_clustering(caplog: pytest.LogCaptureFixture) -> None:
"""Check that fully connected clustering rewrite model
- has the set number of clusters"""
+ has the set number of clusters."""
rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
model = rewrite(input_shape=(28, 28), output_shape=10)
@@ -178,7 +178,7 @@ def test_rewrite_fully_connected_clustering(caplog: pytest.LogCaptureFixture) ->
def test_rewrite_conv2d_clustering(caplog: pytest.LogCaptureFixture) -> None:
- """Check that conv2d clustering rewrite model has the set number of clusters"""
+ """Check that conv2d clustering rewrite model has the set number of clusters."""
rewrite = ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite)
model = rewrite(
@@ -200,23 +200,25 @@ def test_rewrite_conv2d_clustering(caplog: pytest.LogCaptureFixture) -> None:
def test_rewrite_clustering_error_handling() -> None:
- """Check that model has the set number of clusters
- and that when quantized the number of clusters
- remain."""
+ """
+ Check that the clustering rewrite check_optimization
+ function returns the current error.
+ """
rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
model = rewrite(input_shape=(28, 28), output_shape=10)
with pytest.raises(
ValueError,
- match=(
- r"Expected check_preserved_quantize to have argument number_of_clusters"
- ),
+ match=(r"Expected check_optimization to have argument number_of_clusters"),
):
rewrite.check_optimization(model, bad_arg_name=25)
def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> None:
- """Check that."""
+ """
+ Check that sparse fully connected
+ rewrite model is correctly sparse.
+ """
rewrite = Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite)
input_shape = (28, 28)
@@ -243,7 +245,7 @@ def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> N
def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None:
- """Check that."""
+ """Check that sparse conv2d rewrite model is correctly sparse."""
rewrite = Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite)
input_shape = np.array([28, 28, 3])
@@ -300,7 +302,7 @@ def test_rewriting_optimizer( # pylint: disable=too-many-locals
expected_layers: list[object],
quant: bool,
) -> None:
- """Test fc_layer rewrite process with rewrite type fully-connected."""
+ """Test the rewrite process with all rewrite types."""
tfrecord = test_tfrecord if quant else test_tfrecord_fp32
tflite_model = test_tflite_model if quant else test_tflite_model_fp32