aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_nn_tensorflow_tflite_metrics.py
blob: 805f7d1c75728a6a6a2093ef2c46b3f783ccb8e3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module utils/tflite_metrics."""
import os
import tempfile
from math import isclose
from pathlib import Path
from typing import Generator
from typing import List

import numpy as np
import pytest
import tensorflow as tf

from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode
from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics


def _dummy_keras_model() -> tf.keras.Model:
    # Create a dummy model
    keras_model = tf.keras.Sequential(
        [
            tf.keras.Input(shape=(8, 8, 3)),
            tf.keras.layers.Conv2D(4, 3),
            tf.keras.layers.DepthwiseConv2D(3),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(8),
        ]
    )
    return keras_model


def _sparse_binary_keras_model() -> tf.keras.Model:
    def get_sparse_weights(shape: List[int]) -> np.array:
        weights = np.zeros(shape)
        with np.nditer(weights, op_flags=["writeonly"]) as weight_iterator:
            for idx, value in enumerate(weight_iterator):
                if idx % 2 == 0:
                    value[...] = 1.0
        return weights

    keras_model = _dummy_keras_model()
    # Assign weights to have 0.5 sparsity
    for layer in keras_model.layers:
        if not isinstance(layer, tf.keras.layers.Flatten):
            weight = layer.weights[0]
            weight.assign(get_sparse_weights(weight.shape))
            print(layer)
            print(weight.numpy())
    return keras_model


@pytest.fixture(scope="class", name="tflite_file")
def fixture_tflite_file() -> Generator:
    """Generate temporary TFLite file for tests."""
    converter = tf.lite.TFLiteConverter.from_keras_model(_sparse_binary_keras_model())
    tflite_model = converter.convert()
    with tempfile.TemporaryDirectory() as tmp_dir:
        file = os.path.join(tmp_dir, "test.tflite")
        Path(file).write_bytes(tflite_model)
        yield file


@pytest.fixture(scope="function", name="metrics")
def fixture_metrics(tflite_file: str) -> TFLiteMetrics:
    """Generate metrics file for a given TFLite model."""
    return TFLiteMetrics(tflite_file)


class TestTFLiteMetrics:
    """Tests for module TFLite_metrics."""

    @staticmethod
    def test_sparsity(metrics: TFLiteMetrics) -> None:
        """Test sparsity."""
        # Create new instance with a dummy TFLite file
        # Check sparsity calculation
        sparsity_per_layer = metrics.sparsity_per_layer()
        for name, sparsity in sparsity_per_layer.items():
            assert isclose(sparsity, 0.5), "Layer '{}' has incorrect sparsity.".format(
                name
            )
        assert isclose(metrics.sparsity_overall(), 0.5)

    @staticmethod
    def test_clusters(metrics: TFLiteMetrics) -> None:
        """Test clusters."""
        # NUM_CLUSTERS_PER_AXIS and NUM_CLUSTERS_MIN_MAX can be handled together
        for mode in [
            ReportClusterMode.NUM_CLUSTERS_PER_AXIS,
            ReportClusterMode.NUM_CLUSTERS_MIN_MAX,
        ]:
            num_unique_weights = metrics.num_unique_weights(mode)
            for name, num_unique_per_axis in num_unique_weights.items():
                for num_unique in num_unique_per_axis:
                    assert (
                        num_unique == 2
                    ), "Layer '{}' has incorrect number of clusters.".format(name)
        # NUM_CLUSTERS_HISTOGRAM
        hists = metrics.num_unique_weights(ReportClusterMode.NUM_CLUSTERS_HISTOGRAM)
        assert hists
        for name, hist in hists.items():
            assert hist
            for idx, num_axes in enumerate(hist):
                # The histogram starts with the bin for for num_clusters == 1
                num_clusters = idx + 1
                msg = (
                    "Histogram of layer '{}': There are {} axes with {} "
                    "clusters".format(name, num_axes, num_clusters)
                )
                if num_clusters == 2:
                    assert num_axes > 0, "{}, but there should be at least one.".format(
                        msg
                    )
                else:
                    assert num_axes == 0, "{}, but there should be none.".format(msg)

    @staticmethod
    @pytest.mark.parametrize("report_sparsity", (False, True))
    @pytest.mark.parametrize("report_cluster_mode", ReportClusterMode)
    @pytest.mark.parametrize("max_num_clusters", (-1, 8))
    @pytest.mark.parametrize("verbose", (False, True))
    def test_summary(
        tflite_file: str,
        report_sparsity: bool,
        report_cluster_mode: ReportClusterMode,
        max_num_clusters: int,
        verbose: bool,
    ) -> None:
        """Test the summary function."""
        for metrics in [TFLiteMetrics(tflite_file), TFLiteMetrics(tflite_file, [])]:
            metrics.summary(
                report_sparsity=report_sparsity,
                report_cluster_mode=report_cluster_mode,
                max_num_clusters=max_num_clusters,
                verbose=verbose,
            )