aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/tflite_metrics.py
blob: b29fab3db2314d83ee41a0b02a946f238792ec82 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
Contains class TFLiteMetrics to calculate metrics from a TFLite file.

These metrics include:
* Sparsity (per layer and overall)
* Unique weights (clusters) (per layer)
* gzip compression ratio
"""
import os
from enum import Enum
from pprint import pprint
from typing import Any
from typing import List
from typing import Optional

import numpy as np
import tensorflow as tf

DEFAULT_IGNORE_LIST = [
    "relu",
    "pooling",
    "reshape",
    "identity",
    "input",
    "add",
    "flatten",
    "StatefulPartitionedCall",
    "bias",
]


def calculate_num_unique_weights(weights: np.array) -> int:
    """Calculate the number of unique weights in the given weights."""
    num_unique_weights = len(np.unique(weights))
    return num_unique_weights


def calculate_num_unique_weights_per_axis(weights: np.array, axis: int) -> List[int]:
    """Calculate unique weights per quantization axis."""
    # Make quantized dimension the first dimension
    weights_trans = np.swapaxes(weights, 0, axis)
    num_uniques_weights = [
        calculate_num_unique_weights(weights_trans[i])
        for i in range(weights_trans.shape[0])
    ]
    assert num_uniques_weights
    return num_uniques_weights


class SparsityAccumulator:
    """Helper class to accumulate sparsity over several layers."""

    def __init__(self) -> None:
        """Create an empty accumulator."""
        self.total_non_zero_weights: int = 0
        self.total_weights: int = 0

    def __call__(self, weights: np.array) -> None:
        """Update the accumulator with the given weights."""
        non_zero_weights = np.count_nonzero(weights)
        self.total_non_zero_weights += non_zero_weights
        self.total_weights += weights.size

    def sparsity(self) -> float:
        """Calculate the sparsity for all added weights."""
        return 1.0 - self.total_non_zero_weights / float(self.total_weights)


def calculate_sparsity(
    weights: np.array, accumulator: Optional[SparsityAccumulator] = None
) -> float:
    """
    Calculate the sparsity for the given weights.

    If the accumulator is passed, it is updated as well.
    """
    non_zero_weights = np.count_nonzero(weights)
    sparsity = 1.0 - float(non_zero_weights) / float(weights.size)
    if accumulator is not None:
        accumulator(weights)
    return sparsity


class ReportClusterMode(Enum):
    """Specifies the way cluster values are aggregated and reported."""

    NUM_CLUSTERS_HISTOGRAM = (
        "A histogram of the number of clusters per axis. "
        "I.e. the number of clusters is the index of the list (the bin) and "
        "the value is the number of axes that have this number of clusters. "
        "The first bin is 1."
    )
    NUM_CLUSTERS_PER_AXIS = "Number of clusters (unique weights) per axis."
    NUM_CLUSTERS_MIN_MAX = "Min/max number of clusters over all axes."


class TFLiteMetrics:
    """Helper class to calculate metrics from a TFLite file.

    Metrics include:
    * sparsity (per-layer and overall)
    * number of unique weights (clusters) per layer
    * File compression via gzip
    """

    def __init__(
        self, tflite_file: str, ignore_list: Optional[List[str]] = None
    ) -> None:
        """Load the TFLite file and filter layers."""
        self.tflite_file = tflite_file
        if ignore_list is None:
            ignore_list = DEFAULT_IGNORE_LIST
        self.ignore_list = [ignore.casefold() for ignore in ignore_list]
        # Initialize the TFLite interpreter with the model file
        self.interpreter = tf.lite.Interpreter(model_path=tflite_file)
        self.interpreter.allocate_tensors()
        self.details: dict = {}

        def ignore(details: dict) -> bool:
            name = details["name"].casefold()
            if not name:
                return True
            for to_ignore in self.ignore_list:
                if to_ignore in name:
                    return True
            return False

        self.filtered_details = {
            details["name"]: details
            for details in self.interpreter.get_tensor_details()
            if not ignore(details)
        }

    def get_tensor(self, details: dict) -> Any:
        """Return the weights/tensor specified in the given details map."""
        return self.interpreter.tensor(details["index"])()

    def sparsity_per_layer(self) -> dict:
        """Return a dict of layer name and sparsity value."""
        sparsity = {
            name: calculate_sparsity(self.get_tensor(details))
            for name, details in self.filtered_details.items()
        }
        return sparsity

    def sparsity_overall(self) -> float:
        """Return an instance of SparsityAccumulator for the filtered layers."""
        acc = SparsityAccumulator()
        for details in self.filtered_details.values():
            acc(self.get_tensor(details))
        return acc.sparsity()

    def calc_num_clusters_per_axis(self, details: dict) -> List[int]:
        """Calculate number of clusters per axis."""
        quant_params = details["quantization_parameters"]
        per_axis = len(quant_params["zero_points"]) > 1
        if per_axis:
            # Calculate unique weights along quantization axis
            axis = quant_params["quantized_dimension"]
            return calculate_num_unique_weights_per_axis(self.get_tensor(details), axis)

        # Calculate unique weights over all axes/dimensions
        return [calculate_num_unique_weights(self.get_tensor(details))]

    def num_unique_weights(self, mode: ReportClusterMode) -> dict:
        """Return a dict of layer name and number of unique weights."""
        aggregation_func = None
        if mode == ReportClusterMode.NUM_CLUSTERS_PER_AXIS:
            aggregation_func = self.calc_num_clusters_per_axis
        elif mode == ReportClusterMode.NUM_CLUSTERS_MIN_MAX:

            def cluster_min_max(details: dict) -> List[int]:
                num_clusters = self.calc_num_clusters_per_axis(details)
                return [min(num_clusters), max(num_clusters)]

            aggregation_func = cluster_min_max
        elif mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM:

            def cluster_hist(details: dict) -> List[int]:
                num_clusters = self.calc_num_clusters_per_axis(details)
                max_num = max(num_clusters)
                hist = [0] * (max_num)
                for num in num_clusters:
                    idx = num - 1
                    hist[idx] += 1
                return hist

            aggregation_func = cluster_hist
        else:
            raise NotImplementedError(
                "ReportClusterMode '{}' not implemented.".format(mode)
            )
        uniques = {
            name: aggregation_func(details)
            for name, details in self.filtered_details.items()
        }
        return uniques

    @staticmethod
    def _prettify_name(name: str) -> str:
        if name.startswith("model"):
            return name.split("/", 1)[1]
        return name

    def summary(
        self,
        report_sparsity: bool,
        report_cluster_mode: ReportClusterMode = None,
        max_num_clusters: int = 32,
        verbose: bool = False,
    ) -> None:
        """Print a summary of all the model information."""
        print("Model file: {}".format(self.tflite_file))
        print("#" * 80)
        print(" " * 28 + "### TFLITE SUMMARY ###")
        print("File: {}".format(os.path.abspath(self.tflite_file)))
        print("Input(s):")
        self._print_in_outs(self.interpreter.get_input_details(), verbose)
        print("Output(s):")
        self._print_in_outs(self.interpreter.get_output_details(), verbose)
        print()
        header = ["Layer", "Index", "Type", "Num weights"]
        if report_sparsity:
            header.append("Sparsity")
        rows = []
        sparsity_accumulator = SparsityAccumulator()
        for details in self.filtered_details.values():
            name = details["name"]
            weights = self.get_tensor(details)
            row = [
                self._prettify_name(name),
                details["index"],
                weights.dtype,
                weights.size,
            ]
            if report_sparsity:
                sparsity = calculate_sparsity(weights, sparsity_accumulator)
                row.append("{:.2f}".format(sparsity))
            rows.append(row)
            if verbose:
                # Print cluster centroids
                print("{} cluster centroids:".format(name))
                pprint(np.unique(weights))
        # Add summary/overall values
        empty_row = ["" for _ in range(len(header))]
        summary_row = empty_row
        summary_row[header.index("Layer")] = "=> OVERALL"
        summary_row[header.index("Num weights")] = str(
            sparsity_accumulator.total_weights
        )
        if report_sparsity:
            summary_row[header.index("Sparsity")] = "{:.2f}".format(
                sparsity_accumulator.sparsity()
            )
        rows.append(summary_row)
        # Report detailed cluster info
        if report_cluster_mode is not None:
            print()
            self._print_cluster_details(report_cluster_mode, max_num_clusters)
        print("#" * 80)

    def _print_cluster_details(
        self, report_cluster_mode: ReportClusterMode, max_num_clusters: int
    ) -> None:
        print("{}:\n{}".format(report_cluster_mode.name, report_cluster_mode.value))
        num_clusters = self.num_unique_weights(report_cluster_mode)
        if (
            report_cluster_mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM
            and max_num_clusters > 0
        ):
            # Only show cluster histogram if there are not more than
            # max_num_clusters. This is a workaround for not showing a huge
            # histogram for unclustered layers.
            for name, value in num_clusters.items():
                if len(value) > max_num_clusters:
                    num_clusters[name] = "More than {} unique values.".format(
                        max_num_clusters
                    )
        for name, nums in num_clusters.items():
            print("- {}: {}".format(self._prettify_name(name), nums))

    @staticmethod
    def _print_in_outs(ios: List[dict], verbose: bool = False) -> None:
        for item in ios:
            if verbose:
                pprint(item)
            else:
                print(
                    "- {} ({}): {}".format(
                        item["name"],
                        np.dtype(item["dtype"]).name,
                        item["shape"],
                    )
                )