diff options
Diffstat (limited to 'src')
17 files changed, 465 insertions, 135 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index e2c097c..a802c51 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -8,6 +8,7 @@ import tempfile from abc import ABC from abc import abstractmethod from dataclasses import dataclass +from inspect import getfullargspec from pathlib import Path from typing import Any from typing import Callable @@ -15,6 +16,9 @@ from typing import Callable import numpy as np import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 +from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module + is_pruned_m_by_n, +) from mlia.core.errors import ConfigurationError from mlia.core.reporting import Column @@ -24,19 +28,16 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters -from mlia.nn.rewrite.library.fc_clustering_layer import ( - get_keras_model_clus as fc_clustering_rewrite, -) -from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite -from mlia.nn.rewrite.library.fc_sparsity24_layer import ( - get_keras_model as fc_rewrite_sparsity24, -) +from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite +from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite +from mlia.nn.rewrite.library.fc_layer import fc_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry - logger = logging.getLogger(__name__) -RewriteCallable = Callable[[Any, Any], keras.Model] +RewriteCallable = Callable[..., keras.Model] class Rewrite(ABC): @@ -47,10 +48,23 @@ class Rewrite(ABC): self.name = name self.function = rewrite_fn - def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model: + def __call__( + self, input_shape: Any, output_shape: Any, **kwargs: Any + ) -> keras.Model: """Perform the rewrite operation using the configured function.""" try: - return self.function(input_shape, output_shape) + return self.function(input_shape, output_shape, **kwargs) + except TypeError as ex: + expected_args = self.return_rewrite_func_args() + if "input_shape" in expected_args: + expected_args.remove("input_shape") + if "output_shape" in expected_args: + expected_args.remove("output_shape") + raise KeyError( + f"Found unexpected parameters for rewrite. Expected (sub)set " + f"of {expected_args} found unexpected parameter(s) " + f"{list(set(list(kwargs.keys())) - set(expected_args))}" + ) from ex except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex @@ -58,21 +72,25 @@ class Rewrite(ABC): """Return a quantized model if required.""" return model + def return_rewrite_func_args(self) -> list[str]: + """Return the expected args of the rewrite function.""" + return getfullargspec(self.function).args + @abstractmethod def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" + """Return rewrite callbacks.""" @abstractmethod def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return post-processing rewrite option.""" @abstractmethod - def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + def check_optimization(self, model: keras.Model) -> bool: """Check if the optimization has produced the correct result.""" class GenericRewrite(Rewrite): - """Graph rewrite logic for fully-connected rewrite.""" + """Rewrite class for generic rewrites e.g. fully-connected.""" def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" @@ -83,10 +101,10 @@ class GenericRewrite(Rewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return default post-processing rewrite option.""" return model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization(self, model: keras.Model) -> bool: """Not needed here.""" return True @@ -100,15 +118,15 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC): return model -class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): - """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" +class SparsityRewrite(QuantizeAwareTrainingRewrite): + """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) def quantize(self, model: keras.Model) -> keras.Model: - """Skip quantization when using pruning rewrite.""" + """Skip quantization when using sparsity rewrite.""" return model def training_callbacks(self) -> list: @@ -116,7 +134,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): return [self.pruning_callback()] def post_process(self, model: keras.Model) -> keras.Model: - """Pruning-specific post-processing rewrite options.""" + """Pruning-specific post-processing rewrite option.""" return self.strip_pruning_wrapper(model) def preserved_quantize( @@ -132,13 +150,34 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): return model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: - """Not needed here.""" + def check_optimization( + self, + model: keras.Model, + sparsity_m: int = 2, + sparsity_n: int = 4, + **_: Any, + ) -> bool: + """Check if sparity has produced the correct result.""" + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)): + logger.warning( + "\nWARNING: Could not find (%d, %d) sparsity, " + "in layer %s for weight %s \n", + sparsity_m, + sparsity_n, + layer.name, + weight.name, + ) + return False return True class ClusteringRewrite(QuantizeAwareTrainingRewrite): - """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + """Rewrite class for clustering rewrite e.g. fully-connected-clustering.""" _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) @@ -151,27 +190,21 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): ) return cqat_model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization( + self, model: keras.Model, num_clusters: int = 2, **_: Any + ) -> bool: """Check if clustering has produced the correct result.""" - number_of_clusters = kwargs.get("number_of_clusters") - if not number_of_clusters: - raise ValueError( - """ - Expected check_preserved_quantize to have argument number_of_clusters. - """ - ) - for layer in model.layers: for weight in layer.weights: if "kernel" in weight.name: if "kernel_min" in weight.name or "kernel_max" in weight.name: continue number_of_found_clusters = len(np.unique(weight)) - if number_of_found_clusters != number_of_clusters: + if number_of_found_clusters != num_clusters: logger.warning( "\nWARNING: Expected %d cluster(s), found %d " "cluster(s) in layer %s for weight %s \n", - number_of_clusters, + num_clusters, number_of_found_clusters, layer.name, weight.name, @@ -184,7 +217,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return the clustering stripped model.""" + """Clustering-specific post-processing rewrite option.""" return self._strip_clustering_wrapper(model) @@ -215,6 +248,7 @@ class RewriteConfiguration(OptimizerConfiguration): layers_to_optimize: list[str] | None = None dataset: Path | None = None train_params: TrainingParameters = TrainingParameters() + rewrite_specific_params: dict | None = None def __str__(self) -> str: """Return string representation of the configuration.""" @@ -227,8 +261,10 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ GenericRewrite("fully-connected", fc_rewrite), - Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), + SparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite), ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), + ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite), + SparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite), ] ) @@ -250,7 +286,6 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] - use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -272,6 +307,10 @@ class RewritingOptimizer(Optimizer): input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, + rewrite_specific_params=self.optimizer_configuration.rewrite_specific_params, # pylint: disable=line-too-long + detect_activation_function=( + "activation" in rewrite.return_rewrite_func_args() + ), ) if orig_vs_repl_stats: diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 4204978..570968a 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -34,13 +34,13 @@ from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.utils.logging import log_action - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logger = logging.getLogger(__name__) @@ -83,6 +83,8 @@ def train( # pylint: disable=too-many-arguments input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, + detect_activation_function: bool = False, ) -> Any: """Extract and train a model, and return the results.""" if unmodified_model: @@ -122,6 +124,8 @@ def train( # pylint: disable=too-many-arguments rewrite=rewrite, is_qat=is_qat, train_params=train_params, + rewrite_specific_params=rewrite_specific_params, + detect_activation_function=detect_activation_function, ) for i, filename in enumerate(tflite_filenames): @@ -349,6 +353,41 @@ def set_up_data_pipeline( return dataset, steps_per_epoch +def detect_activation_from_rewrite_function(model_path: str) -> str: + """Given a rewrite model, choose the most common activation function.""" + interpreter = tf.lite.Interpreter(model_path=model_path) + interpreter.allocate_tensors() + act_func_match_list = [] + for tensor_details in interpreter.get_tensor_details(): + for act_func in ACTIVATION_FUNCTION_LIST: + tensor_name = tensor_details["name"].lower() + if act_func in tensor_name: + act_func_idx = tensor_name.index(act_func) + if ( + len(tensor_name) == act_func_idx + len(act_func) + or tensor_name[act_func_idx + len(act_func)] == ";" + ): + act_func_match_list.append( + tensor_name[ + act_func_idx : act_func_idx + len(act_func) # noqa: E203 + ] + ) + act_func_match = "relu" + if len(act_func_match_list) == 0: + logger.info( + "No activation function specified, setting activation function to ReLU" + ) + else: + act_func_match = max(set(act_func_match_list), key=act_func_match.count) + logger.info( + "No activation function specified, " + "setting activation function to most " + "common activation detected in rewrite graph: %s", + act_func_match, + ) + return act_func_match + + def train_in_dir( train_dir: str, baseline_dir: Any, @@ -356,6 +395,8 @@ def train_in_dir( rewrite: Callable, is_qat: bool, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, + detect_activation_function: bool = False, ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ and output.tfrec in train_dir. @@ -372,6 +413,18 @@ def train_in_dir( ) replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir)) + if detect_activation_function and ( + rewrite_specific_params is None + or "activation" not in list(rewrite_specific_params.keys()) + ): + detected_activation_function = detect_activation_from_rewrite_function( + ExtractPaths.tflite.replace(train_dir).as_posix() + ) + if rewrite_specific_params: + rewrite_specific_params["activation"] = detected_activation_function + else: + rewrite_specific_params = {"activation": detected_activation_function} + input_name, output_name = _get_io_tensors(teacher) model_is_quantized = replace.is_tensor_quantized(name=input_name) @@ -396,7 +449,13 @@ def train_in_dir( loss_fn = keras.losses.MeanSquaredError() model = create_model( - rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + rewrite_specific_params=rewrite_specific_params, ) logger.info(model.summary()) @@ -462,11 +521,9 @@ def train_in_dir( steps_per_epoch, post_process=True, ) - - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) if model_is_quantized and is_qat: model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] checkpoints = ( @@ -501,11 +558,10 @@ def train_in_dir( loss_fn, steps_per_epoch, ) - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) teacher.close() return output_filenames @@ -528,9 +584,13 @@ def create_model( # pylint: disable=too-many-arguments loss_fn: Callable, model_is_quantized: bool, model_to_load_from: keras.model | None = None, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Create a model, optionally from another.""" - model = rewrite(input_shape, output_shape) + if rewrite_specific_params: + model = rewrite(input_shape, output_shape, **rewrite_specific_params) + else: + model = rewrite(input_shape, output_shape) if model_is_quantized: model = rewrite.quantize(model) # type: ignore[attr-defined] model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) @@ -558,6 +618,7 @@ def model_fit( # pylint: disable=too-many-arguments loss_fn: Callable, steps_per_epoch: int, post_process: bool = False, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Train a tflite model.""" steps_so_far = 0 @@ -597,6 +658,7 @@ def model_fit( # pylint: disable=too-many-arguments loss_fn, model_is_quantized, model_to_load_from=model, + rewrite_specific_params=rewrite_specific_params, ) else: model_to_save = model diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py new file mode 100644 index 0000000..b159763 --- /dev/null +++ b/src/mlia/nn/rewrite/library/clustering.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Rewrite functions used to return layers ready for clustering.""" +from typing import Any + +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function + + +def fc_clustering_rewrite( + input_shape: Any, + output_shape: Any, + num_clusters: int = 2, + cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long + "CentroidInitialization.LINEAR" + ), +) -> keras.Model: + """Fully connected TensorFlow Lite model ready for clustering.""" + rewrite_params = { + "number_of_clusters": num_clusters, + "cluster_centroids_init": cluster_centroids_init, + } + model = tfmot.clustering.keras.cluster_weights( + to_cluster=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Flatten(), + keras.layers.Dense(units=output_shape), + ] + ), + **rewrite_params + ) + return model + + +def conv2d_clustering_rewrite( + input_shape: Any, + output_shape: Any, + num_clusters: int = 2, + cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long + "CentroidInitialization.LINEAR" + ), + activation: str = "relu", +) -> keras.Model: + """Conv2d TensorFlow Lite model ready for clustering.""" + rewrite_params = { + "number_of_clusters": num_clusters, + "cluster_centroids_init": cluster_centroids_init, + } + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, output_shape=output_shape + ) + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + model = tfmot.clustering.keras.cluster_weights( + to_cluster=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv2d_parameters), + keras.layers.BatchNormalization(), + activation_function(**activation_function_extra_args), + ] + ), + **rewrite_params + ) + return model diff --git a/src/mlia/nn/rewrite/library/fc_clustering_layer.py b/src/mlia/nn/rewrite/library/fc_clustering_layer.py deleted file mode 100644 index 7cc383e..0000000 --- a/src/mlia/nn/rewrite/library/fc_clustering_layer.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Example rewrite with one fully connected clustered layer.""" -from typing import Any - -import tensorflow_model_optimization as tfmot -from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 - - -def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for clustering rewrite.""" - rewrite_params = { - "number_of_clusters": 32, - "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR, - } - model = tfmot.clustering.keras.cluster_weights( - to_cluster=keras.Sequential( - [ - keras.layers.InputLayer(input_shape=input_shape), - keras.layers.Flatten(), - keras.layers.Dense(units=output_shape), - ] - ), - **rewrite_params - ) - return model diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py index 041ce85..92195d1 100644 --- a/src/mlia/nn/rewrite/library/fc_layer.py +++ b/src/mlia/nn/rewrite/library/fc_layer.py @@ -1,13 +1,13 @@ # SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Example rewrite with one fully connected layer.""" +"""Rewrite function used to return regular layers.""" from typing import Any from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 -def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for rewrite.""" +def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: + """Fully connected TensorFlow Lite model for rewrite.""" model = keras.Sequential( ( keras.layers.InputLayer(input_shape=input_shape), diff --git a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py b/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py deleted file mode 100644 index 531b34a..0000000 --- a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Example rewrite with one fully connected 2:4 sparsity layer.""" -from typing import Any - -import tensorflow_model_optimization as tfmot -from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 - - -def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for rewrite.""" - model = tfmot.sparsity.keras.prune_low_magnitude( - to_prune=keras.Sequential( - [ - keras.layers.InputLayer(input_shape=input_shape), - keras.layers.Reshape([-1]), - keras.layers.Dense(output_shape), - ] - ), - sparsity_m_by_n=(2, 4), - ) - - return model diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py new file mode 100644 index 0000000..58d84b1 --- /dev/null +++ b/src/mlia/nn/rewrite/library/helper_functions.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Helper functions for the rewrite library.""" +import math +from typing import Any + +import numpy as np +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +ACTIVATION_FUNCTION_PRESETS = { + "relu": {"layer_func": keras.layers.ReLU, "extra_args": {}}, + "relu6": {"layer_func": keras.layers.ReLU, "extra_args": {"max_value": 6}}, + "none": {"layer_func": keras.layers.Identity, "extra_args": {}}, +} +ACTIVATION_FUNCTION_LIST = [ + act_func for act_func, _ in ACTIVATION_FUNCTION_PRESETS.items() +] + + +def get_activation_function( + activation: str = "relu", +) -> tuple[type[keras.layers.Layer], dict]: + """Get the activation function from a key.""" + if activation not in ACTIVATION_FUNCTION_LIST: + raise KeyError( + "Expected activation function to be " + f"in {ACTIVATION_FUNCTION_LIST}, found {activation}" + ) + activation_function = ACTIVATION_FUNCTION_PRESETS[activation]["layer_func"] + activation_function_extra_args = ACTIVATION_FUNCTION_PRESETS[activation][ + "extra_args" + ] + return activation_function, activation_function_extra_args + + +def compute_conv2d_parameters( + input_shape: np.ndarray, output_shape: np.ndarray +) -> dict[str, Any]: + """Compute needed kernel size and strides for a given input and output_shape.""" + input_shape = input_shape.tolist() + output_shape = output_shape.tolist() + assert len(input_shape) == 3 + assert len(output_shape) == 3 + num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1] + padding = "valid" + kernel_size = (3, 3) + stride_h = round(input_shape[0] / output_shape[0]) + check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1 + stride_w = round(input_shape[1] / output_shape[1]) + check_output_size_w = math.floor((input_shape[1] - kernel_size[1]) / stride_w) + 1 + if check_output_size_h != output_shape[0] or check_output_size_w != output_shape[1]: + padding = "same" + return { + "filters": num_filters, + "kernel_size": kernel_size, + "padding": padding, + "strides": (stride_h, stride_w), + } diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py new file mode 100644 index 0000000..95f99a7 --- /dev/null +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Rewrite functions used to return layers ready for sparse pruning.""" +from typing import Any + +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function + + +def fc_sparsity_rewrite( + input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4 +) -> keras.Model: + """Fully connected TensorFlow Lite model ready for sparse pruning.""" + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Reshape([-1]), + keras.layers.Dense(output_shape), + ] + ), + sparsity_m_by_n=( + sparsity_m, + sparsity_n, + ), + ) + + return model + + +def conv2d_sparsity_rewrite( + input_shape: Any, + output_shape: Any, + sparsity_m: int = 2, + sparsity_n: int = 4, + activation: str = "relu", +) -> keras.Model: + """Conv2d TensorFlow Lite model ready for sparse pruning.""" + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, output_shape=output_shape + ) + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv2d_parameters), + keras.layers.BatchNormalization(), + activation_function(**activation_function_extra_args), + ] + ), + sparsity_m_by_n=( + sparsity_m, + sparsity_n, + ), + ) + + return model diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index b61e713..d5470d1 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -17,7 +17,7 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewritingOptimizer -from mlia.nn.rewrite.core.rewrite import TrainingParameters +from mlia.nn.rewrite.core.train import TrainingParameters from mlia.nn.tensorflow.config import KerasModel from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.optimizations.clustering import Clusterer @@ -109,7 +109,7 @@ class MultiStageOptimizer(Optimizer): def apply_optimization(self) -> None: """Apply optimization to the model.""" for config in self.optimizations: - optimizer = get_optimizer(self.model, config) + optimizer = get_optimizer(self.model, config, {}) optimizer.apply_optimization() self.model = optimizer.get_model() @@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], - training_parameters: dict | None = None, + rewrite_parameters: dict, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -137,12 +137,12 @@ def get_optimizer( if isinstance(config, OptimizationSettings): return _get_optimizer( - model, cast(OptimizationSettings, config), training_parameters + model, cast(OptimizationSettings, config), rewrite_parameters ) if is_list_of(config, OptimizationSettings): return _get_optimizer( - model, cast(List[OptimizationSettings], config), training_parameters + model, cast(List[OptimizationSettings], config), rewrite_parameters ) raise ConfigurationError(f"Unknown optimization configuration {config}") @@ -151,7 +151,7 @@ def get_optimizer( def _get_optimizer( model: keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], - training_parameters: dict | None = None, + rewrite_parameters: dict, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] @@ -162,12 +162,12 @@ def _get_optimizer( _check_optimizer_params(opt_type, opt_target) opt_config = _get_optimizer_configuration( - opt_type, opt_target, layers_to_optimize, dataset, training_parameters + opt_type, opt_target, rewrite_parameters, layers_to_optimize, dataset ) optimizer_configs.append(opt_config) if len(optimizer_configs) == 1: - return get_optimizer(model, optimizer_configs[0]) + return get_optimizer(model, optimizer_configs[0], {}) return MultiStageOptimizer(model, optimizer_configs) @@ -189,9 +189,9 @@ def _get_rewrite_params( def _get_optimizer_configuration( optimization_type: str, optimization_target: int | float | str, + rewrite_parameters: dict, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, - training_parameters: dict | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -212,12 +212,14 @@ def _get_optimizer_configuration( if opt_type == "rewrite": if isinstance(optimization_target, str): - rewrite_params = _get_rewrite_params(training_parameters) return RewriteConfiguration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=rewrite_params, + train_params=_get_rewrite_params(rewrite_parameters["train_params"]), + rewrite_specific_params=rewrite_parameters.get( + "rewrite_specific_params" + ), ) raise ConfigurationError( diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml new file mode 100644 index 0000000..fe50c31 --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.conv2d-clustering] +num_clusters = 16 +cluster_centroids_init = "CentroidInitialization.LINEAR" +activation = "relu" diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml new file mode 100644 index 0000000..d0e05a7 --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.conv2d-sparsity] +sparsity_m = 2 +sparsity_n = 4 +activation = "relu" diff --git a/src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml b/src/mlia/resources/optimization_profiles/optimization-custom-augmentation.toml index 5d1f917..96d9742 100644 --- a/src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml +++ b/src/mlia/resources/optimization_profiles/optimization-custom-augmentation.toml @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -[training] +[rewrite.training_parameters] batch_size = 32 learning_rate = 1e-3 show_progress = true diff --git a/src/mlia/resources/optimization_profiles/optimization-fully-connected-clustering.toml b/src/mlia/resources/optimization_profiles/optimization-fully-connected-clustering.toml new file mode 100644 index 0000000..c5d460b --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-fully-connected-clustering.toml @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.fully-connected-clustering] +num_clusters = 16 +cluster_centroids_init = "CentroidInitialization.LINEAR" diff --git a/src/mlia/resources/optimization_profiles/optimization-fully-connected-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-fully-connected-pruning.toml new file mode 100644 index 0000000..f7f91ec --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-fully-connected-pruning.toml @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.fully-connected-sparsity] +sparsity_m = 2 +sparsity_n = 4 diff --git a/src/mlia/resources/optimization_profiles/optimization.toml b/src/mlia/resources/optimization_profiles/optimization.toml index 42b64f0..6f2800e 100644 --- a/src/mlia/resources/optimization_profiles/optimization.toml +++ b/src/mlia/resources/optimization_profiles/optimization.toml @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -[training] +[rewrite.training_parameters] batch_size = 32 learning_rate = 1e-3 show_progress = true diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py index a139a7d..69d3a24 100644 --- a/src/mlia/target/common/optimization.py +++ b/src/mlia/target/common/optimization.py @@ -51,7 +51,7 @@ class OptimizingDataCollector(ContextAwareDataCollector): optimizations = self._get_optimization_settings(self.context) - training_parameters = self._get_training_settings(self.context) + rewrite_parameters = self._get_rewrite_settings(self.context) if not optimizations or optimizations == [[]]: raise FunctionalityNotSupportedError( @@ -78,7 +78,7 @@ class OptimizingDataCollector(ContextAwareDataCollector): model = self.model # type: ignore optimizers: list[Callable] = [ - partial(self.optimize_model, opts, training_parameters) + partial(self.optimize_model, opts, rewrite_parameters) for opts in opt_settings ] @@ -87,12 +87,12 @@ class OptimizingDataCollector(ContextAwareDataCollector): def optimize_model( self, opt_settings: list[OptimizationSettings], - training_parameters: dict | None, + rewrite_parameters: dict, model: KerasModel | TFLiteModel, ) -> Any: """Run optimization.""" optimizer = get_optimizer( - model, opt_settings, training_parameters=training_parameters + model, opt_settings, rewrite_parameters=rewrite_parameters ) opts_as_str = ", ".join(str(opt) for opt in opt_settings) logger.info("Applying model optimizations - [%s]", opts_as_str) @@ -124,11 +124,11 @@ class OptimizingDataCollector(ContextAwareDataCollector): context=context, ) - def _get_training_settings(self, context: Context) -> dict: + def _get_rewrite_settings(self, context: Context) -> list[dict]: """Get optimization settings.""" return self.get_parameter( # type: ignore OptimizingDataCollector.name(), - "training_parameters", + "rewrite_parameters", expected_type=dict, expected=False, context=context, @@ -234,7 +234,7 @@ def parse_augmentations( valid_keys = ["mixup_strength", "gaussian_strength"] tuple_to_return = [] for valid_key in valid_keys.copy(): - if augmentations.get(valid_key): + if augmentations.get(valid_key) is not None: del augmentation_keys_test_for_valid[ augmentation_keys_test_for_valid.index(valid_key) ] @@ -247,7 +247,6 @@ def parse_augmentations( tuple_to_return.append(None) else: tuple_to_return.append(None) - if len(augmentation_keys_test_for_valid) > 0: logger.warning( "Warning! Expected augmentation parameters to be 'gaussian_strength' " @@ -275,23 +274,32 @@ def add_common_optimization_params( # pylint: disable=too-many-branches if not is_list_of(optimization_targets, dict): raise TypeError("Optimization targets value has wrong format.") - rewrite_parameters = extra_args.get("optimization_profile") training_parameters = None - if rewrite_parameters: - if not isinstance(rewrite_parameters, dict): - raise TypeError("Training Parameter values has wrong format.") - training_parameters = extra_args["optimization_profile"].get("training") - - if training_parameters: - training_parameters["augmentations"] = parse_augmentations( - training_parameters.get("augmentations") - ) + rewrite_specific_parameters = None + + optimization_parameters = extra_args.get("optimization_profile") + if optimization_parameters: # pylint: disable=too-many-nested-blocks + if not isinstance(optimization_parameters, dict): + raise TypeError("Optimization Parameter values has wrong format.") + + if optimization_parameters.get("rewrite"): + rewrite_params = optimization_parameters["rewrite"] + training_parameters = rewrite_params.get("training_parameters") + if training_parameters: + training_parameters["augmentations"] = parse_augmentations( + training_parameters.get("augmentations") + ) + optimization_target = optimization_targets[0]["optimization_target"] + rewrite_specific_parameters = rewrite_params.get(optimization_target) advisor_parameters.update( { "common_optimizations": { "optimizations": [optimization_targets], - "training_parameters": training_parameters, + "rewrite_parameters": { + "train_params": training_parameters, + "rewrite_specific_params": rewrite_specific_parameters, + }, }, } ) diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index 8492086..236511c 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -71,7 +71,14 @@ def is_builtin_target_profile(profile_name: str | Path) -> bool: return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES -BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = ["optimization"] +BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = [ + "optimization", + "optimization-custom-augmentation", + "optimization-fully-connected-clustering", + "optimization-fully-connected-pruning", + "optimization-conv2d-clustering", + "optimization-conv2d-pruning", +] def is_builtin_optimization_profile(optimization_name: str | Path) -> bool: |