aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_optimizations_clustering.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_tensorflow_optimizations_clustering.py')
-rw-r--r--tests/test_nn_tensorflow_optimizations_clustering.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/tests/test_nn_tensorflow_optimizations_clustering.py b/tests/test_nn_tensorflow_optimizations_clustering.py
index ba7aea3..d3c0da6 100644
--- a/tests/test_nn_tensorflow_optimizations_clustering.py
+++ b/tests/test_nn_tensorflow_optimizations_clustering.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module optimizations/clustering."""
from __future__ import annotations
@@ -8,6 +8,7 @@ from pathlib import Path
import pytest
import tensorflow as tf
+from flaky import flaky
from mlia.nn.tensorflow.optimizations.clustering import Clusterer
from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
@@ -81,6 +82,12 @@ def _test_sparsity(
assert num_sparse_layers == expected_num_sparse_layers
+# This test fails sporadically for stochastic reasons, due to a threshold not being met.
+# Re-running the test will help. We are yet to find a more deterministic approach
+# to run the test, and in the meantime we classify it as a known issue.
+# Additionally, flaky is (as of 2023) untyped and thus we need to silence the
+# warning from mypy.
+@flaky # type: ignore
@pytest.mark.parametrize("target_num_clusters", (32, 4))
@pytest.mark.parametrize("sparsity_aware", (False, True))
@pytest.mark.parametrize("layers_to_cluster", (["conv1"], ["conv1", "conv2"], None))