From 0efca3cadbad5517a59884576ddb90cfe7ac30f8 Mon Sep 17 00:00:00 2001 From: Diego Russo Date: Mon, 30 May 2022 13:34:14 +0100 Subject: Add MLIA codebase Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd --- src/mlia/nn/tensorflow/__init__.py | 3 + src/mlia/nn/tensorflow/config.py | 134 ++++++++++ src/mlia/nn/tensorflow/optimizations/__init__.py | 3 + src/mlia/nn/tensorflow/optimizations/clustering.py | 109 ++++++++ src/mlia/nn/tensorflow/optimizations/common.py | 29 ++ src/mlia/nn/tensorflow/optimizations/pruning.py | 168 ++++++++++++ src/mlia/nn/tensorflow/optimizations/select.py | 179 +++++++++++++ src/mlia/nn/tensorflow/tflite_metrics.py | 296 +++++++++++++++++++++ src/mlia/nn/tensorflow/utils.py | 149 +++++++++++ 9 files changed, 1070 insertions(+) create mode 100644 src/mlia/nn/tensorflow/__init__.py create mode 100644 src/mlia/nn/tensorflow/config.py create mode 100644 src/mlia/nn/tensorflow/optimizations/__init__.py create mode 100644 src/mlia/nn/tensorflow/optimizations/clustering.py create mode 100644 src/mlia/nn/tensorflow/optimizations/common.py create mode 100644 src/mlia/nn/tensorflow/optimizations/pruning.py create mode 100644 src/mlia/nn/tensorflow/optimizations/select.py create mode 100644 src/mlia/nn/tensorflow/tflite_metrics.py create mode 100644 src/mlia/nn/tensorflow/utils.py (limited to 'src/mlia/nn/tensorflow') diff --git a/src/mlia/nn/tensorflow/__init__.py b/src/mlia/nn/tensorflow/__init__.py new file mode 100644 index 0000000..ff061c1 --- /dev/null +++ b/src/mlia/nn/tensorflow/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TensorFlow related module.""" diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py new file mode 100644 index 0000000..d3235d7 --- /dev/null +++ b/src/mlia/nn/tensorflow/config.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Model configuration.""" +import logging +from pathlib import Path +from typing import cast +from typing import Dict +from typing import List +from typing import Union + +import tensorflow as tf + +from mlia.core.context import Context +from mlia.nn.tensorflow.utils import convert_tf_to_tflite +from mlia.nn.tensorflow.utils import convert_to_tflite +from mlia.nn.tensorflow.utils import is_keras_model +from mlia.nn.tensorflow.utils import is_tf_model +from mlia.nn.tensorflow.utils import is_tflite_model +from mlia.nn.tensorflow.utils import save_tflite_model + +logger = logging.getLogger(__name__) + + +class ModelConfiguration: + """Base class for model configuration.""" + + def __init__(self, model_path: Union[str, Path]) -> None: + """Init model configuration instance.""" + self.model_path = str(model_path) + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + raise NotImplementedError() + + def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel": + """Convert model to Keras format.""" + raise NotImplementedError() + + +class KerasModel(ModelConfiguration): + """Keras model configuration. + + Supports all models supported by Keras API: saved model, H5, HDF5 + """ + + def get_keras_model(self) -> tf.keras.Model: + """Return associated Keras model.""" + return tf.keras.models.load_model(self.model_path) + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + logger.info("Converting Keras to TFLite ...") + + converted_model = convert_to_tflite(self.get_keras_model(), quantized) + logger.info("Done\n") + + save_tflite_model(converted_model, tflite_model_path) + logger.debug( + "Model %s converted and saved to %s", self.model_path, tflite_model_path + ) + + return TFLiteModel(tflite_model_path) + + def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel": + """Convert model to Keras format.""" + return self + + +class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method + """TFLite model configuration.""" + + def input_details(self) -> List[Dict]: + """Get model's input details.""" + interpreter = tf.lite.Interpreter(model_path=self.model_path) + return cast(List[Dict], interpreter.get_input_details()) + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + return self + + +class TfModel(ModelConfiguration): # pylint: disable=abstract-method + """TensorFlow model configuration. + + Supports models supported by TensorFlow API (not Keras) + """ + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + converted_model = convert_tf_to_tflite(self.model_path, quantized) + save_tflite_model(converted_model, tflite_model_path) + + return TFLiteModel(tflite_model_path) + + +def get_model(model: Union[Path, str]) -> "ModelConfiguration": + """Return the model object.""" + if is_tflite_model(model): + return TFLiteModel(model) + + if is_keras_model(model): + return KerasModel(model) + + if is_tf_model(model): + return TfModel(model) + + raise Exception( + "The input model format is not supported" + "(supported formats: TFLite, Keras, TensorFlow saved model)!" + ) + + +def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel": + """Convert input model to TFLite and returns TFLiteModel object.""" + tflite_model_path = ctx.get_model_path("converted_model.tflite") + converted_model = get_model(model) + + return converted_model.convert_to_tflite(tflite_model_path, True) + + +def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel": + """Convert input model to Keras and returns KerasModel object.""" + keras_model_path = ctx.get_model_path("converted_model.h5") + converted_model = get_model(model) + + return converted_model.convert_to_keras(keras_model_path) diff --git a/src/mlia/nn/tensorflow/optimizations/__init__.py b/src/mlia/nn/tensorflow/optimizations/__init__.py new file mode 100644 index 0000000..201c130 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Optimizations module.""" diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py new file mode 100644 index 0000000..16d9e4b --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/clustering.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +""" +Contains class Clusterer that clusters unique weights per layer to a specified number. + +In order to do this, we need to have a base model and corresponding training data. +We also have to specify a subset of layers we want to cluster. For more details, +please refer to the documentation for TensorFlow Model Optimization Toolkit. +""" +from dataclasses import dataclass +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import tensorflow as tf +import tensorflow_model_optimization as tfmot +from tensorflow_model_optimization.python.core.clustering.keras.experimental import ( # pylint: disable=no-name-in-module + cluster as experimental_cluster, +) + +from mlia.nn.tensorflow.optimizations.common import Optimizer +from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration + + +@dataclass +class ClusteringConfiguration(OptimizerConfiguration): + """Clustering configuration.""" + + optimization_target: int + layers_to_optimize: Optional[List[str]] = None + + def __str__(self) -> str: + """Return string representation of the configuration.""" + return f"clustering: {self.optimization_target}" + + +class Clusterer(Optimizer): + """ + Clusterer class. + + Used to cluster a model to a specified number of unique weights per layer. + + Sample usage: + clusterer = Clusterer( + base_model, + optimizer_configuration) + + clusterer.apply_clustering() + clustered_model = clusterer.get_model() + """ + + def __init__( + self, model: tf.keras.Model, optimizer_configuration: ClusteringConfiguration + ): + """Init Clusterer instance.""" + self.model = model + self.optimizer_configuration = optimizer_configuration + + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" + return str(self.optimizer_configuration) + + def _setup_clustering_params(self) -> Dict[str, Any]: + CentroidInitialization = tfmot.clustering.keras.CentroidInitialization + return { + "number_of_clusters": self.optimizer_configuration.optimization_target, + "cluster_centroids_init": CentroidInitialization.LINEAR, + "preserve_sparsity": True, + } + + def _apply_clustering_to_layer( + self, layer: tf.keras.layers.Layer + ) -> tf.keras.layers.Layer: + layers_to_optimize = self.optimizer_configuration.layers_to_optimize + assert layers_to_optimize, "List of the layers to optimize is empty" + + if layer.name not in layers_to_optimize: + return layer + + clustering_params = self._setup_clustering_params() + return experimental_cluster.cluster_weights(layer, **clustering_params) + + def _init_for_clustering(self) -> None: + # Use `tf.keras.models.clone_model` to apply `apply_clustering_to_layer` + # to the layers of the model + if not self.optimizer_configuration.layers_to_optimize: + clustering_params = self._setup_clustering_params() + clustered_model = experimental_cluster.cluster_weights( + self.model, **clustering_params + ) + else: + clustered_model = tf.keras.models.clone_model( + self.model, clone_function=self._apply_clustering_to_layer + ) + + self.model = clustered_model + + def _strip_clustering(self) -> None: + self.model = tfmot.clustering.keras.strip_clustering(self.model) + + def apply_optimization(self) -> None: + """Apply all steps of clustering at once.""" + self._init_for_clustering() + self._strip_clustering() + + def get_model(self) -> tf.keras.Model: + """Get model.""" + return self.model diff --git a/src/mlia/nn/tensorflow/optimizations/common.py b/src/mlia/nn/tensorflow/optimizations/common.py new file mode 100644 index 0000000..1dce0b2 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/common.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Common items for the optimizations module.""" +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass + +import tensorflow as tf + + +@dataclass +class OptimizerConfiguration: + """Abstract optimizer configuration.""" + + +class Optimizer(ABC): + """Abstract class for the optimizer.""" + + @abstractmethod + def get_model(self) -> tf.keras.Model: + """Abstract method to return the model instance from the optimizer.""" + + @abstractmethod + def apply_optimization(self) -> None: + """Abstract method to apply optimization to the model.""" + + @abstractmethod + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py new file mode 100644 index 0000000..f629ba1 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +""" +Contains class Pruner to prune a model to a specified sparsity. + +In order to do this, we need to have a base model and corresponding training data. +We also have to specify a subset of layers we want to prune. For more details, +please refer to the documentation for TensorFlow Model Optimization Toolkit. +""" +from dataclasses import dataclass +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import tensorflow as tf +import tensorflow_model_optimization as tfmot +from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module + pruning_wrapper, +) + +from mlia.nn.tensorflow.optimizations.common import Optimizer +from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration + + +@dataclass +class PruningConfiguration(OptimizerConfiguration): + """Pruning configuration.""" + + optimization_target: float + layers_to_optimize: Optional[List[str]] = None + x_train: Optional[np.array] = None + y_train: Optional[np.array] = None + batch_size: int = 1 + num_epochs: int = 1 + + def __str__(self) -> str: + """Return string representation of the configuration.""" + return f"pruning: {self.optimization_target}" + + def has_training_data(self) -> bool: + """Return True if training data provided.""" + return self.x_train is not None and self.y_train is not None + + +class Pruner(Optimizer): + """ + Pruner class. Used to prune a model to a specified sparsity. + + Sample usage: + pruner = Pruner( + base_model, + optimizer_configuration) + + pruner.apply_pruning() + pruned_model = pruner.get_model() + """ + + def __init__( + self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration + ): + """Init Pruner instance.""" + self.model = model + self.optimizer_configuration = optimizer_configuration + + if not optimizer_configuration.has_training_data(): + mock_x_train, mock_y_train = self._mock_train_data() + + self.optimizer_configuration.x_train = mock_x_train + self.optimizer_configuration.y_train = mock_y_train + + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" + return str(self.optimizer_configuration) + + def _mock_train_data(self) -> Tuple[np.array, np.array]: + # get rid of the batch_size dimension in input and output shape + input_shape = tuple(x for x in self.model.input_shape if x is not None) + output_shape = tuple(x for x in self.model.output_shape if x is not None) + + return ( + np.random.rand(*input_shape), + np.random.randint(0, output_shape[-1], (output_shape[:-1])), + ) + + def _setup_pruning_params(self) -> dict: + return { + "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=0, + final_sparsity=self.optimizer_configuration.optimization_target, + begin_step=0, + end_step=self.optimizer_configuration.num_epochs, + frequency=1, + ), + } + + def _apply_pruning_to_layer( + self, layer: tf.keras.layers.Layer + ) -> tf.keras.layers.Layer: + layers_to_optimize = self.optimizer_configuration.layers_to_optimize + assert layers_to_optimize, "List of the layers to optimize is empty" + + if layer.name not in layers_to_optimize: + return layer + + pruning_params = self._setup_pruning_params() + return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params) + + def _init_for_pruning(self) -> None: + # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer` + # to the layers of the model + if not self.optimizer_configuration.layers_to_optimize: + pruning_params = self._setup_pruning_params() + prunable_model = tfmot.sparsity.keras.prune_low_magnitude( + self.model, **pruning_params + ) + else: + prunable_model = tf.keras.models.clone_model( + self.model, clone_function=self._apply_pruning_to_layer + ) + + self.model = prunable_model + + def _train_pruning(self) -> None: + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy() + self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"]) + + # Model callbacks + callbacks = [tfmot.sparsity.keras.UpdatePruningStep()] + + # Fitting data + self.model.fit( + self.optimizer_configuration.x_train, + self.optimizer_configuration.y_train, + batch_size=self.optimizer_configuration.batch_size, + epochs=self.optimizer_configuration.num_epochs, + callbacks=callbacks, + verbose=0, + ) + + def _assert_sparsity_reached(self) -> None: + for layer in self.model.layers: + if not isinstance(layer, pruning_wrapper.PruneLowMagnitude): + continue + + for weight in layer.layer.get_prunable_weights(): + nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight)) + all_weights = tf.keras.backend.get_value(weight).size + + np.testing.assert_approx_equal( + self.optimizer_configuration.optimization_target, + 1 - nonzero_weights / all_weights, + significant=2, + ) + + def _strip_pruning(self) -> None: + self.model = tfmot.sparsity.keras.strip_pruning(self.model) + + def apply_optimization(self) -> None: + """Apply all steps of pruning sequentially.""" + self._init_for_pruning() + self._train_pruning() + self._assert_sparsity_reached() + self._strip_pruning() + + def get_model(self) -> tf.keras.Model: + """Get model.""" + return self.model diff --git a/src/mlia/nn/tensorflow/optimizations/select.py b/src/mlia/nn/tensorflow/optimizations/select.py new file mode 100644 index 0000000..1b0c755 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/select.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for optimization selection.""" +import math +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Tuple +from typing import Union + +import tensorflow as tf + +from mlia.core.errors import ConfigurationError +from mlia.nn.tensorflow.config import KerasModel +from mlia.nn.tensorflow.optimizations.clustering import Clusterer +from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration +from mlia.nn.tensorflow.optimizations.common import Optimizer +from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration +from mlia.nn.tensorflow.optimizations.pruning import Pruner +from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration +from mlia.utils.types import is_list_of + + +class OptimizationSettings(NamedTuple): + """Optimization settings.""" + + optimization_type: str + optimization_target: Union[int, float] + layers_to_optimize: Optional[List[str]] + + @staticmethod + def create_from( + optimizer_params: List[Tuple[str, float]], + layers_to_optimize: Optional[List[str]] = None, + ) -> List["OptimizationSettings"]: + """Create optimization settings from the provided parameters.""" + return [ + OptimizationSettings( + optimization_type=opt_type, + optimization_target=opt_target, + layers_to_optimize=layers_to_optimize, + ) + for opt_type, opt_target in optimizer_params + ] + + def __str__(self) -> str: + """Return string representation.""" + return f"{self.optimization_type}: {self.optimization_target}" + + def next_target(self) -> "OptimizationSettings": + """Return next optimization target.""" + if self.optimization_type == "pruning": + next_target = round(min(self.optimization_target + 0.1, 0.9), 2) + return OptimizationSettings( + self.optimization_type, next_target, self.layers_to_optimize + ) + + if self.optimization_type == "clustering": + # return next lowest power of two for clustering + next_target = math.log(self.optimization_target, 2) + if next_target.is_integer(): + next_target -= 1 + + next_target = max(int(2 ** int(next_target)), 4) + return OptimizationSettings( + self.optimization_type, next_target, self.layers_to_optimize + ) + + raise Exception(f"Unknown optimization type {self.optimization_type}") + + +class MultiStageOptimizer(Optimizer): + """Optimizer with multiply stages.""" + + def __init__( + self, + model: tf.keras.Model, + optimizations: List[OptimizerConfiguration], + ) -> None: + """Init MultiStageOptimizer instance.""" + self.model = model + self.optimizations = optimizations + + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" + return " - ".join(str(opt) for opt in self.optimizations) + + def get_model(self) -> tf.keras.Model: + """Return optimized model.""" + return self.model + + def apply_optimization(self) -> None: + """Apply optimization to the model.""" + for config in self.optimizations: + optimizer = get_optimizer(self.model, config) + optimizer.apply_optimization() + self.model = optimizer.get_model() + + +def get_optimizer( + model: Union[tf.keras.Model, KerasModel], + config: Union[ + OptimizerConfiguration, OptimizationSettings, List[OptimizationSettings] + ], +) -> Optimizer: + """Get optimizer for provided configuration.""" + if isinstance(model, KerasModel): + model = model.get_keras_model() + + if isinstance(config, PruningConfiguration): + return Pruner(model, config) + + if isinstance(config, ClusteringConfiguration): + return Clusterer(model, config) + + if isinstance(config, OptimizationSettings) or is_list_of( + config, OptimizationSettings + ): + return _get_optimizer(model, config) # type: ignore + + raise ConfigurationError(f"Unknown optimization configuration {config}") + + +def _get_optimizer( + model: tf.keras.Model, + optimization_settings: Union[OptimizationSettings, List[OptimizationSettings]], +) -> Optimizer: + if isinstance(optimization_settings, OptimizationSettings): + optimization_settings = [optimization_settings] + + optimizer_configs = [] + for opt_type, opt_target, layers_to_optimize in optimization_settings: + _check_optimizer_params(opt_type, opt_target) + + opt_config = _get_optimizer_configuration( + opt_type, opt_target, layers_to_optimize + ) + optimizer_configs.append(opt_config) + + if len(optimizer_configs) == 1: + return get_optimizer(model, optimizer_configs[0]) + + return MultiStageOptimizer(model, optimizer_configs) + + +def _get_optimizer_configuration( + optimization_type: str, + optimization_target: Union[int, float], + layers_to_optimize: Optional[List[str]] = None, +) -> OptimizerConfiguration: + """Get optimizer configuration for provided parameters.""" + _check_optimizer_params(optimization_type, optimization_target) + + opt_type = optimization_type.lower() + if opt_type == "pruning": + return PruningConfiguration(optimization_target, layers_to_optimize) + + if opt_type == "clustering": + # make sure an integer is given as clustering target + if optimization_target == int(optimization_target): + return ClusteringConfiguration(int(optimization_target), layers_to_optimize) + + raise ConfigurationError( + "Optimization target should be a positive integer. " + f"Optimization target provided: {optimization_target}" + ) + + raise ConfigurationError(f"Unsupported optimization type: {optimization_type}") + + +def _check_optimizer_params( + optimization_type: str, optimization_target: Union[int, float] +) -> None: + """Check optimizer params.""" + if not optimization_target: + raise ConfigurationError("Optimization target is not provided") + + if not optimization_type: + raise ConfigurationError("Optimization type is not provided") diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py new file mode 100644 index 0000000..b29fab3 --- /dev/null +++ b/src/mlia/nn/tensorflow/tflite_metrics.py @@ -0,0 +1,296 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +""" +Contains class TFLiteMetrics to calculate metrics from a TFLite file. + +These metrics include: +* Sparsity (per layer and overall) +* Unique weights (clusters) (per layer) +* gzip compression ratio +""" +import os +from enum import Enum +from pprint import pprint +from typing import Any +from typing import List +from typing import Optional + +import numpy as np +import tensorflow as tf + +DEFAULT_IGNORE_LIST = [ + "relu", + "pooling", + "reshape", + "identity", + "input", + "add", + "flatten", + "StatefulPartitionedCall", + "bias", +] + + +def calculate_num_unique_weights(weights: np.array) -> int: + """Calculate the number of unique weights in the given weights.""" + num_unique_weights = len(np.unique(weights)) + return num_unique_weights + + +def calculate_num_unique_weights_per_axis(weights: np.array, axis: int) -> List[int]: + """Calculate unique weights per quantization axis.""" + # Make quantized dimension the first dimension + weights_trans = np.swapaxes(weights, 0, axis) + num_uniques_weights = [ + calculate_num_unique_weights(weights_trans[i]) + for i in range(weights_trans.shape[0]) + ] + assert num_uniques_weights + return num_uniques_weights + + +class SparsityAccumulator: + """Helper class to accumulate sparsity over several layers.""" + + def __init__(self) -> None: + """Create an empty accumulator.""" + self.total_non_zero_weights: int = 0 + self.total_weights: int = 0 + + def __call__(self, weights: np.array) -> None: + """Update the accumulator with the given weights.""" + non_zero_weights = np.count_nonzero(weights) + self.total_non_zero_weights += non_zero_weights + self.total_weights += weights.size + + def sparsity(self) -> float: + """Calculate the sparsity for all added weights.""" + return 1.0 - self.total_non_zero_weights / float(self.total_weights) + + +def calculate_sparsity( + weights: np.array, accumulator: Optional[SparsityAccumulator] = None +) -> float: + """ + Calculate the sparsity for the given weights. + + If the accumulator is passed, it is updated as well. + """ + non_zero_weights = np.count_nonzero(weights) + sparsity = 1.0 - float(non_zero_weights) / float(weights.size) + if accumulator is not None: + accumulator(weights) + return sparsity + + +class ReportClusterMode(Enum): + """Specifies the way cluster values are aggregated and reported.""" + + NUM_CLUSTERS_HISTOGRAM = ( + "A histogram of the number of clusters per axis. " + "I.e. the number of clusters is the index of the list (the bin) and " + "the value is the number of axes that have this number of clusters. " + "The first bin is 1." + ) + NUM_CLUSTERS_PER_AXIS = "Number of clusters (unique weights) per axis." + NUM_CLUSTERS_MIN_MAX = "Min/max number of clusters over all axes." + + +class TFLiteMetrics: + """Helper class to calculate metrics from a TFLite file. + + Metrics include: + * sparsity (per-layer and overall) + * number of unique weights (clusters) per layer + * File compression via gzip + """ + + def __init__( + self, tflite_file: str, ignore_list: Optional[List[str]] = None + ) -> None: + """Load the TFLite file and filter layers.""" + self.tflite_file = tflite_file + if ignore_list is None: + ignore_list = DEFAULT_IGNORE_LIST + self.ignore_list = [ignore.casefold() for ignore in ignore_list] + # Initialize the TFLite interpreter with the model file + self.interpreter = tf.lite.Interpreter(model_path=tflite_file) + self.interpreter.allocate_tensors() + self.details: dict = {} + + def ignore(details: dict) -> bool: + name = details["name"].casefold() + if not name: + return True + for to_ignore in self.ignore_list: + if to_ignore in name: + return True + return False + + self.filtered_details = { + details["name"]: details + for details in self.interpreter.get_tensor_details() + if not ignore(details) + } + + def get_tensor(self, details: dict) -> Any: + """Return the weights/tensor specified in the given details map.""" + return self.interpreter.tensor(details["index"])() + + def sparsity_per_layer(self) -> dict: + """Return a dict of layer name and sparsity value.""" + sparsity = { + name: calculate_sparsity(self.get_tensor(details)) + for name, details in self.filtered_details.items() + } + return sparsity + + def sparsity_overall(self) -> float: + """Return an instance of SparsityAccumulator for the filtered layers.""" + acc = SparsityAccumulator() + for details in self.filtered_details.values(): + acc(self.get_tensor(details)) + return acc.sparsity() + + def calc_num_clusters_per_axis(self, details: dict) -> List[int]: + """Calculate number of clusters per axis.""" + quant_params = details["quantization_parameters"] + per_axis = len(quant_params["zero_points"]) > 1 + if per_axis: + # Calculate unique weights along quantization axis + axis = quant_params["quantized_dimension"] + return calculate_num_unique_weights_per_axis(self.get_tensor(details), axis) + + # Calculate unique weights over all axes/dimensions + return [calculate_num_unique_weights(self.get_tensor(details))] + + def num_unique_weights(self, mode: ReportClusterMode) -> dict: + """Return a dict of layer name and number of unique weights.""" + aggregation_func = None + if mode == ReportClusterMode.NUM_CLUSTERS_PER_AXIS: + aggregation_func = self.calc_num_clusters_per_axis + elif mode == ReportClusterMode.NUM_CLUSTERS_MIN_MAX: + + def cluster_min_max(details: dict) -> List[int]: + num_clusters = self.calc_num_clusters_per_axis(details) + return [min(num_clusters), max(num_clusters)] + + aggregation_func = cluster_min_max + elif mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM: + + def cluster_hist(details: dict) -> List[int]: + num_clusters = self.calc_num_clusters_per_axis(details) + max_num = max(num_clusters) + hist = [0] * (max_num) + for num in num_clusters: + idx = num - 1 + hist[idx] += 1 + return hist + + aggregation_func = cluster_hist + else: + raise NotImplementedError( + "ReportClusterMode '{}' not implemented.".format(mode) + ) + uniques = { + name: aggregation_func(details) + for name, details in self.filtered_details.items() + } + return uniques + + @staticmethod + def _prettify_name(name: str) -> str: + if name.startswith("model"): + return name.split("/", 1)[1] + return name + + def summary( + self, + report_sparsity: bool, + report_cluster_mode: ReportClusterMode = None, + max_num_clusters: int = 32, + verbose: bool = False, + ) -> None: + """Print a summary of all the model information.""" + print("Model file: {}".format(self.tflite_file)) + print("#" * 80) + print(" " * 28 + "### TFLITE SUMMARY ###") + print("File: {}".format(os.path.abspath(self.tflite_file))) + print("Input(s):") + self._print_in_outs(self.interpreter.get_input_details(), verbose) + print("Output(s):") + self._print_in_outs(self.interpreter.get_output_details(), verbose) + print() + header = ["Layer", "Index", "Type", "Num weights"] + if report_sparsity: + header.append("Sparsity") + rows = [] + sparsity_accumulator = SparsityAccumulator() + for details in self.filtered_details.values(): + name = details["name"] + weights = self.get_tensor(details) + row = [ + self._prettify_name(name), + details["index"], + weights.dtype, + weights.size, + ] + if report_sparsity: + sparsity = calculate_sparsity(weights, sparsity_accumulator) + row.append("{:.2f}".format(sparsity)) + rows.append(row) + if verbose: + # Print cluster centroids + print("{} cluster centroids:".format(name)) + pprint(np.unique(weights)) + # Add summary/overall values + empty_row = ["" for _ in range(len(header))] + summary_row = empty_row + summary_row[header.index("Layer")] = "=> OVERALL" + summary_row[header.index("Num weights")] = str( + sparsity_accumulator.total_weights + ) + if report_sparsity: + summary_row[header.index("Sparsity")] = "{:.2f}".format( + sparsity_accumulator.sparsity() + ) + rows.append(summary_row) + # Report detailed cluster info + if report_cluster_mode is not None: + print() + self._print_cluster_details(report_cluster_mode, max_num_clusters) + print("#" * 80) + + def _print_cluster_details( + self, report_cluster_mode: ReportClusterMode, max_num_clusters: int + ) -> None: + print("{}:\n{}".format(report_cluster_mode.name, report_cluster_mode.value)) + num_clusters = self.num_unique_weights(report_cluster_mode) + if ( + report_cluster_mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM + and max_num_clusters > 0 + ): + # Only show cluster histogram if there are not more than + # max_num_clusters. This is a workaround for not showing a huge + # histogram for unclustered layers. + for name, value in num_clusters.items(): + if len(value) > max_num_clusters: + num_clusters[name] = "More than {} unique values.".format( + max_num_clusters + ) + for name, nums in num_clusters.items(): + print("- {}: {}".format(self._prettify_name(name), nums)) + + @staticmethod + def _print_in_outs(ios: List[dict], verbose: bool = False) -> None: + for item in ios: + if verbose: + pprint(item) + else: + print( + "- {} ({}): {}".format( + item["name"], + np.dtype(item["dtype"]).name, + item["shape"], + ) + ) diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py new file mode 100644 index 0000000..4abf6cd --- /dev/null +++ b/src/mlia/nn/tensorflow/utils.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Collection of useful functions for optimizations.""" +import logging +from pathlib import Path +from typing import Callable +from typing import Iterable +from typing import Union + +import numpy as np +import tensorflow as tf +from tensorflow.lite.python.interpreter import Interpreter + +from mlia.utils.logging import redirect_output + + +def representative_dataset(model: tf.keras.Model) -> Callable: + """Sample dataset used for quantization.""" + input_shape = model.input_shape + + def dataset() -> Iterable: + for _ in range(100): + if input_shape[0] != 1: + raise Exception("Only the input batch_size=1 is supported!") + data = np.random.rand(*input_shape) + yield [data.astype(np.float32)] + + return dataset + + +def get_tf_tensor_shape(model: str) -> list: + """Get input shape for the TensorFlow tensor model.""" + # Loading the model + loaded = tf.saved_model.load(model) + # The model signature must have 'serving_default' as a key + if "serving_default" not in loaded.signatures.keys(): + raise Exception( + "Unsupported TensorFlow model signature, must have 'serving_default'" + ) + # Get the signature inputs + inputs_tensor_info = loaded.signatures["serving_default"].inputs + dims = [] + # Build a list of all inputs shape sizes + for input_key in inputs_tensor_info: + if input_key.get_shape(): + dims.extend(list(input_key.get_shape())) + return dims + + +def representative_tf_dataset(model: str) -> Callable: + """Sample dataset used for quantization.""" + if not (input_shape := get_tf_tensor_shape(model)): + raise Exception("Unable to get input shape") + + def dataset() -> Iterable: + for _ in range(100): + data = np.random.rand(*input_shape) + yield [data.astype(np.float32)] + + return dataset + + +def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter: + """Convert Keras model to TFLite.""" + if not isinstance(model, tf.keras.Model): + raise Exception("Invalid model type") + + converter = tf.lite.TFLiteConverter.from_keras_model(model) + + if quantized: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset(model) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + + with redirect_output(logging.getLogger("tensorflow")): + tflite_model = converter.convert() + + return tflite_model + + +def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter: + """Convert TensorFlow model to TFLite.""" + if not isinstance(model, str): + raise Exception("Invalid model type") + + converter = tf.lite.TFLiteConverter.from_saved_model(model) + + if quantized: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_tf_dataset(model) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + + with redirect_output(logging.getLogger("tensorflow")): + tflite_model = converter.convert() + + return tflite_model + + +def save_keras_model(model: tf.keras.Model, save_path: Union[str, Path]) -> None: + """Save Keras model at provided path.""" + # Checkpoint: saving the optimizer is necessary. + model.save(save_path, include_optimizer=True) + + +def save_tflite_model( + model: tf.lite.TFLiteConverter, save_path: Union[str, Path] +) -> None: + """Save TFLite model at provided path.""" + with open(save_path, "wb") as file: + file.write(model) + + +def is_tflite_model(model: Union[Path, str]) -> bool: + """Check if model type is supported by TFLite API. + + TFLite model is indicated by the model file extension .tflite + """ + model_path = Path(model) + return model_path.suffix == ".tflite" + + +def is_keras_model(model: Union[Path, str]) -> bool: + """Check if model type is supported by Keras API. + + Keras model is indicated by: + 1. if it's a directory (meaning saved model), + it should contain keras_metadata.pb file + 2. or if the model file extension is .h5/.hdf5 + """ + model_path = Path(model) + + if model_path.is_dir(): + return (model_path / "keras_metadata.pb").exists() + return model_path.suffix in (".h5", ".hdf5") + + +def is_tf_model(model: Union[Path, str]) -> bool: + """Check if model type is supported by TensorFlow API. + + TensorFlow model is indicated if its directory (meaning saved model) + doesn't contain keras_metadata.pb file + """ + model_path = Path(model) + return model_path.is_dir() and not is_keras_model(model) -- cgit v1.2.1