aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_optimizations_clustering.py
blob: 11036ad0bf4bcaaf9e897029325fa294fe7f0177 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module optimizations/clustering."""
from __future__ import annotations

import math
from pathlib import Path

import pytest
from flaky import flaky
from keras.api._v2 import keras  # Temporary workaround for now: MLIA-1107

from mlia.nn.tensorflow.optimizations.clustering import Clusterer
from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
from mlia.nn.tensorflow.optimizations.pruning import Pruner
from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode
from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
from tests.utils.common import get_dataset
from tests.utils.common import train_model


def _prune_model(
    model: keras.Model, target_sparsity: float, layers_to_prune: list[str] | None
) -> keras.Model:
    x_train, y_train = get_dataset()
    batch_size = 1
    num_epochs = 1

    pruner = Pruner(
        model,
        PruningConfiguration(
            target_sparsity,
            layers_to_prune,
            x_train,
            y_train,
            batch_size,
            num_epochs,
        ),
    )
    pruner.apply_optimization()
    pruned_model = pruner.get_model()

    return pruned_model


def _test_num_unique_weights(
    metrics: TFLiteMetrics,
    target_num_clusters: int,
    layers_to_cluster: list[str] | None,
) -> None:
    clustered_uniqueness = metrics.num_unique_weights(
        ReportClusterMode.NUM_CLUSTERS_PER_AXIS
    )

    num_clustered_layers = 0
    for layer_num_clusters in clustered_uniqueness.values():
        if layer_num_clusters[0] <= target_num_clusters:
            num_clustered_layers += 1

    expected_num_clustered_layers = len(layers_to_cluster or clustered_uniqueness)
    assert num_clustered_layers == expected_num_clustered_layers


def _test_sparsity(
    metrics: TFLiteMetrics,
    target_sparsity: float,
    layers_to_cluster: list[str] | None,
) -> None:
    error_margin = 0.03
    pruned_sparsity = metrics.sparsity_per_layer()

    num_sparse_layers = 0
    for layer_sparsity in pruned_sparsity.values():
        if math.isclose(layer_sparsity, target_sparsity, abs_tol=error_margin):
            num_sparse_layers += 1

    # make sure we are having exactly as many sparse layers as we wanted
    expected_num_sparse_layers = len(layers_to_cluster or pruned_sparsity)
    assert num_sparse_layers == expected_num_sparse_layers


# This test fails sporadically for stochastic reasons, due to a threshold not being met.
# Re-running the test will help. We are yet to find a more deterministic approach
# to run the test, and in the meantime we classify it as a known issue.
# Additionally, flaky is (as of 2023) untyped and thus we need to silence the
# warning from mypy.
@flaky(max_runs=4, min_passes=1)  # type: ignore
@pytest.mark.parametrize("target_num_clusters", (32, 4))
@pytest.mark.parametrize("sparsity_aware", (False, True))
@pytest.mark.parametrize("layers_to_cluster", (["conv1"], ["conv1", "conv2"], None))
def test_cluster_simple_model_fully(
    target_num_clusters: int,
    sparsity_aware: bool,
    layers_to_cluster: list[str] | None,
    tmp_path: Path,
    test_keras_model: Path,
) -> None:
    """Simple MNIST test to see if clustering works correctly."""
    target_sparsity = 0.5

    base_model = keras.models.load_model(str(test_keras_model))
    train_model(base_model)

    if sparsity_aware:
        base_model = _prune_model(base_model, target_sparsity, layers_to_cluster)

    clusterer = Clusterer(
        base_model,
        ClusteringConfiguration(
            target_num_clusters,
            layers_to_cluster,
        ),
    )
    clusterer.apply_optimization()
    clustered_model = clusterer.get_model()

    temp_file = tmp_path / "test_cluster_simple_model_fully_after.tflite"
    convert_to_tflite(clustered_model, output_path=temp_file)
    clustered_tflite_metrics = TFLiteMetrics(str(temp_file))

    _test_num_unique_weights(
        clustered_tflite_metrics, target_num_clusters, layers_to_cluster
    )

    if sparsity_aware:
        _test_sparsity(clustered_tflite_metrics, target_sparsity, layers_to_cluster)