aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn')
-rw-r--r--src/mlia/nn/__init__.py3
-rw-r--r--src/mlia/nn/tensorflow/__init__.py3
-rw-r--r--src/mlia/nn/tensorflow/config.py134
-rw-r--r--src/mlia/nn/tensorflow/optimizations/__init__.py3
-rw-r--r--src/mlia/nn/tensorflow/optimizations/clustering.py109
-rw-r--r--src/mlia/nn/tensorflow/optimizations/common.py29
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py168
-rw-r--r--src/mlia/nn/tensorflow/optimizations/select.py179
-rw-r--r--src/mlia/nn/tensorflow/tflite_metrics.py296
-rw-r--r--src/mlia/nn/tensorflow/utils.py149
10 files changed, 1073 insertions, 0 deletions
diff --git a/src/mlia/nn/__init__.py b/src/mlia/nn/__init__.py
new file mode 100644
index 0000000..aac2830
--- /dev/null
+++ b/src/mlia/nn/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""NN related module."""
diff --git a/src/mlia/nn/tensorflow/__init__.py b/src/mlia/nn/tensorflow/__init__.py
new file mode 100644
index 0000000..ff061c1
--- /dev/null
+++ b/src/mlia/nn/tensorflow/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TensorFlow related module."""
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
new file mode 100644
index 0000000..d3235d7
--- /dev/null
+++ b/src/mlia/nn/tensorflow/config.py
@@ -0,0 +1,134 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Model configuration."""
+import logging
+from pathlib import Path
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Union
+
+import tensorflow as tf
+
+from mlia.core.context import Context
+from mlia.nn.tensorflow.utils import convert_tf_to_tflite
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.nn.tensorflow.utils import is_tf_model
+from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+
+logger = logging.getLogger(__name__)
+
+
+class ModelConfiguration:
+ """Base class for model configuration."""
+
+ def __init__(self, model_path: Union[str, Path]) -> None:
+ """Init model configuration instance."""
+ self.model_path = str(model_path)
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ raise NotImplementedError()
+
+ def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ """Convert model to Keras format."""
+ raise NotImplementedError()
+
+
+class KerasModel(ModelConfiguration):
+ """Keras model configuration.
+
+ Supports all models supported by Keras API: saved model, H5, HDF5
+ """
+
+ def get_keras_model(self) -> tf.keras.Model:
+ """Return associated Keras model."""
+ return tf.keras.models.load_model(self.model_path)
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ logger.info("Converting Keras to TFLite ...")
+
+ converted_model = convert_to_tflite(self.get_keras_model(), quantized)
+ logger.info("Done\n")
+
+ save_tflite_model(converted_model, tflite_model_path)
+ logger.debug(
+ "Model %s converted and saved to %s", self.model_path, tflite_model_path
+ )
+
+ return TFLiteModel(tflite_model_path)
+
+ def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ """Convert model to Keras format."""
+ return self
+
+
+class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
+ """TFLite model configuration."""
+
+ def input_details(self) -> List[Dict]:
+ """Get model's input details."""
+ interpreter = tf.lite.Interpreter(model_path=self.model_path)
+ return cast(List[Dict], interpreter.get_input_details())
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ return self
+
+
+class TfModel(ModelConfiguration): # pylint: disable=abstract-method
+ """TensorFlow model configuration.
+
+ Supports models supported by TensorFlow API (not Keras)
+ """
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ converted_model = convert_tf_to_tflite(self.model_path, quantized)
+ save_tflite_model(converted_model, tflite_model_path)
+
+ return TFLiteModel(tflite_model_path)
+
+
+def get_model(model: Union[Path, str]) -> "ModelConfiguration":
+ """Return the model object."""
+ if is_tflite_model(model):
+ return TFLiteModel(model)
+
+ if is_keras_model(model):
+ return KerasModel(model)
+
+ if is_tf_model(model):
+ return TfModel(model)
+
+ raise Exception(
+ "The input model format is not supported"
+ "(supported formats: TFLite, Keras, TensorFlow saved model)!"
+ )
+
+
+def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel":
+ """Convert input model to TFLite and returns TFLiteModel object."""
+ tflite_model_path = ctx.get_model_path("converted_model.tflite")
+ converted_model = get_model(model)
+
+ return converted_model.convert_to_tflite(tflite_model_path, True)
+
+
+def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel":
+ """Convert input model to Keras and returns KerasModel object."""
+ keras_model_path = ctx.get_model_path("converted_model.h5")
+ converted_model = get_model(model)
+
+ return converted_model.convert_to_keras(keras_model_path)
diff --git a/src/mlia/nn/tensorflow/optimizations/__init__.py b/src/mlia/nn/tensorflow/optimizations/__init__.py
new file mode 100644
index 0000000..201c130
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Optimizations module."""
diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py
new file mode 100644
index 0000000..16d9e4b
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/clustering.py
@@ -0,0 +1,109 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Contains class Clusterer that clusters unique weights per layer to a specified number.
+
+In order to do this, we need to have a base model and corresponding training data.
+We also have to specify a subset of layers we want to cluster. For more details,
+please refer to the documentation for TensorFlow Model Optimization Toolkit.
+"""
+from dataclasses import dataclass
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+from tensorflow_model_optimization.python.core.clustering.keras.experimental import ( # pylint: disable=no-name-in-module
+ cluster as experimental_cluster,
+)
+
+from mlia.nn.tensorflow.optimizations.common import Optimizer
+from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+
+
+@dataclass
+class ClusteringConfiguration(OptimizerConfiguration):
+ """Clustering configuration."""
+
+ optimization_target: int
+ layers_to_optimize: Optional[List[str]] = None
+
+ def __str__(self) -> str:
+ """Return string representation of the configuration."""
+ return f"clustering: {self.optimization_target}"
+
+
+class Clusterer(Optimizer):
+ """
+ Clusterer class.
+
+ Used to cluster a model to a specified number of unique weights per layer.
+
+ Sample usage:
+ clusterer = Clusterer(
+ base_model,
+ optimizer_configuration)
+
+ clusterer.apply_clustering()
+ clustered_model = clusterer.get_model()
+ """
+
+ def __init__(
+ self, model: tf.keras.Model, optimizer_configuration: ClusteringConfiguration
+ ):
+ """Init Clusterer instance."""
+ self.model = model
+ self.optimizer_configuration = optimizer_configuration
+
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
+ return str(self.optimizer_configuration)
+
+ def _setup_clustering_params(self) -> Dict[str, Any]:
+ CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
+ return {
+ "number_of_clusters": self.optimizer_configuration.optimization_target,
+ "cluster_centroids_init": CentroidInitialization.LINEAR,
+ "preserve_sparsity": True,
+ }
+
+ def _apply_clustering_to_layer(
+ self, layer: tf.keras.layers.Layer
+ ) -> tf.keras.layers.Layer:
+ layers_to_optimize = self.optimizer_configuration.layers_to_optimize
+ assert layers_to_optimize, "List of the layers to optimize is empty"
+
+ if layer.name not in layers_to_optimize:
+ return layer
+
+ clustering_params = self._setup_clustering_params()
+ return experimental_cluster.cluster_weights(layer, **clustering_params)
+
+ def _init_for_clustering(self) -> None:
+ # Use `tf.keras.models.clone_model` to apply `apply_clustering_to_layer`
+ # to the layers of the model
+ if not self.optimizer_configuration.layers_to_optimize:
+ clustering_params = self._setup_clustering_params()
+ clustered_model = experimental_cluster.cluster_weights(
+ self.model, **clustering_params
+ )
+ else:
+ clustered_model = tf.keras.models.clone_model(
+ self.model, clone_function=self._apply_clustering_to_layer
+ )
+
+ self.model = clustered_model
+
+ def _strip_clustering(self) -> None:
+ self.model = tfmot.clustering.keras.strip_clustering(self.model)
+
+ def apply_optimization(self) -> None:
+ """Apply all steps of clustering at once."""
+ self._init_for_clustering()
+ self._strip_clustering()
+
+ def get_model(self) -> tf.keras.Model:
+ """Get model."""
+ return self.model
diff --git a/src/mlia/nn/tensorflow/optimizations/common.py b/src/mlia/nn/tensorflow/optimizations/common.py
new file mode 100644
index 0000000..1dce0b2
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/common.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common items for the optimizations module."""
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+
+import tensorflow as tf
+
+
+@dataclass
+class OptimizerConfiguration:
+ """Abstract optimizer configuration."""
+
+
+class Optimizer(ABC):
+ """Abstract class for the optimizer."""
+
+ @abstractmethod
+ def get_model(self) -> tf.keras.Model:
+ """Abstract method to return the model instance from the optimizer."""
+
+ @abstractmethod
+ def apply_optimization(self) -> None:
+ """Abstract method to apply optimization to the model."""
+
+ @abstractmethod
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
new file mode 100644
index 0000000..f629ba1
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -0,0 +1,168 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Contains class Pruner to prune a model to a specified sparsity.
+
+In order to do this, we need to have a base model and corresponding training data.
+We also have to specify a subset of layers we want to prune. For more details,
+please refer to the documentation for TensorFlow Model Optimization Toolkit.
+"""
+from dataclasses import dataclass
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module
+ pruning_wrapper,
+)
+
+from mlia.nn.tensorflow.optimizations.common import Optimizer
+from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+
+
+@dataclass
+class PruningConfiguration(OptimizerConfiguration):
+ """Pruning configuration."""
+
+ optimization_target: float
+ layers_to_optimize: Optional[List[str]] = None
+ x_train: Optional[np.array] = None
+ y_train: Optional[np.array] = None
+ batch_size: int = 1
+ num_epochs: int = 1
+
+ def __str__(self) -> str:
+ """Return string representation of the configuration."""
+ return f"pruning: {self.optimization_target}"
+
+ def has_training_data(self) -> bool:
+ """Return True if training data provided."""
+ return self.x_train is not None and self.y_train is not None
+
+
+class Pruner(Optimizer):
+ """
+ Pruner class. Used to prune a model to a specified sparsity.
+
+ Sample usage:
+ pruner = Pruner(
+ base_model,
+ optimizer_configuration)
+
+ pruner.apply_pruning()
+ pruned_model = pruner.get_model()
+ """
+
+ def __init__(
+ self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration
+ ):
+ """Init Pruner instance."""
+ self.model = model
+ self.optimizer_configuration = optimizer_configuration
+
+ if not optimizer_configuration.has_training_data():
+ mock_x_train, mock_y_train = self._mock_train_data()
+
+ self.optimizer_configuration.x_train = mock_x_train
+ self.optimizer_configuration.y_train = mock_y_train
+
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
+ return str(self.optimizer_configuration)
+
+ def _mock_train_data(self) -> Tuple[np.array, np.array]:
+ # get rid of the batch_size dimension in input and output shape
+ input_shape = tuple(x for x in self.model.input_shape if x is not None)
+ output_shape = tuple(x for x in self.model.output_shape if x is not None)
+
+ return (
+ np.random.rand(*input_shape),
+ np.random.randint(0, output_shape[-1], (output_shape[:-1])),
+ )
+
+ def _setup_pruning_params(self) -> dict:
+ return {
+ "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=0,
+ final_sparsity=self.optimizer_configuration.optimization_target,
+ begin_step=0,
+ end_step=self.optimizer_configuration.num_epochs,
+ frequency=1,
+ ),
+ }
+
+ def _apply_pruning_to_layer(
+ self, layer: tf.keras.layers.Layer
+ ) -> tf.keras.layers.Layer:
+ layers_to_optimize = self.optimizer_configuration.layers_to_optimize
+ assert layers_to_optimize, "List of the layers to optimize is empty"
+
+ if layer.name not in layers_to_optimize:
+ return layer
+
+ pruning_params = self._setup_pruning_params()
+ return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
+
+ def _init_for_pruning(self) -> None:
+ # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer`
+ # to the layers of the model
+ if not self.optimizer_configuration.layers_to_optimize:
+ pruning_params = self._setup_pruning_params()
+ prunable_model = tfmot.sparsity.keras.prune_low_magnitude(
+ self.model, **pruning_params
+ )
+ else:
+ prunable_model = tf.keras.models.clone_model(
+ self.model, clone_function=self._apply_pruning_to_layer
+ )
+
+ self.model = prunable_model
+
+ def _train_pruning(self) -> None:
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
+ self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
+
+ # Model callbacks
+ callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
+
+ # Fitting data
+ self.model.fit(
+ self.optimizer_configuration.x_train,
+ self.optimizer_configuration.y_train,
+ batch_size=self.optimizer_configuration.batch_size,
+ epochs=self.optimizer_configuration.num_epochs,
+ callbacks=callbacks,
+ verbose=0,
+ )
+
+ def _assert_sparsity_reached(self) -> None:
+ for layer in self.model.layers:
+ if not isinstance(layer, pruning_wrapper.PruneLowMagnitude):
+ continue
+
+ for weight in layer.layer.get_prunable_weights():
+ nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight))
+ all_weights = tf.keras.backend.get_value(weight).size
+
+ np.testing.assert_approx_equal(
+ self.optimizer_configuration.optimization_target,
+ 1 - nonzero_weights / all_weights,
+ significant=2,
+ )
+
+ def _strip_pruning(self) -> None:
+ self.model = tfmot.sparsity.keras.strip_pruning(self.model)
+
+ def apply_optimization(self) -> None:
+ """Apply all steps of pruning sequentially."""
+ self._init_for_pruning()
+ self._train_pruning()
+ self._assert_sparsity_reached()
+ self._strip_pruning()
+
+ def get_model(self) -> tf.keras.Model:
+ """Get model."""
+ return self.model
diff --git a/src/mlia/nn/tensorflow/optimizations/select.py b/src/mlia/nn/tensorflow/optimizations/select.py
new file mode 100644
index 0000000..1b0c755
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/select.py
@@ -0,0 +1,179 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for optimization selection."""
+import math
+from typing import List
+from typing import NamedTuple
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import tensorflow as tf
+
+from mlia.core.errors import ConfigurationError
+from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.optimizations.clustering import Clusterer
+from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
+from mlia.nn.tensorflow.optimizations.common import Optimizer
+from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.utils.types import is_list_of
+
+
+class OptimizationSettings(NamedTuple):
+ """Optimization settings."""
+
+ optimization_type: str
+ optimization_target: Union[int, float]
+ layers_to_optimize: Optional[List[str]]
+
+ @staticmethod
+ def create_from(
+ optimizer_params: List[Tuple[str, float]],
+ layers_to_optimize: Optional[List[str]] = None,
+ ) -> List["OptimizationSettings"]:
+ """Create optimization settings from the provided parameters."""
+ return [
+ OptimizationSettings(
+ optimization_type=opt_type,
+ optimization_target=opt_target,
+ layers_to_optimize=layers_to_optimize,
+ )
+ for opt_type, opt_target in optimizer_params
+ ]
+
+ def __str__(self) -> str:
+ """Return string representation."""
+ return f"{self.optimization_type}: {self.optimization_target}"
+
+ def next_target(self) -> "OptimizationSettings":
+ """Return next optimization target."""
+ if self.optimization_type == "pruning":
+ next_target = round(min(self.optimization_target + 0.1, 0.9), 2)
+ return OptimizationSettings(
+ self.optimization_type, next_target, self.layers_to_optimize
+ )
+
+ if self.optimization_type == "clustering":
+ # return next lowest power of two for clustering
+ next_target = math.log(self.optimization_target, 2)
+ if next_target.is_integer():
+ next_target -= 1
+
+ next_target = max(int(2 ** int(next_target)), 4)
+ return OptimizationSettings(
+ self.optimization_type, next_target, self.layers_to_optimize
+ )
+
+ raise Exception(f"Unknown optimization type {self.optimization_type}")
+
+
+class MultiStageOptimizer(Optimizer):
+ """Optimizer with multiply stages."""
+
+ def __init__(
+ self,
+ model: tf.keras.Model,
+ optimizations: List[OptimizerConfiguration],
+ ) -> None:
+ """Init MultiStageOptimizer instance."""
+ self.model = model
+ self.optimizations = optimizations
+
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
+ return " - ".join(str(opt) for opt in self.optimizations)
+
+ def get_model(self) -> tf.keras.Model:
+ """Return optimized model."""
+ return self.model
+
+ def apply_optimization(self) -> None:
+ """Apply optimization to the model."""
+ for config in self.optimizations:
+ optimizer = get_optimizer(self.model, config)
+ optimizer.apply_optimization()
+ self.model = optimizer.get_model()
+
+
+def get_optimizer(
+ model: Union[tf.keras.Model, KerasModel],
+ config: Union[
+ OptimizerConfiguration, OptimizationSettings, List[OptimizationSettings]
+ ],
+) -> Optimizer:
+ """Get optimizer for provided configuration."""
+ if isinstance(model, KerasModel):
+ model = model.get_keras_model()
+
+ if isinstance(config, PruningConfiguration):
+ return Pruner(model, config)
+
+ if isinstance(config, ClusteringConfiguration):
+ return Clusterer(model, config)
+
+ if isinstance(config, OptimizationSettings) or is_list_of(
+ config, OptimizationSettings
+ ):
+ return _get_optimizer(model, config) # type: ignore
+
+ raise ConfigurationError(f"Unknown optimization configuration {config}")
+
+
+def _get_optimizer(
+ model: tf.keras.Model,
+ optimization_settings: Union[OptimizationSettings, List[OptimizationSettings]],
+) -> Optimizer:
+ if isinstance(optimization_settings, OptimizationSettings):
+ optimization_settings = [optimization_settings]
+
+ optimizer_configs = []
+ for opt_type, opt_target, layers_to_optimize in optimization_settings:
+ _check_optimizer_params(opt_type, opt_target)
+
+ opt_config = _get_optimizer_configuration(
+ opt_type, opt_target, layers_to_optimize
+ )
+ optimizer_configs.append(opt_config)
+
+ if len(optimizer_configs) == 1:
+ return get_optimizer(model, optimizer_configs[0])
+
+ return MultiStageOptimizer(model, optimizer_configs)
+
+
+def _get_optimizer_configuration(
+ optimization_type: str,
+ optimization_target: Union[int, float],
+ layers_to_optimize: Optional[List[str]] = None,
+) -> OptimizerConfiguration:
+ """Get optimizer configuration for provided parameters."""
+ _check_optimizer_params(optimization_type, optimization_target)
+
+ opt_type = optimization_type.lower()
+ if opt_type == "pruning":
+ return PruningConfiguration(optimization_target, layers_to_optimize)
+
+ if opt_type == "clustering":
+ # make sure an integer is given as clustering target
+ if optimization_target == int(optimization_target):
+ return ClusteringConfiguration(int(optimization_target), layers_to_optimize)
+
+ raise ConfigurationError(
+ "Optimization target should be a positive integer. "
+ f"Optimization target provided: {optimization_target}"
+ )
+
+ raise ConfigurationError(f"Unsupported optimization type: {optimization_type}")
+
+
+def _check_optimizer_params(
+ optimization_type: str, optimization_target: Union[int, float]
+) -> None:
+ """Check optimizer params."""
+ if not optimization_target:
+ raise ConfigurationError("Optimization target is not provided")
+
+ if not optimization_type:
+ raise ConfigurationError("Optimization type is not provided")
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py
new file mode 100644
index 0000000..b29fab3
--- /dev/null
+++ b/src/mlia/nn/tensorflow/tflite_metrics.py
@@ -0,0 +1,296 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Contains class TFLiteMetrics to calculate metrics from a TFLite file.
+
+These metrics include:
+* Sparsity (per layer and overall)
+* Unique weights (clusters) (per layer)
+* gzip compression ratio
+"""
+import os
+from enum import Enum
+from pprint import pprint
+from typing import Any
+from typing import List
+from typing import Optional
+
+import numpy as np
+import tensorflow as tf
+
+DEFAULT_IGNORE_LIST = [
+ "relu",
+ "pooling",
+ "reshape",
+ "identity",
+ "input",
+ "add",
+ "flatten",
+ "StatefulPartitionedCall",
+ "bias",
+]
+
+
+def calculate_num_unique_weights(weights: np.array) -> int:
+ """Calculate the number of unique weights in the given weights."""
+ num_unique_weights = len(np.unique(weights))
+ return num_unique_weights
+
+
+def calculate_num_unique_weights_per_axis(weights: np.array, axis: int) -> List[int]:
+ """Calculate unique weights per quantization axis."""
+ # Make quantized dimension the first dimension
+ weights_trans = np.swapaxes(weights, 0, axis)
+ num_uniques_weights = [
+ calculate_num_unique_weights(weights_trans[i])
+ for i in range(weights_trans.shape[0])
+ ]
+ assert num_uniques_weights
+ return num_uniques_weights
+
+
+class SparsityAccumulator:
+ """Helper class to accumulate sparsity over several layers."""
+
+ def __init__(self) -> None:
+ """Create an empty accumulator."""
+ self.total_non_zero_weights: int = 0
+ self.total_weights: int = 0
+
+ def __call__(self, weights: np.array) -> None:
+ """Update the accumulator with the given weights."""
+ non_zero_weights = np.count_nonzero(weights)
+ self.total_non_zero_weights += non_zero_weights
+ self.total_weights += weights.size
+
+ def sparsity(self) -> float:
+ """Calculate the sparsity for all added weights."""
+ return 1.0 - self.total_non_zero_weights / float(self.total_weights)
+
+
+def calculate_sparsity(
+ weights: np.array, accumulator: Optional[SparsityAccumulator] = None
+) -> float:
+ """
+ Calculate the sparsity for the given weights.
+
+ If the accumulator is passed, it is updated as well.
+ """
+ non_zero_weights = np.count_nonzero(weights)
+ sparsity = 1.0 - float(non_zero_weights) / float(weights.size)
+ if accumulator is not None:
+ accumulator(weights)
+ return sparsity
+
+
+class ReportClusterMode(Enum):
+ """Specifies the way cluster values are aggregated and reported."""
+
+ NUM_CLUSTERS_HISTOGRAM = (
+ "A histogram of the number of clusters per axis. "
+ "I.e. the number of clusters is the index of the list (the bin) and "
+ "the value is the number of axes that have this number of clusters. "
+ "The first bin is 1."
+ )
+ NUM_CLUSTERS_PER_AXIS = "Number of clusters (unique weights) per axis."
+ NUM_CLUSTERS_MIN_MAX = "Min/max number of clusters over all axes."
+
+
+class TFLiteMetrics:
+ """Helper class to calculate metrics from a TFLite file.
+
+ Metrics include:
+ * sparsity (per-layer and overall)
+ * number of unique weights (clusters) per layer
+ * File compression via gzip
+ """
+
+ def __init__(
+ self, tflite_file: str, ignore_list: Optional[List[str]] = None
+ ) -> None:
+ """Load the TFLite file and filter layers."""
+ self.tflite_file = tflite_file
+ if ignore_list is None:
+ ignore_list = DEFAULT_IGNORE_LIST
+ self.ignore_list = [ignore.casefold() for ignore in ignore_list]
+ # Initialize the TFLite interpreter with the model file
+ self.interpreter = tf.lite.Interpreter(model_path=tflite_file)
+ self.interpreter.allocate_tensors()
+ self.details: dict = {}
+
+ def ignore(details: dict) -> bool:
+ name = details["name"].casefold()
+ if not name:
+ return True
+ for to_ignore in self.ignore_list:
+ if to_ignore in name:
+ return True
+ return False
+
+ self.filtered_details = {
+ details["name"]: details
+ for details in self.interpreter.get_tensor_details()
+ if not ignore(details)
+ }
+
+ def get_tensor(self, details: dict) -> Any:
+ """Return the weights/tensor specified in the given details map."""
+ return self.interpreter.tensor(details["index"])()
+
+ def sparsity_per_layer(self) -> dict:
+ """Return a dict of layer name and sparsity value."""
+ sparsity = {
+ name: calculate_sparsity(self.get_tensor(details))
+ for name, details in self.filtered_details.items()
+ }
+ return sparsity
+
+ def sparsity_overall(self) -> float:
+ """Return an instance of SparsityAccumulator for the filtered layers."""
+ acc = SparsityAccumulator()
+ for details in self.filtered_details.values():
+ acc(self.get_tensor(details))
+ return acc.sparsity()
+
+ def calc_num_clusters_per_axis(self, details: dict) -> List[int]:
+ """Calculate number of clusters per axis."""
+ quant_params = details["quantization_parameters"]
+ per_axis = len(quant_params["zero_points"]) > 1
+ if per_axis:
+ # Calculate unique weights along quantization axis
+ axis = quant_params["quantized_dimension"]
+ return calculate_num_unique_weights_per_axis(self.get_tensor(details), axis)
+
+ # Calculate unique weights over all axes/dimensions
+ return [calculate_num_unique_weights(self.get_tensor(details))]
+
+ def num_unique_weights(self, mode: ReportClusterMode) -> dict:
+ """Return a dict of layer name and number of unique weights."""
+ aggregation_func = None
+ if mode == ReportClusterMode.NUM_CLUSTERS_PER_AXIS:
+ aggregation_func = self.calc_num_clusters_per_axis
+ elif mode == ReportClusterMode.NUM_CLUSTERS_MIN_MAX:
+
+ def cluster_min_max(details: dict) -> List[int]:
+ num_clusters = self.calc_num_clusters_per_axis(details)
+ return [min(num_clusters), max(num_clusters)]
+
+ aggregation_func = cluster_min_max
+ elif mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM:
+
+ def cluster_hist(details: dict) -> List[int]:
+ num_clusters = self.calc_num_clusters_per_axis(details)
+ max_num = max(num_clusters)
+ hist = [0] * (max_num)
+ for num in num_clusters:
+ idx = num - 1
+ hist[idx] += 1
+ return hist
+
+ aggregation_func = cluster_hist
+ else:
+ raise NotImplementedError(
+ "ReportClusterMode '{}' not implemented.".format(mode)
+ )
+ uniques = {
+ name: aggregation_func(details)
+ for name, details in self.filtered_details.items()
+ }
+ return uniques
+
+ @staticmethod
+ def _prettify_name(name: str) -> str:
+ if name.startswith("model"):
+ return name.split("/", 1)[1]
+ return name
+
+ def summary(
+ self,
+ report_sparsity: bool,
+ report_cluster_mode: ReportClusterMode = None,
+ max_num_clusters: int = 32,
+ verbose: bool = False,
+ ) -> None:
+ """Print a summary of all the model information."""
+ print("Model file: {}".format(self.tflite_file))
+ print("#" * 80)
+ print(" " * 28 + "### TFLITE SUMMARY ###")
+ print("File: {}".format(os.path.abspath(self.tflite_file)))
+ print("Input(s):")
+ self._print_in_outs(self.interpreter.get_input_details(), verbose)
+ print("Output(s):")
+ self._print_in_outs(self.interpreter.get_output_details(), verbose)
+ print()
+ header = ["Layer", "Index", "Type", "Num weights"]
+ if report_sparsity:
+ header.append("Sparsity")
+ rows = []
+ sparsity_accumulator = SparsityAccumulator()
+ for details in self.filtered_details.values():
+ name = details["name"]
+ weights = self.get_tensor(details)
+ row = [
+ self._prettify_name(name),
+ details["index"],
+ weights.dtype,
+ weights.size,
+ ]
+ if report_sparsity:
+ sparsity = calculate_sparsity(weights, sparsity_accumulator)
+ row.append("{:.2f}".format(sparsity))
+ rows.append(row)
+ if verbose:
+ # Print cluster centroids
+ print("{} cluster centroids:".format(name))
+ pprint(np.unique(weights))
+ # Add summary/overall values
+ empty_row = ["" for _ in range(len(header))]
+ summary_row = empty_row
+ summary_row[header.index("Layer")] = "=> OVERALL"
+ summary_row[header.index("Num weights")] = str(
+ sparsity_accumulator.total_weights
+ )
+ if report_sparsity:
+ summary_row[header.index("Sparsity")] = "{:.2f}".format(
+ sparsity_accumulator.sparsity()
+ )
+ rows.append(summary_row)
+ # Report detailed cluster info
+ if report_cluster_mode is not None:
+ print()
+ self._print_cluster_details(report_cluster_mode, max_num_clusters)
+ print("#" * 80)
+
+ def _print_cluster_details(
+ self, report_cluster_mode: ReportClusterMode, max_num_clusters: int
+ ) -> None:
+ print("{}:\n{}".format(report_cluster_mode.name, report_cluster_mode.value))
+ num_clusters = self.num_unique_weights(report_cluster_mode)
+ if (
+ report_cluster_mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM
+ and max_num_clusters > 0
+ ):
+ # Only show cluster histogram if there are not more than
+ # max_num_clusters. This is a workaround for not showing a huge
+ # histogram for unclustered layers.
+ for name, value in num_clusters.items():
+ if len(value) > max_num_clusters:
+ num_clusters[name] = "More than {} unique values.".format(
+ max_num_clusters
+ )
+ for name, nums in num_clusters.items():
+ print("- {}: {}".format(self._prettify_name(name), nums))
+
+ @staticmethod
+ def _print_in_outs(ios: List[dict], verbose: bool = False) -> None:
+ for item in ios:
+ if verbose:
+ pprint(item)
+ else:
+ print(
+ "- {} ({}): {}".format(
+ item["name"],
+ np.dtype(item["dtype"]).name,
+ item["shape"],
+ )
+ )
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
new file mode 100644
index 0000000..4abf6cd
--- /dev/null
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -0,0 +1,149 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+"""Collection of useful functions for optimizations."""
+import logging
+from pathlib import Path
+from typing import Callable
+from typing import Iterable
+from typing import Union
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.lite.python.interpreter import Interpreter
+
+from mlia.utils.logging import redirect_output
+
+
+def representative_dataset(model: tf.keras.Model) -> Callable:
+ """Sample dataset used for quantization."""
+ input_shape = model.input_shape
+
+ def dataset() -> Iterable:
+ for _ in range(100):
+ if input_shape[0] != 1:
+ raise Exception("Only the input batch_size=1 is supported!")
+ data = np.random.rand(*input_shape)
+ yield [data.astype(np.float32)]
+
+ return dataset
+
+
+def get_tf_tensor_shape(model: str) -> list:
+ """Get input shape for the TensorFlow tensor model."""
+ # Loading the model
+ loaded = tf.saved_model.load(model)
+ # The model signature must have 'serving_default' as a key
+ if "serving_default" not in loaded.signatures.keys():
+ raise Exception(
+ "Unsupported TensorFlow model signature, must have 'serving_default'"
+ )
+ # Get the signature inputs
+ inputs_tensor_info = loaded.signatures["serving_default"].inputs
+ dims = []
+ # Build a list of all inputs shape sizes
+ for input_key in inputs_tensor_info:
+ if input_key.get_shape():
+ dims.extend(list(input_key.get_shape()))
+ return dims
+
+
+def representative_tf_dataset(model: str) -> Callable:
+ """Sample dataset used for quantization."""
+ if not (input_shape := get_tf_tensor_shape(model)):
+ raise Exception("Unable to get input shape")
+
+ def dataset() -> Iterable:
+ for _ in range(100):
+ data = np.random.rand(*input_shape)
+ yield [data.astype(np.float32)]
+
+ return dataset
+
+
+def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter:
+ """Convert Keras model to TFLite."""
+ if not isinstance(model, tf.keras.Model):
+ raise Exception("Invalid model type")
+
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+ if quantized:
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset(model)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+
+ with redirect_output(logging.getLogger("tensorflow")):
+ tflite_model = converter.convert()
+
+ return tflite_model
+
+
+def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter:
+ """Convert TensorFlow model to TFLite."""
+ if not isinstance(model, str):
+ raise Exception("Invalid model type")
+
+ converter = tf.lite.TFLiteConverter.from_saved_model(model)
+
+ if quantized:
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_tf_dataset(model)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+
+ with redirect_output(logging.getLogger("tensorflow")):
+ tflite_model = converter.convert()
+
+ return tflite_model
+
+
+def save_keras_model(model: tf.keras.Model, save_path: Union[str, Path]) -> None:
+ """Save Keras model at provided path."""
+ # Checkpoint: saving the optimizer is necessary.
+ model.save(save_path, include_optimizer=True)
+
+
+def save_tflite_model(
+ model: tf.lite.TFLiteConverter, save_path: Union[str, Path]
+) -> None:
+ """Save TFLite model at provided path."""
+ with open(save_path, "wb") as file:
+ file.write(model)
+
+
+def is_tflite_model(model: Union[Path, str]) -> bool:
+ """Check if model type is supported by TFLite API.
+
+ TFLite model is indicated by the model file extension .tflite
+ """
+ model_path = Path(model)
+ return model_path.suffix == ".tflite"
+
+
+def is_keras_model(model: Union[Path, str]) -> bool:
+ """Check if model type is supported by Keras API.
+
+ Keras model is indicated by:
+ 1. if it's a directory (meaning saved model),
+ it should contain keras_metadata.pb file
+ 2. or if the model file extension is .h5/.hdf5
+ """
+ model_path = Path(model)
+
+ if model_path.is_dir():
+ return (model_path / "keras_metadata.pb").exists()
+ return model_path.suffix in (".h5", ".hdf5")
+
+
+def is_tf_model(model: Union[Path, str]) -> bool:
+ """Check if model type is supported by TensorFlow API.
+
+ TensorFlow model is indicated if its directory (meaning saved model)
+ doesn't contain keras_metadata.pb file
+ """
+ model_path = Path(model)
+ return model_path.is_dir() and not is_keras_model(model)