diff options
31 files changed, 1521 insertions, 283 deletions
@@ -42,7 +42,9 @@ Information on reporting security issues can be found in ## License -ML Inference Advisor is licensed under [Apache License 2.0](LICENSES/Apache-2.0.txt). +ML Inference Advisor is licensed under [Apache License 2.0](LICENSES/Apache-2.0.txt) +unless otherwise indicated. This project contains software under a range of +permissive licenses, see [LICENSES](LICENSES/). ## Trademarks and copyrights @@ -181,6 +183,16 @@ documentation, e.g. in the candidates from the rewrite library, with or without training using a small portion of the training data, to achieve local performance gains. +The following rewrites are supported: + +* fully-connected - replaces a subgraph with a fully connected layer +* fully-connected-sparsity - replaces a subgraph with a pruned 2:4 sparse fully connected layer +* fully-connected-unstructured-sparsity - replaces a subgraph with an unstructured pruned fully connected layer +* fully-connected-clustering - replaces a subgraph with a clustered fully connected layer +* conv2d-sparsity - replaces a subgraph with a pruned 2:4 sparse conv2d layer +* conv2d-unstructured-sparsity - replaces a subgraph with an unstructured pruned conv2d layer +* conv2d-clustering - replaces a subgraph with a clustered conv2d layer + **Note:** A ***Keras model*** (.h5 or SavedModel) is required as input to perform pruning and clustering. A ***TensorFlow Lite model*** is required as input to perform a rewrite. @@ -209,15 +221,72 @@ mlia optimize ~/models/ds_cnn_large_fp32.tflite \ --rewrite-end MobileNet/fc1/BiasAdd ``` -### optimization Profiles +### Optimization Profiles Training parameters for rewrites can be specified. -There are a number of predefined profiles: +There are a number of predefined profiles for rewrites shown below: + +| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Augmentations | +| :----------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :-------------: | +| optimization | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | "gaussian" | + +| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Num Clusters | Cluster Centroids Init | +| :-------------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :----------: | :--------------------------------: | +| optimization-fully-connected-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 16 | "CentroidInitialization.LINEAR" | + +| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Sparsity M | Sparsity N | +| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: | +| optimization-fully-connected-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 2 | 4 | + +| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Initial Sparsity | End Sparsity | End Step | +| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: | :--------: | +| optimization-fully-connected-unstructured-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 0.25 | 0.5 | 48000 | + +| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Num Clusters | Cluster Centroids Init | Activation | Kernel Size | +| :-------------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :----------: | :--------------------------------: | :--------: | :---------: | +| optimization-conv2d-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 16 | "CentroidInitialization.LINEAR" | "relu" | 3x3 | + +| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Sparsity M | Sparsity N | Activation | Kernel Size | +| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: | :--------: | :---------: | +| optimization-conv2d-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 2 | 4 | "relu" | 3x3 | + +| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Initial Sparsity | End Sparsity | End Step | Activation | Kernel Size | +| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: | :--------: | :--------:| :---------: | +| optimization-conv2d-unstructured-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 0.25 | 0.5 | 48000 | "relu" | 3x3 | + +These are summarized below: + +* optimization - Provides training parameters for rewrites +* optimization-fully-connected-clustering - Provides training parameters for rewrites and cluster specific parameters for the fully-connected-clustering rewrite +* optimization-fully-connected-pruning - Provides training parameters for rewrites and pruning specific parameters for the fully-connected-sparsity rewrite +* optimization-conv2d-clustering - Provides training parameters for rewrites and cluster specific parameters for the conv2d-clustering rewrite +* optimization-conv2d-pruning - Provides training parameters for rewrites and pruning specific parameters for the conv2d-sparsity rewrite + +Note for convolutional rewrites (e.g. optimization-conv2d-pruning). The activation function for the rewrite can be selected in the optimization profile from the following list: -| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | -| :----------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | -| optimization | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | +* "relu" - Standard ReLU activation function +* "relu6" - ReLU6 activation function i.e. ReLU activation function capped at 6 +* "none" - No activation function + +The user can also specify custom augmentations as part of the training parameters. An example of this can be found in the following optimization profile: + +| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Augmentations - gaussian_strength | Augmentations - mixup_strength | +| :------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :-------------------------------: | :----------------------------: | +| optimization-custom-augmentation | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 0.1 | 0.1 | + +The augmentations consist of 2 parameters: mixup strength and gaussian strength. + +Augmentations can be selected from a number of pre-defined profiles (see the table below) or each individual parameter can be chosen (see optimization_custom_augmentation above for an example): + +| Name | MixUp Strength | Gaussian Strength | +| :------------------: | :------------: | :---------------: | +| "none" | None | None | +| "gaussian" | None | 1.0 | +| "mixup" | 1.0 | None | +| "mixout" | 1.6 | None | +| "mix_gaussian_large" | 2.0 | 1.0 | +| "mix_gaussian_small" | 1.6 | 0.3 | ```bash ##### An example for using optimization Profiles @@ -228,7 +297,7 @@ mlia optimize ~/models/ds_cnn_large_fp32.tflite \ --dataset input.tfrec \ --rewrite-target fully-connected \ --rewrite-start MobileNet/avg_pool/AvgPool \ - --rewrite-end MobileNet/fc1/BiasAdd_ + --rewrite-end MobileNet/fc1/BiasAdd ``` #### Custom optimization Profiles @@ -244,7 +313,17 @@ apply for each optimization. ``` bash # for custom profiles -mlia ops --optimization-profile ~/my_custom_optimization_profile.toml +mlia optimize --optimization-profile ~/my_custom_optimization_profile.toml +``` + +When providing rewrite-specific parameters e.g. for clustering, the rewrite name should be specified in the toml: + +For example, the following provides rewrite-specific parameters for the fully-connected-clustering rewrite + +``` bash +[rewrite.fully-connected-clustering] +num_clusters = 16 +cluster_centroids_init = "CentroidInitialization.LINEAR" ``` # Target profiles diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index e2c097c..c2ad364 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -8,13 +8,20 @@ import tempfile from abc import ABC from abc import abstractmethod from dataclasses import dataclass +from inspect import getfullargspec from pathlib import Path +from statistics import fmean from typing import Any from typing import Callable +from typing import Generator import numpy as np +import tensorflow as tf import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 +from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module + is_pruned_m_by_n, +) from mlia.core.errors import ConfigurationError from mlia.core.reporting import Column @@ -24,19 +31,18 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters -from mlia.nn.rewrite.library.fc_clustering_layer import ( - get_keras_model_clus as fc_clustering_rewrite, -) -from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite -from mlia.nn.rewrite.library.fc_sparsity24_layer import ( - get_keras_model as fc_rewrite_sparsity24, -) +from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite +from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite +from mlia.nn.rewrite.library.fc_layer import fc_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_unstructured_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_unstructured_rewrite from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry - logger = logging.getLogger(__name__) -RewriteCallable = Callable[[Any, Any], keras.Model] +RewriteCallable = Callable[..., keras.Model] class Rewrite(ABC): @@ -47,10 +53,23 @@ class Rewrite(ABC): self.name = name self.function = rewrite_fn - def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model: + def __call__( + self, input_shape: Any, output_shape: Any, **kwargs: Any + ) -> keras.Model: """Perform the rewrite operation using the configured function.""" try: - return self.function(input_shape, output_shape) + return self.function(input_shape, output_shape, **kwargs) + except TypeError as ex: + expected_args = self.return_rewrite_func_args() + if "input_shape" in expected_args: + expected_args.remove("input_shape") + if "output_shape" in expected_args: + expected_args.remove("output_shape") + raise KeyError( + f"Found unexpected parameters for rewrite. Expected (sub)set " + f"of {expected_args} found unexpected parameter(s) " + f"{list(set(list(kwargs.keys())) - set(expected_args))}" + ) from ex except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex @@ -58,21 +77,25 @@ class Rewrite(ABC): """Return a quantized model if required.""" return model + def return_rewrite_func_args(self) -> list[str]: + """Return the expected args of the rewrite function.""" + return getfullargspec(self.function).args + @abstractmethod def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" + """Return rewrite callbacks.""" @abstractmethod def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return post-processing rewrite option.""" @abstractmethod - def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + def check_optimization(self, model: keras.Model) -> bool: """Check if the optimization has produced the correct result.""" class GenericRewrite(Rewrite): - """Graph rewrite logic for fully-connected rewrite.""" + """Rewrite class for generic rewrites e.g. fully-connected.""" def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" @@ -83,10 +106,10 @@ class GenericRewrite(Rewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return default post-processing rewrite option.""" return model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization(self, model: keras.Model) -> bool: """Not needed here.""" return True @@ -99,16 +122,27 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC): """Apply optimization-aware quantization to a given model.""" return model + def check_optimization_generator( + self, model: keras.Model + ) -> Generator[tuple[tf.Tensor, keras.layers.Layer], None, None]: + """Loop for check_optimization function.""" + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + yield weight, layer + -class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): - """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" +class SparsityRewrite(QuantizeAwareTrainingRewrite): + """Base rewrite class for sparsity rewrites.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) def quantize(self, model: keras.Model) -> keras.Model: - """Skip quantization when using pruning rewrite.""" + """Skip quantization when using sparsity rewrite.""" return model def training_callbacks(self) -> list: @@ -116,7 +150,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): return [self.pruning_callback()] def post_process(self, model: keras.Model) -> keras.Model: - """Pruning-specific post-processing rewrite options.""" + """Pruning-specific post-processing rewrite option.""" return self.strip_pruning_wrapper(model) def preserved_quantize( @@ -129,16 +163,78 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): model, tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), ) - return model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization(self, model: keras.Model) -> bool: """Not needed here.""" return True +class UnstructuredSparsityRewrite(SparsityRewrite): + """ + Rewrite class for unstructured sparsity rewrite. + + e.g. fully-connected-unstructured-sparsity. + """ + + def check_optimization( + self, model: keras.Model, final_sparsity: float = 0.5, **_: Any + ) -> bool: + """Not needed here.""" + found_sparsity_list = [] + num_dec_places = str(final_sparsity)[::-1].find(".") + for weight, _ in self.check_optimization_generator(model=model): + weight_np = weight.numpy() + found_sparsity_list.append( + round(np.count_nonzero(weight_np) / weight_np.size, num_dec_places) + ) + if len(found_sparsity_list) == 0: + logger.warning( + "\nWARNING: Could not find any layers " + "in rewrite that could be sparsely pruned" + ) + return False + found_sparsity = fmean(found_sparsity_list) + if found_sparsity != final_sparsity: + logger.warning( + "\nWARNING: Found total sparsity of " + "rewrite model: %.2f " + "expected total sparsity to be: " + "%.2f\n", + found_sparsity, + final_sparsity, + ) + return False + return True + + +class StructuredSparsityRewrite(SparsityRewrite): + """Rewrite class for structured sparsity rewrite e.g. fully-connected-sparsity.""" + + def check_optimization( + self, + model: keras.Model, + sparsity_m: int = 2, + sparsity_n: int = 4, + **_: Any, + ) -> bool: + """Check if sparity has produced the correct result.""" + for weight, layer in self.check_optimization_generator(model=model): + if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)): + logger.warning( + "\nWARNING: Could not find (%d, %d) sparsity, " + "in layer %s for weight %s \n", + sparsity_m, + sparsity_n, + layer.name, + weight.name, + ) + return False + return True + + class ClusteringRewrite(QuantizeAwareTrainingRewrite): - """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + """Rewrite class for clustering rewrite e.g. fully-connected-clustering.""" _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) @@ -151,32 +247,22 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): ) return cqat_model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization( + self, model: keras.Model, num_clusters: int = 2, **_: Any + ) -> bool: """Check if clustering has produced the correct result.""" - number_of_clusters = kwargs.get("number_of_clusters") - if not number_of_clusters: - raise ValueError( - """ - Expected check_preserved_quantize to have argument number_of_clusters. - """ - ) - - for layer in model.layers: - for weight in layer.weights: - if "kernel" in weight.name: - if "kernel_min" in weight.name or "kernel_max" in weight.name: - continue - number_of_found_clusters = len(np.unique(weight)) - if number_of_found_clusters != number_of_clusters: - logger.warning( - "\nWARNING: Expected %d cluster(s), found %d " - "cluster(s) in layer %s for weight %s \n", - number_of_clusters, - number_of_found_clusters, - layer.name, - weight.name, - ) - return False + for weight, layer in self.check_optimization_generator(model=model): + number_of_found_clusters = len(np.unique(weight)) + if number_of_found_clusters != num_clusters: + logger.warning( + "\nWARNING: Expected %d cluster(s), found %d " + "cluster(s) in layer %s for weight %s \n", + num_clusters, + number_of_found_clusters, + layer.name, + weight.name, + ) + return False return True def training_callbacks(self) -> list: @@ -184,7 +270,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return the clustering stripped model.""" + """Clustering-specific post-processing rewrite option.""" return self._strip_clustering_wrapper(model) @@ -215,6 +301,7 @@ class RewriteConfiguration(OptimizerConfiguration): layers_to_optimize: list[str] | None = None dataset: Path | None = None train_params: TrainingParameters = TrainingParameters() + rewrite_specific_params: dict | None = None def __str__(self) -> str: """Return string representation of the configuration.""" @@ -227,8 +314,17 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ GenericRewrite("fully-connected", fc_rewrite), - Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), + StructuredSparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite), ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), + ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite), + StructuredSparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite), + UnstructuredSparsityRewrite( + "conv2d-unstructured-sparsity", conv2d_sparsity_unstructured_rewrite + ), + UnstructuredSparsityRewrite( + "fully-connected-unstructured-sparsity", + fc_sparsity_unstructured_rewrite, + ), ] ) @@ -250,7 +346,6 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] - use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -272,6 +367,10 @@ class RewritingOptimizer(Optimizer): input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, + rewrite_specific_params=self.optimizer_configuration.rewrite_specific_params, # pylint: disable=line-too-long + detect_activation_function=( + "activation" in rewrite.return_rewrite_func_args() + ), ) if orig_vs_repl_stats: diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 4204978..570968a 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -34,13 +34,13 @@ from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.utils.logging import log_action - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logger = logging.getLogger(__name__) @@ -83,6 +83,8 @@ def train( # pylint: disable=too-many-arguments input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, + detect_activation_function: bool = False, ) -> Any: """Extract and train a model, and return the results.""" if unmodified_model: @@ -122,6 +124,8 @@ def train( # pylint: disable=too-many-arguments rewrite=rewrite, is_qat=is_qat, train_params=train_params, + rewrite_specific_params=rewrite_specific_params, + detect_activation_function=detect_activation_function, ) for i, filename in enumerate(tflite_filenames): @@ -349,6 +353,41 @@ def set_up_data_pipeline( return dataset, steps_per_epoch +def detect_activation_from_rewrite_function(model_path: str) -> str: + """Given a rewrite model, choose the most common activation function.""" + interpreter = tf.lite.Interpreter(model_path=model_path) + interpreter.allocate_tensors() + act_func_match_list = [] + for tensor_details in interpreter.get_tensor_details(): + for act_func in ACTIVATION_FUNCTION_LIST: + tensor_name = tensor_details["name"].lower() + if act_func in tensor_name: + act_func_idx = tensor_name.index(act_func) + if ( + len(tensor_name) == act_func_idx + len(act_func) + or tensor_name[act_func_idx + len(act_func)] == ";" + ): + act_func_match_list.append( + tensor_name[ + act_func_idx : act_func_idx + len(act_func) # noqa: E203 + ] + ) + act_func_match = "relu" + if len(act_func_match_list) == 0: + logger.info( + "No activation function specified, setting activation function to ReLU" + ) + else: + act_func_match = max(set(act_func_match_list), key=act_func_match.count) + logger.info( + "No activation function specified, " + "setting activation function to most " + "common activation detected in rewrite graph: %s", + act_func_match, + ) + return act_func_match + + def train_in_dir( train_dir: str, baseline_dir: Any, @@ -356,6 +395,8 @@ def train_in_dir( rewrite: Callable, is_qat: bool, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, + detect_activation_function: bool = False, ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ and output.tfrec in train_dir. @@ -372,6 +413,18 @@ def train_in_dir( ) replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir)) + if detect_activation_function and ( + rewrite_specific_params is None + or "activation" not in list(rewrite_specific_params.keys()) + ): + detected_activation_function = detect_activation_from_rewrite_function( + ExtractPaths.tflite.replace(train_dir).as_posix() + ) + if rewrite_specific_params: + rewrite_specific_params["activation"] = detected_activation_function + else: + rewrite_specific_params = {"activation": detected_activation_function} + input_name, output_name = _get_io_tensors(teacher) model_is_quantized = replace.is_tensor_quantized(name=input_name) @@ -396,7 +449,13 @@ def train_in_dir( loss_fn = keras.losses.MeanSquaredError() model = create_model( - rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + rewrite_specific_params=rewrite_specific_params, ) logger.info(model.summary()) @@ -462,11 +521,9 @@ def train_in_dir( steps_per_epoch, post_process=True, ) - - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) if model_is_quantized and is_qat: model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] checkpoints = ( @@ -501,11 +558,10 @@ def train_in_dir( loss_fn, steps_per_epoch, ) - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) teacher.close() return output_filenames @@ -528,9 +584,13 @@ def create_model( # pylint: disable=too-many-arguments loss_fn: Callable, model_is_quantized: bool, model_to_load_from: keras.model | None = None, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Create a model, optionally from another.""" - model = rewrite(input_shape, output_shape) + if rewrite_specific_params: + model = rewrite(input_shape, output_shape, **rewrite_specific_params) + else: + model = rewrite(input_shape, output_shape) if model_is_quantized: model = rewrite.quantize(model) # type: ignore[attr-defined] model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) @@ -558,6 +618,7 @@ def model_fit( # pylint: disable=too-many-arguments loss_fn: Callable, steps_per_epoch: int, post_process: bool = False, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Train a tflite model.""" steps_so_far = 0 @@ -597,6 +658,7 @@ def model_fit( # pylint: disable=too-many-arguments loss_fn, model_is_quantized, model_to_load_from=model, + rewrite_specific_params=rewrite_specific_params, ) else: model_to_save = model diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py new file mode 100644 index 0000000..48914dc --- /dev/null +++ b/src/mlia/nn/rewrite/library/clustering.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Rewrite functions used to return layers ready for clustering.""" +from typing import Any + +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function + + +def fc_clustering_rewrite( + input_shape: Any, + output_shape: Any, + num_clusters: int = 2, + cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long + "CentroidInitialization.LINEAR" + ), +) -> keras.Model: + """Fully connected TensorFlow Lite model ready for clustering.""" + rewrite_params = { + "number_of_clusters": num_clusters, + "cluster_centroids_init": cluster_centroids_init, + } + model = tfmot.clustering.keras.cluster_weights( + to_cluster=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Flatten(), + keras.layers.Dense(units=output_shape), + ] + ), + **rewrite_params + ) + return model + + +def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value + input_shape: Any, + output_shape: Any, + num_clusters: int = 2, + cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long + "CentroidInitialization.LINEAR" + ), + activation: str = "relu", + kernel_size: list[int] = [3, 3], +) -> keras.Model: + """Conv2d TensorFlow Lite model ready for clustering.""" + rewrite_params = { + "number_of_clusters": num_clusters, + "cluster_centroids_init": cluster_centroids_init, + } + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, + output_shape=output_shape, + kernel_size_input=kernel_size, + ) + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + activation_func_found = ( + [activation_function(**activation_function_extra_args)] + if activation_function + else [] + ) + model = tfmot.clustering.keras.cluster_weights( + to_cluster=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv2d_parameters), + keras.layers.BatchNormalization(), + *activation_func_found, + ] + ), + **rewrite_params + ) + return model diff --git a/src/mlia/nn/rewrite/library/fc_clustering_layer.py b/src/mlia/nn/rewrite/library/fc_clustering_layer.py deleted file mode 100644 index 7cc383e..0000000 --- a/src/mlia/nn/rewrite/library/fc_clustering_layer.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Example rewrite with one fully connected clustered layer.""" -from typing import Any - -import tensorflow_model_optimization as tfmot -from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 - - -def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for clustering rewrite.""" - rewrite_params = { - "number_of_clusters": 32, - "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR, - } - model = tfmot.clustering.keras.cluster_weights( - to_cluster=keras.Sequential( - [ - keras.layers.InputLayer(input_shape=input_shape), - keras.layers.Flatten(), - keras.layers.Dense(units=output_shape), - ] - ), - **rewrite_params - ) - return model diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py index 041ce85..92195d1 100644 --- a/src/mlia/nn/rewrite/library/fc_layer.py +++ b/src/mlia/nn/rewrite/library/fc_layer.py @@ -1,13 +1,13 @@ # SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Example rewrite with one fully connected layer.""" +"""Rewrite function used to return regular layers.""" from typing import Any from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 -def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for rewrite.""" +def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: + """Fully connected TensorFlow Lite model for rewrite.""" model = keras.Sequential( ( keras.layers.InputLayer(input_shape=input_shape), diff --git a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py b/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py deleted file mode 100644 index 531b34a..0000000 --- a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Example rewrite with one fully connected 2:4 sparsity layer.""" -from typing import Any - -import tensorflow_model_optimization as tfmot -from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 - - -def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for rewrite.""" - model = tfmot.sparsity.keras.prune_low_magnitude( - to_prune=keras.Sequential( - [ - keras.layers.InputLayer(input_shape=input_shape), - keras.layers.Reshape([-1]), - keras.layers.Dense(output_shape), - ] - ), - sparsity_m_by_n=(2, 4), - ) - - return model diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py new file mode 100644 index 0000000..1237c17 --- /dev/null +++ b/src/mlia/nn/rewrite/library/helper_functions.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Helper functions for the rewrite library.""" +import math +from typing import Any + +import numpy as np +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +ACTIVATION_FUNCTION_PRESETS = { + "relu": {"layer_func": keras.layers.ReLU, "extra_args": {}}, + "relu6": {"layer_func": keras.layers.ReLU, "extra_args": {"max_value": 6}}, + "none": {"layer_func": None, "extra_args": {}}, +} +ACTIVATION_FUNCTION_LIST = [ + act_func for act_func, _ in ACTIVATION_FUNCTION_PRESETS.items() +] + + +def get_activation_function( + activation: str = "relu", +) -> tuple[type, dict]: + """Get the activation function from a key.""" + if activation not in ACTIVATION_FUNCTION_LIST: + raise KeyError( + "Expected activation function to be " + f"in {ACTIVATION_FUNCTION_LIST}, found {activation}" + ) + activation_function = ACTIVATION_FUNCTION_PRESETS[activation]["layer_func"] + activation_function_extra_args = ACTIVATION_FUNCTION_PRESETS[activation][ + "extra_args" + ] + return activation_function, activation_function_extra_args + + +def compute_conv2d_parameters( # pylint: disable=dangerous-default-value + input_shape: np.ndarray, + output_shape: np.ndarray, + kernel_size_input: list[int] = [3, 3], +) -> dict[str, Any]: + """Compute needed kernel size and strides for a given input and output_shape.""" + input_shape = input_shape.tolist() + output_shape = output_shape.tolist() + assert len(kernel_size_input) == 2, "Kernel size should have 2 entries" + assert len(input_shape) == 3 + assert len(output_shape) == 3 + kernel_size = tuple(kernel_size_input) + num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1] + padding = "valid" + stride_h = round(input_shape[0] / output_shape[0]) + check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1 + stride_w = round(input_shape[1] / output_shape[1]) + check_output_size_w = math.floor((input_shape[1] - kernel_size[1]) / stride_w) + 1 + if check_output_size_h != output_shape[0] or check_output_size_w != output_shape[1]: + padding = "same" + return { + "filters": num_filters, + "kernel_size": kernel_size, + "padding": padding, + "strides": (stride_h, stride_w), + } diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py new file mode 100644 index 0000000..1e53254 --- /dev/null +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Rewrite functions used to return layers ready for sparse pruning.""" +from __future__ import annotations + +from typing import Any + +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function + + +def fc_sparsity_unstructured_rewrite( + input_shape: Any, + output_shape: Any, + initial_sparsity: float = 0.5, + final_sparsity: float = 0.5, + begin_step: int = 0, + end_step: int = 48000, +) -> keras.Model: + """Fully connected TensorFlow Lite model ready for unstructured sparse pruning.""" + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Reshape([-1]), + keras.layers.Dense(output_shape), + ] + ), + pruning_schedule=tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=initial_sparsity, + final_sparsity=final_sparsity, + begin_step=begin_step, + end_step=end_step, + ), + ) + + return model + + +def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value + input_shape: Any, + output_shape: Any, + initial_sparsity: float = 0.5, + final_sparsity: float = 0.5, + begin_step: int = 0, + end_step: int = 48000, + activation: str = "relu", + kernel_size: list[int] = [3, 3], +) -> keras.Model: + """Conv2d TensorFlow Lite model ready for unstructured sparse pruning.""" + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, + output_shape=output_shape, + kernel_size_input=kernel_size, + ) + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + activation_func_found = ( + [activation_function(**activation_function_extra_args)] + if activation_function + else [] + ) + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv2d_parameters), + keras.layers.BatchNormalization(), + *activation_func_found, + ] + ), + pruning_schedule=tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=initial_sparsity, + final_sparsity=final_sparsity, + begin_step=begin_step, + end_step=end_step, + ), + ) + + return model + + +def fc_sparsity_rewrite( + input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4 +) -> keras.Model: + """Fully connected TensorFlow Lite model ready for sparse pruning.""" + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Reshape([-1]), + keras.layers.Dense(output_shape), + ] + ), + sparsity_m_by_n=( + sparsity_m, + sparsity_n, + ), + ) + + return model + + +def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value + input_shape: Any, + output_shape: Any, + sparsity_m: int = 2, + sparsity_n: int = 4, + activation: str = "relu", + kernel_size: list[int] = [3, 3], +) -> keras.Model: + """Conv2d TensorFlow Lite model ready for sparse pruning.""" + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, + output_shape=output_shape, + kernel_size_input=kernel_size, + ) + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + activation_func_found = ( + [activation_function(**activation_function_extra_args)] + if activation_function + else [] + ) + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv2d_parameters), + keras.layers.BatchNormalization(), + *activation_func_found, + ] + ), + sparsity_m_by_n=( + sparsity_m, + sparsity_n, + ), + ) + return model diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index b61e713..d5470d1 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -17,7 +17,7 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewritingOptimizer -from mlia.nn.rewrite.core.rewrite import TrainingParameters +from mlia.nn.rewrite.core.train import TrainingParameters from mlia.nn.tensorflow.config import KerasModel from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.optimizations.clustering import Clusterer @@ -109,7 +109,7 @@ class MultiStageOptimizer(Optimizer): def apply_optimization(self) -> None: """Apply optimization to the model.""" for config in self.optimizations: - optimizer = get_optimizer(self.model, config) + optimizer = get_optimizer(self.model, config, {}) optimizer.apply_optimization() self.model = optimizer.get_model() @@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], - training_parameters: dict | None = None, + rewrite_parameters: dict, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -137,12 +137,12 @@ def get_optimizer( if isinstance(config, OptimizationSettings): return _get_optimizer( - model, cast(OptimizationSettings, config), training_parameters + model, cast(OptimizationSettings, config), rewrite_parameters ) if is_list_of(config, OptimizationSettings): return _get_optimizer( - model, cast(List[OptimizationSettings], config), training_parameters + model, cast(List[OptimizationSettings], config), rewrite_parameters ) raise ConfigurationError(f"Unknown optimization configuration {config}") @@ -151,7 +151,7 @@ def get_optimizer( def _get_optimizer( model: keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], - training_parameters: dict | None = None, + rewrite_parameters: dict, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] @@ -162,12 +162,12 @@ def _get_optimizer( _check_optimizer_params(opt_type, opt_target) opt_config = _get_optimizer_configuration( - opt_type, opt_target, layers_to_optimize, dataset, training_parameters + opt_type, opt_target, rewrite_parameters, layers_to_optimize, dataset ) optimizer_configs.append(opt_config) if len(optimizer_configs) == 1: - return get_optimizer(model, optimizer_configs[0]) + return get_optimizer(model, optimizer_configs[0], {}) return MultiStageOptimizer(model, optimizer_configs) @@ -189,9 +189,9 @@ def _get_rewrite_params( def _get_optimizer_configuration( optimization_type: str, optimization_target: int | float | str, + rewrite_parameters: dict, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, - training_parameters: dict | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -212,12 +212,14 @@ def _get_optimizer_configuration( if opt_type == "rewrite": if isinstance(optimization_target, str): - rewrite_params = _get_rewrite_params(training_parameters) return RewriteConfiguration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=rewrite_params, + train_params=_get_rewrite_params(rewrite_parameters["train_params"]), + rewrite_specific_params=rewrite_parameters.get( + "rewrite_specific_params" + ), ) raise ConfigurationError( diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml new file mode 100644 index 0000000..3d8adfa --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.conv2d-clustering] +num_clusters = 16 +cluster_centroids_init = "CentroidInitialization.LINEAR" +activation = "relu" +kernel_size = [3, 3] diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml new file mode 100644 index 0000000..aa7f982 --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.conv2d-sparsity] +sparsity_m = 2 +sparsity_n = 4 +activation = "relu" +kernel_size = [3, 3] diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml new file mode 100644 index 0000000..67740ca --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.conv2d-unstructured-sparsity] +initial_sparsity = 0.25 +final_sparsity = 0.5 +end_step = 48000 +activation = "relu" +kernel_size = [3, 3] diff --git a/src/mlia/resources/optimization_profiles/optimization-custom-augmentation.toml b/src/mlia/resources/optimization_profiles/optimization-custom-augmentation.toml new file mode 100644 index 0000000..96d9742 --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-custom-augmentation.toml @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.1 +augmentations.mixup_strength = 0.1 diff --git a/src/mlia/resources/optimization_profiles/optimization-fully-connected-clustering.toml b/src/mlia/resources/optimization_profiles/optimization-fully-connected-clustering.toml new file mode 100644 index 0000000..c5d460b --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-fully-connected-clustering.toml @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.fully-connected-clustering] +num_clusters = 16 +cluster_centroids_init = "CentroidInitialization.LINEAR" diff --git a/src/mlia/resources/optimization_profiles/optimization-fully-connected-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-fully-connected-pruning.toml new file mode 100644 index 0000000..f7f91ec --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-fully-connected-pruning.toml @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.fully-connected-sparsity] +sparsity_m = 2 +sparsity_n = 4 diff --git a/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml new file mode 100644 index 0000000..cd5f745 --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[rewrite.training_parameters] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 +augmentations.gaussian_strength = 0.0 +augmentations.mixup_strength = 0.0 + +[rewrite.fully-connected-unstructured-sparsity] +initial_sparsity = 0.25 +final_sparsity = 0.5 +end_step = 48000 diff --git a/src/mlia/resources/optimization_profiles/optimization.toml b/src/mlia/resources/optimization_profiles/optimization.toml index 623a763..6f2800e 100644 --- a/src/mlia/resources/optimization_profiles/optimization.toml +++ b/src/mlia/resources/optimization_profiles/optimization.toml @@ -1,11 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -[training] +[rewrite.training_parameters] batch_size = 32 learning_rate = 1e-3 show_progress = true steps = 48000 learning_rate_schedule = "cosine" +augmentations = "gaussian" num_procs = 1 num_threads = 0 diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py index 1423189..69d3a24 100644 --- a/src/mlia/target/common/optimization.py +++ b/src/mlia/target/common/optimization.py @@ -17,6 +17,7 @@ from mlia.core.errors import FunctionalityNotSupportedError from mlia.core.performance import estimate_performance from mlia.core.performance import P from mlia.core.performance import PerformanceEstimator +from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS from mlia.nn.select import get_optimizer from mlia.nn.select import OptimizationSettings from mlia.nn.tensorflow.config import get_keras_model @@ -50,7 +51,7 @@ class OptimizingDataCollector(ContextAwareDataCollector): optimizations = self._get_optimization_settings(self.context) - training_parameters = self._get_training_settings(self.context) + rewrite_parameters = self._get_rewrite_settings(self.context) if not optimizations or optimizations == [[]]: raise FunctionalityNotSupportedError( @@ -77,7 +78,7 @@ class OptimizingDataCollector(ContextAwareDataCollector): model = self.model # type: ignore optimizers: list[Callable] = [ - partial(self.optimize_model, opts, training_parameters) + partial(self.optimize_model, opts, rewrite_parameters) for opts in opt_settings ] @@ -86,12 +87,12 @@ class OptimizingDataCollector(ContextAwareDataCollector): def optimize_model( self, opt_settings: list[OptimizationSettings], - training_parameters: dict | None, + rewrite_parameters: dict, model: KerasModel | TFLiteModel, ) -> Any: """Run optimization.""" optimizer = get_optimizer( - model, opt_settings, training_parameters=training_parameters + model, opt_settings, rewrite_parameters=rewrite_parameters ) opts_as_str = ", ".join(str(opt) for opt in opt_settings) logger.info("Applying model optimizations - [%s]", opts_as_str) @@ -123,11 +124,11 @@ class OptimizingDataCollector(ContextAwareDataCollector): context=context, ) - def _get_training_settings(self, context: Context) -> dict: + def _get_rewrite_settings(self, context: Context) -> list[dict]: """Get optimization settings.""" return self.get_parameter( # type: ignore OptimizingDataCollector.name(), - "training_parameters", + "rewrite_parameters", expected_type=dict, expected=False, context=context, @@ -218,7 +219,53 @@ _DEFAULT_OPTIMIZATION_TARGETS = [ ] -def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -> None: +def parse_augmentations( + augmentations: dict | str | None, +) -> tuple[float | None, float | None]: + """Parse augmentations from optimization-profile and return a valid tuple.""" + if isinstance(augmentations, str): + match_augmentation = AUGMENTATION_PRESETS.get(augmentations) + if not match_augmentation: + match_augmentation = AUGMENTATION_PRESETS["none"] + return match_augmentation + if isinstance(augmentations, dict): + augmentation_keys_test_for_valid = list(augmentations.keys()) + augmentation_keys_test_for_float = list(augmentations.keys()) + valid_keys = ["mixup_strength", "gaussian_strength"] + tuple_to_return = [] + for valid_key in valid_keys.copy(): + if augmentations.get(valid_key) is not None: + del augmentation_keys_test_for_valid[ + augmentation_keys_test_for_valid.index(valid_key) + ] + if isinstance(augmentations.get(valid_key), float): + tuple_to_return.append(augmentations[valid_key]) + del augmentation_keys_test_for_float[ + augmentation_keys_test_for_float.index(valid_key) + ] + else: + tuple_to_return.append(None) + else: + tuple_to_return.append(None) + if len(augmentation_keys_test_for_valid) > 0: + logger.warning( + "Warning! Expected augmentation parameters to be 'gaussian_strength' " + "and/or 'mixup_strength' got %s. " + "Removing invalid augmentations", + str(list(augmentations.keys())), + ) + elif len(augmentation_keys_test_for_float) > 0: + logger.warning( + "Warning! Not all augmentation parameters were floats, " + "removing non-float augmentations" + ) + return (tuple_to_return[0], tuple_to_return[1]) + return AUGMENTATION_PRESETS["none"] + + +def add_common_optimization_params( # pylint: disable=too-many-branches + advisor_parameters: dict, extra_args: dict +) -> None: """Add common optimization parameters.""" optimization_targets = extra_args.get("optimization_targets") if not optimization_targets: @@ -227,18 +274,32 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) - if not is_list_of(optimization_targets, dict): raise TypeError("Optimization targets value has wrong format.") - rewrite_parameters = extra_args.get("optimization_profile") training_parameters = None - if rewrite_parameters: - if not isinstance(rewrite_parameters, dict): - raise TypeError("Training Parameter values has wrong format.") - training_parameters = extra_args["optimization_profile"].get("training") + rewrite_specific_parameters = None + + optimization_parameters = extra_args.get("optimization_profile") + if optimization_parameters: # pylint: disable=too-many-nested-blocks + if not isinstance(optimization_parameters, dict): + raise TypeError("Optimization Parameter values has wrong format.") + + if optimization_parameters.get("rewrite"): + rewrite_params = optimization_parameters["rewrite"] + training_parameters = rewrite_params.get("training_parameters") + if training_parameters: + training_parameters["augmentations"] = parse_augmentations( + training_parameters.get("augmentations") + ) + optimization_target = optimization_targets[0]["optimization_target"] + rewrite_specific_parameters = rewrite_params.get(optimization_target) advisor_parameters.update( { "common_optimizations": { "optimizations": [optimization_targets], - "training_parameters": training_parameters, + "rewrite_parameters": { + "train_params": training_parameters, + "rewrite_specific_params": rewrite_specific_parameters, + }, }, } ) diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index 8492086..8a5b360 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -71,7 +71,16 @@ def is_builtin_target_profile(profile_name: str | Path) -> bool: return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES -BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = ["optimization"] +BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = [ + "optimization", + "optimization-custom-augmentation", + "optimization-fully-connected-clustering", + "optimization-fully-connected-pruning", + "optimization-fully-connected-unstructured-pruning", + "optimization-conv2d-clustering", + "optimization-conv2d-pruning", + "optimization-conv2d-unstructured-pruning", +] def is_builtin_optimization_profile(optimization_name: str | Path) -> bool: diff --git a/tests/conftest.py b/tests/conftest.py index 3d0b832..a64f320 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -126,9 +126,28 @@ def get_test_keras_model() -> keras.Model: return model +def get_test_keras_model_no_activation() -> keras.Model: + """Return test Keras model.""" + model = keras.Sequential( + [ + keras.Input(shape=(28, 28, 1), batch_size=1, name="input"), + keras.layers.Reshape((28, 28, 1)), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv1"), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv2"), + keras.layers.MaxPool2D(2, 2), + keras.layers.Flatten(), + keras.layers.Dense(10, name="output"), + ] + ) + + model.compile(optimizer="sgd", loss="mean_squared_error") + return model + + TEST_MODEL_KERAS_FILE = "test_model.h5" TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite" TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite" +TEST_MODEL_TFLITE_NO_ACT_FILE = "test_model_no_act.tflite" TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite" TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model" TEST_MODEL_INVALID_FILE = "invalid.tflite" @@ -153,6 +172,13 @@ def fixture_test_models_path( keras_model, quantized=False, output_path=tmp_path / TEST_MODEL_TFLITE_FP32_FILE ) + # Un-quantized TensorFlow Lite model with ReLU activation (fp32) + convert_to_tflite( + get_test_keras_model_no_activation(), + quantized=False, + output_path=tmp_path / TEST_MODEL_TFLITE_NO_ACT_FILE, + ) + # Quantized TensorFlow Lite model (int8) tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path) @@ -195,6 +221,12 @@ def fixture_test_tflite_vela_model(test_models_path: Path) -> Path: return test_models_path / TEST_MODEL_TFLITE_VELA_FILE +@pytest.fixture(scope="session", name="test_tflite_no_act_model") +def fixture_test_tflite_no_act_model(test_models_path: Path) -> Path: + """Return test TensorFlow Lite model with relu activation.""" + return test_models_path / TEST_MODEL_TFLITE_NO_ACT_FILE + + @pytest.fixture(scope="session", name="test_tf_model") def fixture_test_tf_model(test_models_path: Path) -> Path: """Return test TensorFlow Lite model.""" @@ -257,17 +289,17 @@ def fixture_test_tfrecord_fp32( yield from create_tfrecord(tmp_path_factory, random_data) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="function", autouse=True) def set_training_steps( request: _pytest.fixtures.SubRequest, ) -> Generator[None, None, None]: """Speed up tests by using MockTrainingParameters.""" - if "set_training_steps" == request.fixturename: - yield - else: + if "skip_set_training_steps" not in request.keywords: with pytest.MonkeyPatch.context() as monkeypatch: monkeypatch.setattr( "mlia.nn.select._get_rewrite_params", - MagicMock(return_value=[MockTrainingParameters(), None, None]), + MagicMock(return_value=MockTrainingParameters()), ) yield + else: + yield diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 93a05bd..5a91cd7 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -90,7 +90,7 @@ def test_performance_unknown_target( None, None, True, - "fully-connected-sparsity24", + "fully-connected-sparsity", "sequential/flatten/Reshape", "StatefulPartitionedCall:0", does_not_raise(), @@ -139,8 +139,10 @@ def test_performance_unknown_target( Exception, match=re.escape( "Invalid rewrite target: 'random'. " - "Supported rewrites: ['fully-connected'," - " 'fully-connected-clustering', 'fully-connected-sparsity24']" + "Supported rewrites: ['conv2d-clustering', 'conv2d-sparsity', " + "'conv2d-unstructured-sparsity', 'fully-connected', " + "'fully-connected-clustering', 'fully-connected-sparsity', " + "'fully-connected-unstructured-sparsity']" ), ), ], @@ -195,6 +197,58 @@ def test_performance_unknown_target( "StatefulPartitionedCall:0", does_not_raise(), ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "fully-connected-unstructured-sparsity", + "sequential/flatten/Reshape", + "StatefulPartitionedCall:0", + does_not_raise(), + ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "conv2d-sparsity", + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + does_not_raise(), + ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "conv2d-unstructured-sparsity", + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + does_not_raise(), + ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "conv2d-clustering", + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + does_not_raise(), + ], ], ) def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py index 0e9f0d6..69e6ffe 100644 --- a/tests/test_cli_helpers.py +++ b/tests/test_cli_helpers.py @@ -156,7 +156,7 @@ def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None: def test_copy_optimization_file_to_output_dir(tmp_path: Path) -> None: - """Test if the optimization profile file is copied into the output directory.""" + """Test if the profile file is copied into the output directory.""" test_target_profile_name = "optimization" test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml") diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 341e0d2..bdcf034 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the common optimization module.""" +from __future__ import annotations + from contextlib import ExitStack as does_not_raises from pathlib import Path from typing import Any @@ -15,6 +17,7 @@ from mlia.nn.tensorflow.config import TFLiteModel from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS from mlia.target.common.optimization import add_common_optimization_params from mlia.target.common.optimization import OptimizingDataCollector +from mlia.target.common.optimization import parse_augmentations from mlia.target.config import load_profile from mlia.target.config import TargetProfile @@ -57,7 +60,10 @@ def test_optimizing_data_collector( config_parameters={ "common_optimizations": { "optimizations": optimizations, - "training_parameters": training_parameters, + "rewrite_parameters": { + "train_params": training_parameters, + "rewrite_specific_params": None, + }, } } ) @@ -94,12 +100,15 @@ def test_optimizing_data_collector( collector.set_context(context) collector.collect_data() assert optimize_model_mock.call_args.args[0] == opt_settings[0] - assert optimize_model_mock.call_args.args[1] == training_parameters + assert optimize_model_mock.call_args.args[1] == { + "train_params": training_parameters, + "rewrite_specific_params": None, + } assert fake_optimizer.invocation_count == 1 @pytest.mark.parametrize( - "extra_args, error_to_raise", + "extra_args, error_to_raise, rewrite_parameter_type", [ ( { @@ -112,14 +121,39 @@ def test_optimizing_data_collector( ], }, does_not_raises(), + type(None), ), ( { + "optimization_targets": [ + { + "optimization_type": "rewrite", + "optimization_target": "fully-connected-clustering", + } + ], "optimization_profile": load_profile( - "src/mlia/resources/optimization_profiles/optimization.toml" - ) + "src/mlia/resources/optimization_profiles/" + "optimization-fully-connected-clustering.toml" + ), }, does_not_raises(), + dict, + ), + ( + { + "optimization_targets": [ + { + "optimization_type": "rewrite", + "optimization_target": "fully-connected-sparsity", + } + ], + "optimization_profile": load_profile( + "src/mlia/resources/optimization_profiles/" + "optimization-fully-connected-pruning.toml" + ), + }, + does_not_raises(), + dict, ), ( { @@ -132,16 +166,22 @@ def test_optimizing_data_collector( pytest.raises( TypeError, match="Optimization targets value has wrong format." ), + type(None), ), ( {"optimization_profile": [32, 1e-3, True, 48000, "cosine", 1, 0]}, pytest.raises( - TypeError, match="Training Parameter values has wrong format." + TypeError, match="Optimization Parameter values has wrong format." ), + type(None), ), ], ) -def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -> None: +def test_add_common_optimization_params( + extra_args: dict, + error_to_raise: Any, + rewrite_parameter_type: dict | None, +) -> None: """Test to check that optimization_targets and optimization_profiles are correctly parsed.""" advisor_parameters: dict = {} @@ -158,12 +198,93 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) - ] if not extra_args.get("optimization_profile"): - assert ( - advisor_parameters["common_optimizations"]["training_parameters"] - is None - ) + assert advisor_parameters["common_optimizations"]["rewrite_parameters"] == { + "train_params": None, + "rewrite_specific_params": None, + } else: - assert ( - advisor_parameters["common_optimizations"]["training_parameters"] - == extra_args["optimization_profile"]["training"] + if not extra_args["optimization_profile"].get("rewrite"): + assert isinstance( + advisor_parameters["common_optimizations"]["rewrite_parameters"][ + "train_params" + ], + type(None), + ) + elif not extra_args["optimization_profile"]["rewrite"].get( + "training_parameters" + ): + assert isinstance( + advisor_parameters["common_optimizations"]["rewrite_parameters"][ + "train_params" + ], + type(None), + ) + else: + assert isinstance( + advisor_parameters["common_optimizations"]["rewrite_parameters"][ + "train_params" + ], + dict, + ) + + assert isinstance( + advisor_parameters["common_optimizations"]["rewrite_parameters"][ + "rewrite_specific_params" + ], + rewrite_parameter_type, # type: ignore ) + + +@pytest.mark.parametrize( + "augmentations, expected_output", + [ + ( + {"gaussian_strength": 1.0, "mixup_strength": 1.0}, + (1.0, 1.0), + ), + ( + {"gaussian_strength": 1.0}, + (None, 1.0), + ), + ( + {"Wrong param": 1.0, "mixup_strength": 1.0}, + (1.0, None), + ), + ( + {"Wrong param1": 1.0, "Wrong param2": 1.0}, + (None, None), + ), + ( + "gaussian", + (None, 1.0), + ), + ( + "mix_gaussian_large", + (2.0, 1.0), + ), + ( + "not in presets", + (None, None), + ), + ( + {"gaussian_strength": 1.0, "mixup_strength": 1.0, "mix2": 1.0}, + (1.0, 1.0), + ), + ( + {"gaussian_strength": "not a float", "mixup_strength": 1.0}, + (1.0, None), + ), + ( + None, + (None, None), + ), + ], +) +def test_parse_augmentations( + augmentations: dict | str | None, expected_output: tuple +) -> None: + """Check that augmentation parameters in optimization_profiles are + correctly parsed.""" + + augmentation_output = parse_augmentations(augmentations) + assert augmentation_output == expected_output diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index e502842..9e3287e 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -3,18 +3,23 @@ """Tests for module mlia.nn.rewrite.core.rewrite.""" from __future__ import annotations +import re from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from typing import cast from unittest.mock import MagicMock +import numpy as np import pytest import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module ClusterWeights, ) +from tensorflow_model_optimization.python.core.sparsity.keras.pruning_wrapper import ( # pylint: disable=no-name-in-module + PruneLowMagnitude, +) from mlia.nn.rewrite.core.rewrite import ClusteringRewrite from mlia.nn.rewrite.core.rewrite import GenericRewrite @@ -23,40 +28,19 @@ from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer -from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite +from mlia.nn.rewrite.core.rewrite import StructuredSparsityRewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters +from mlia.nn.rewrite.core.rewrite import UnstructuredSparsityRewrite from mlia.nn.rewrite.core.train import train_in_dir -from mlia.nn.rewrite.library.fc_clustering_layer import ( - get_keras_model_clus as fc_clustering_rewrite, -) +from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_unstructured_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_unstructured_rewrite from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters -class TestRewrite(Rewrite): - """Test rewrite class.""" - - def quantize(self, model: keras.Model) -> keras.Model: - """Return a quantized model if required.""" - return tfmot.quantization.keras.quantize_model(model) - - def preserved_quantize(self, model: keras.Model) -> keras.Model: - """Not needed.""" - return model - - def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" - return [] - - def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" - return model - - def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: - """Not needed here.""" - return True - - def mock_rewrite_function(*_: Any) -> Any: """Mock function to test autoloading of rewrite functions.""" @@ -67,10 +51,10 @@ def test_rewrite() -> None: def bad_rewrite_func() -> Any: raise NotImplementedError() - rewrite = TestRewrite( + rewrite = GenericRewrite( "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func) ) - with pytest.raises(RuntimeError): + with pytest.raises(KeyError): rewrite((1, 2), (1, 2)) @@ -79,7 +63,9 @@ def test_rewrite() -> None: [ ("fully-connected", 0, GenericRewrite), ("fully-connected-clustering", 0, ClusteringRewrite), - ("fully-connected-sparsity24", 1, Sparsity24Rewrite), + ("fully-connected-sparsity", 1, StructuredSparsityRewrite), + ("conv2d-clustering", 0, ClusteringRewrite), + ("conv2d-sparsity", 1, StructuredSparsityRewrite), ], ) def test_rewrite_selection( @@ -96,8 +82,10 @@ def test_rewrite_selection( "rewrite_name, expected_error", [ ("fully-connected", does_not_raise()), - ("fully-connected-sparsity24", does_not_raise()), + ("fully-connected-sparsity", does_not_raise()), ("fully-connected-clustering", does_not_raise()), + ("conv2d-clustering", does_not_raise()), + ("conv2d-sparsity", does_not_raise()), ("random", does_not_raise()), ], ) @@ -105,7 +93,8 @@ def test_rewrite_configuration( test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any ) -> None: """Test get_rewrite function only supports rewrite type fully-connected, - fully-connected-clustering and fully-connected-sparsity24.""" + fully-connected-clustering, fully-connected-sparsity, conv2d-clustering + and conv2d-sparsity.""" with expected_error: config_obj = RewriteConfiguration( rewrite_name, @@ -120,29 +109,195 @@ def test_rewrite_configuration( assert isinstance(rewriter_obj, RewritingOptimizer) +def train_rewrite_model( + input_shape: tuple | np.ndarray, + output_shape: int | np.ndarray, + rewrite_model: keras.Model, + epochs: int = 1, +) -> keras.Model: + """Helper function to quickly train a rewrite model.""" + rewrite_model.compile( + optimizer=keras.optimizers.Nadam(learning_rate=0.01), + loss=keras.losses.MeanSquaredError(), + metrics=["mae"], + ) + if isinstance(output_shape, int): + output_shape_list = [output_shape] + else: + output_shape_list = output_shape.tolist() + rewrite_model.fit( + x=np.random.rand(16, *input_shape), + y=np.random.rand(16, *output_shape_list), + batch_size=1, + epochs=epochs, + callbacks=[tfmot.sparsity.keras.UpdatePruningStep()], + ) + return rewrite_model + + def test_rewrite_fully_connected_clustering() -> None: - """Check that model has the set number of clusters""" + """Check that fully connected clustering rewrite model + has the set number of clusters.""" + + rewrite = ClusteringRewrite( + "fully-connected-clustering", + fc_clustering_rewrite, + ) + + model = rewrite(input_shape=(28, 28), output_shape=10, num_clusters=2) + model = rewrite.post_process(model) + assert rewrite.check_optimization( + model, + num_clusters=2, + ) + + +def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> None: + """ + Check that sparse fully connected + rewrite model is correctly sparse. + """ + + rewrite = StructuredSparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite) + input_shape = (28, 28) + output_shape = 10 + model = rewrite( + input_shape=tuple(input_shape), + output_shape=output_shape, + sparsity_m=2, + sparsity_n=4, + ) + model = rewrite.post_process(model) + assert not rewrite.check_optimization(model) + log_records = caplog.records + warning_messages = [x.message for x in log_records if x.levelno == 30] + assert ( + re.search( + r"\nWARNING: Could not find \(2, 4\) sparsity, in " + r"layer dense_?\d? for weight dense_?\d?\/kernel:0 \n", + warning_messages[0], + ) + is not None + ) + model = rewrite( + input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4 + ) + train_rewrite_model( + input_shape=input_shape, output_shape=output_shape, rewrite_model=model + ) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model) + + +def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None: + """Check that sparse conv2d rewrite model is correctly sparse.""" + + rewrite = StructuredSparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite) + input_shape = np.array([28, 28, 3]) + output_shape = np.array([14, 14, 3]) + model = rewrite( + input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4 + ) + model = rewrite.post_process(model) + assert not rewrite.check_optimization(model) + log_records = caplog.records + warning_messages = [x.message for x in log_records if x.levelno == 30] + assert ( + re.search( + r"\nWARNING: Could not find \(2, 4\) sparsity, in " + r"layer conv2d_?\d? for weight conv2d_?\d?\/kernel:0 \n", + warning_messages[0], + ) + is not None + ) + model = rewrite( + input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4 + ) + train_rewrite_model( + input_shape=input_shape, output_shape=output_shape, rewrite_model=model + ) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model) + - rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) - model = rewrite(input_shape=(28, 28), output_shape=10) +def test_rewrite_conv2d_unstructured_sparsity(caplog: pytest.LogCaptureFixture) -> None: + """Check that an unstructured sparse conv2d rewrite is correctly sparse.""" + + rewrite = UnstructuredSparsityRewrite( + "conv2d-unstructured-sparsity", conv2d_sparsity_unstructured_rewrite + ) + input_shape = np.array([28, 28, 3]) + output_shape = np.array([14, 14, 3]) + model = rewrite( + input_shape=input_shape, output_shape=output_shape, final_sparsity=0.50 + ) + model = rewrite.post_process(model) + assert not rewrite.check_optimization(model) + log_records = caplog.records + warning_messages = [x.message for x in log_records if x.levelno == 30] + assert ( + re.search( + r"\nWARNING: Found total sparsity of rewrite model: \d.\d\d " + r"expected total sparsity to be: 0.50\n", + warning_messages[0], + ) + is not None + ) + model = rewrite( + input_shape=input_shape, + output_shape=output_shape, + final_sparsity=0.5, + end_step=120, + ) + train_rewrite_model( + input_shape=input_shape, + output_shape=output_shape, + rewrite_model=model, + epochs=10, + ) model = rewrite.post_process(model) - assert rewrite.check_optimization(model, number_of_clusters=32) + assert rewrite.check_optimization(model) -def test_rewrite_fully_connected_clustering_error_handling() -> None: - """Check that model has the set number of clusters - and that when quantized the number of clusters - remain.""" +def test_rewrite_fully_connected_unstructured_sparsity( + caplog: pytest.LogCaptureFixture, +) -> None: + """Check that an unstructured sparse FC rewrite is correctly sparse.""" - rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) - model = rewrite(input_shape=(28, 28), output_shape=10) - with pytest.raises( - ValueError, - match=( - r"Expected check_preserved_quantize to have argument number_of_clusters" - ), - ): - rewrite.check_optimization(model, bad_arg_name=25) + rewrite = UnstructuredSparsityRewrite( + "fully-connected-unstructured-sparsity", fc_sparsity_unstructured_rewrite + ) + input_shape = (28, 28) + output_shape = 10 + model = rewrite( + input_shape=tuple(input_shape), output_shape=output_shape, final_sparsity=0.5 + ) + model = rewrite.post_process(model) + assert not rewrite.check_optimization(model) + log_records = caplog.records + warning_messages = [x.message for x in log_records if x.levelno == 30] + assert ( + re.search( + r"\nWARNING: Found total sparsity of rewrite model: \d.\d\d " + r"expected total sparsity to be: 0.50\n", + warning_messages[0], + ) + is not None + ) + model = rewrite( + input_shape=input_shape, + output_shape=output_shape, + final_sparsity=0.5, + end_step=120, + ) + train_rewrite_model( + input_shape=input_shape, + output_shape=output_shape, + rewrite_model=model, + epochs=10, + ) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model) @pytest.mark.parametrize( @@ -151,6 +306,40 @@ def test_rewrite_fully_connected_clustering_error_handling() -> None: ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False], ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False], ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True], + ["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], False], + ["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], True], + ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], False], + ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], True], + [ + "conv2d-sparsity", + [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude], + False, + ], + [ + "conv2d-sparsity", + [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude], + True, + ], + [ + "fully-connected-unstructured-sparsity", + [PruneLowMagnitude, PruneLowMagnitude], + False, + ], + [ + "fully-connected-unstructured-sparsity", + [PruneLowMagnitude, PruneLowMagnitude], + True, + ], + [ + "conv2d-unstructured-sparsity", + [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude], + False, + ], + [ + "conv2d-unstructured-sparsity", + [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude], + True, + ], ], ) def test_rewriting_optimizer( # pylint: disable=too-many-locals @@ -162,24 +351,32 @@ def test_rewriting_optimizer( # pylint: disable=too-many-locals expected_layers: list[object], quant: bool, ) -> None: - """Test fc_layer rewrite process with rewrite type fully-connected.""" + """Test the rewrite process with all rewrite types.""" tfrecord = test_tfrecord if quant else test_tfrecord_fp32 tflite_model = test_tflite_model if quant else test_tflite_model_fp32 + rewrite_function = RewritingOptimizer.registry.items[rewrite_type] config_obj = RewriteConfiguration( rewrite_type, - ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], + ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"] + if "fully-connected" in rewrite_type + else [ + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + ], tfrecord, train_params=MockTrainingParameters(), ) - test_obj = RewritingOptimizer(tflite_model, config_obj) - rewrite_function = RewritingOptimizer.registry.items[ - test_obj.optimizer_configuration.optimization_target - ] # Input, output shape does not matter, just need the test the layers are as expected - rewrite_model = rewrite_function(input_shape=(28, 28, 1), output_shape=12) + rewrite_model = ( + rewrite_function(input_shape=(28, 28, 1), output_shape=12) + if "fully-connected" in rewrite_type + else rewrite_function( + input_shape=np.array([28, 28, 3]), output_shape=np.array([14, 14, 3]) + ) + ) for idx, layer in enumerate(rewrite_model.layers): assert isinstance(layer, expected_layers[idx]) # type: ignore @@ -197,8 +394,14 @@ def test_register_rewrite_function() -> None: """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() - rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1)) - rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2)) + rewrite1 = GenericRewrite( + "r1", + cast(RewriteCallable, lambda: 1), + ) + rewrite2 = GenericRewrite( + "r2", + cast(RewriteCallable, lambda: 2), + ) registry.register_rewrite(rewrite1) registry.register_rewrite(rewrite2) @@ -207,11 +410,15 @@ def test_register_rewrite_function() -> None: def test_builtin_rewrite_names() -> None: """Test if all builtin rewrites are properly registered and returned.""" - assert RewritingOptimizer.builtin_rewrite_names() == [ + assert set(RewritingOptimizer.builtin_rewrite_names()) == { + "conv2d-clustering", + "conv2d-sparsity", + "conv2d-unstructured-sparsity", "fully-connected", "fully-connected-clustering", - "fully-connected-sparsity24", - ] + "fully-connected-sparsity", + "fully-connected-unstructured-sparsity", + } def test_rewrite_configuration_train_params( diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 94c99ff..03b230f 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -16,11 +16,12 @@ from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.rewrite.core.train import augment_fn_twins from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS +from mlia.nn.rewrite.core.train import detect_activation_from_rewrite_function from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters -from tests.test_nn_rewrite_core_rewrite import TestRewrite +from tests.test_nn_rewrite_core_rewrite import GenericRewrite from tests.utils.rewrite import MockTrainingParameters @@ -54,7 +55,7 @@ def check_train( """Test the train() function.""" with TemporaryDirectory() as tmp_dir: output_file = Path(tmp_dir, "out.tflite") - mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv) + mock_rewrite = GenericRewrite("replace", replace_fully_connected_with_conv) result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, @@ -65,6 +66,7 @@ def check_train( input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], train_params=train_params, + rewrite_specific_params={}, ) assert len(result[0][0]) == 2 @@ -249,3 +251,41 @@ def test_train_checkpoint( use_unmodified_model=False, quantized=True, ) + + +def test_detect_activation_from_rewrite_function_no_activation( + caplog: pytest.LogCaptureFixture, test_tflite_no_act_model: Path +) -> None: + """ + Test function detect_activation_from_rewrite_function() + with a model with no activation functions. + """ + caplog.set_level(level=20) + activation = detect_activation_from_rewrite_function( + test_tflite_no_act_model.as_posix() + ) + log_records = caplog.get_records(when="call") + logging_messages = [x.message for x in log_records if x.levelno == 20] + assert activation == "relu" + assert ( + "No activation function specified, setting activation function to ReLU" + in logging_messages + ) + + +def test_detect_activation_from_rewrite_function_relu_activation( + caplog: pytest.LogCaptureFixture, test_tflite_model: Path +) -> None: + """ + Test function detect_activation_from_rewrite_function() + with a model with ReLU activation functions. + """ + caplog.set_level(level=20) + activation = detect_activation_from_rewrite_function(test_tflite_model.as_posix()) + log_records = caplog.get_records(when="call") + logging_messages = [x.message for x in log_records if x.levelno == 20] + assert activation == "relu" + assert ( + "No activation function specified, setting activation function " + "to most common activation detected in rewrite graph: relu" in logging_messages + ) diff --git a/tests/test_nn_rewrite_library_helper_functions.py b/tests/test_nn_rewrite_library_helper_functions.py new file mode 100644 index 0000000..a0dd7b9 --- /dev/null +++ b/tests/test_nn_rewrite_library_helper_functions.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.library.helper_functions.""" +from __future__ import annotations + +from contextlib import ExitStack as does_not_raise +from typing import Any + +import numpy as np +import pytest +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function + + +def compute_conv_output( + input_data: np.ndarray, input_shape: np.ndarray, conv_parameters: dict[str, Any] +) -> np.ndarray: + """Compute the output of a conv layer for testing.""" + test_model = keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv_parameters), + ] + ) + output = test_model(input_data) + return np.array(output.shape[1:]) + + +@pytest.mark.parametrize( + "input_shape, output_shape, kernel_size", + [ + (np.array([32, 32, 3]), np.array([16, 16, 3]), [3, 3]), + (np.array([32, 32, 3]), np.array([8, 8, 3]), [3, 3]), + (np.array([32, 32, 3]), np.array([8, 16, 3]), [3, 3]), + (np.array([25, 10, 3]), np.array([13, 5, 3]), [3, 3]), + (np.array([25, 10, 3]), np.array([7, 5, 3]), [3, 3]), + (np.array([25, 10, 3]), np.array([6, 4, 3]), [3, 3]), + (np.array([25, 10, 3]), np.array([5, 5, 3]), [3, 3]), + (np.array([32, 32, 3]), np.array([16, 16, 3]), [1, 3]), + (np.array([32, 32, 3]), np.array([16, 16, 3]), [1, 1]), + (np.array([32, 32, 3]), np.array([16, 16, 3]), [5, 5]), + ], +) +def test_compute_conv2d_parameters( + input_shape: np.ndarray, output_shape: np.ndarray, kernel_size: list[int] +) -> None: + """Test to check compute_conv2d_parameters works as expected.""" + conv_parameters = compute_conv2d_parameters( + input_shape=input_shape, + output_shape=output_shape, + kernel_size_input=kernel_size, + ) + computed_output_shape = compute_conv_output( + np.random.rand(1, *input_shape), input_shape, conv_parameters + ) + assert np.equal(computed_output_shape, output_shape).all() + + +@pytest.mark.parametrize( + "activation, expected_function_type, expected_extra_args, expected_error", + [ + ("relu", keras.layers.ReLU, {}, does_not_raise()), + ("relu6", keras.layers.ReLU, {"max_value": 6}, does_not_raise()), + ("none", None, {}, does_not_raise()), + ( + "wrong_key", + keras.layers.Identity, + {}, + pytest.raises( + KeyError, + match=( + "Expected activation function to be " + rf"in \{ACTIVATION_FUNCTION_LIST}\, found wrong_key" + ), + ), + ), + ], +) +def test_get_activation_functions( + activation: str, + expected_function_type: type, + expected_extra_args: dict, + expected_error: Any, +) -> None: + """ + Check the get_activation_function returns + the expected layer and extra arguments. + """ + with expected_error: + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + if activation_function: + assert isinstance( + activation_function(**activation_function_extra_args), + expected_function_type, + ) + else: + assert activation_function is None + assert expected_extra_args == activation_function_extra_args diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py index 4095076..08752bd 100644 --- a/tests/test_nn_select.py +++ b/tests/test_nn_select.py @@ -4,12 +4,12 @@ from __future__ import annotations from contextlib import ExitStack as does_not_raise -from dataclasses import asdict from pathlib import Path from typing import Any from typing import cast import pytest +import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.core.errors import ConfigurationError @@ -176,23 +176,50 @@ def test_get_optimizer( model = test_tflite_model else: model = keras.models.load_model(str(test_keras_model)) - optimizer = get_optimizer(model, config) + optimizer = get_optimizer( + model, config, {"train_params": None, "rewrite_specific_params": None} + ) assert isinstance(optimizer, expected_type) assert optimizer.optimization_config() == expected_config +# pylint: disable=line-too-long @pytest.mark.parametrize( - "rewrite_parameters", - [None, {"batch_size": 64, "learning_rate": 0.003}], + "rewrite_parameters, optimization_target", + [ + [ + {"train_params": None, "rewrite_specific_params": None}, + "fully-connected-clustering", + ], + [ + { + "train_params": None, + "rewrite_specific_params": { + "num_clusters": 5, + "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization( + "CentroidInitialization.LINEAR" + ), + }, + }, + "fully-connected-clustering", + ], + [ + {"train_params": None, "rewrite_specific_params": None}, + "fully-connected", + ], + ], ) +# pylint: enable=line-too-long @pytest.mark.skip_set_training_steps def test_get_optimizer_training_parameters( - rewrite_parameters: dict | None, test_tflite_model: Path + rewrite_parameters: dict, + optimization_target: str, + test_tflite_model: Path, ) -> None: """Test function get_optimzer with various combinations of parameters.""" config = OptimizationSettings( optimization_type="rewrite", - optimization_target="fully-connected", # type: ignore + optimization_target=optimization_target, # type: ignore layers_to_optimize=None, dataset=None, ) @@ -200,18 +227,20 @@ def test_get_optimizer_training_parameters( RewritingOptimizer, get_optimizer(test_tflite_model, config, rewrite_parameters), ) + assert len(list(rewrite_parameters.items())) == 2 + if rewrite_parameters.get("rewrite_specific_params"): + assert isinstance( + rewrite_parameters["rewrite_specific_params"], + type(optimizer.optimizer_configuration.rewrite_specific_params), + ) + assert ( + optimizer.optimizer_configuration.rewrite_specific_params + == rewrite_parameters["rewrite_specific_params"] + ) assert isinstance( optimizer.optimizer_configuration.train_params, TrainingParameters ) - if not rewrite_parameters: - assert asdict(TrainingParameters()) == asdict( - optimizer.optimizer_configuration.train_params - ) - else: - assert asdict(TrainingParameters()) | rewrite_parameters == asdict( - optimizer.optimizer_configuration.train_params - ) @pytest.mark.parametrize( diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py index 7bb57c3..2f06f54 100644 --- a/tests/test_target_cortex_a_advisor.py +++ b/tests/test_target_cortex_a_advisor.py @@ -8,6 +8,7 @@ import pytest from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext from mlia.core.workflow import DefaultWorkflowExecutor +from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor from mlia.target.cortex_a.advisor import CortexAInferenceAdvisor @@ -33,21 +34,11 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None: "target_profile": "cortex-a", }, "common_optimizations": { - "optimizations": [ - [ - { - "layers_to_optimize": None, - "optimization_target": 0.5, - "optimization_type": "pruning", - }, - { - "layers_to_optimize": None, - "optimization_target": 32, - "optimization_type": "clustering", - }, - ] - ], - "training_parameters": None, + "optimizations": [_DEFAULT_OPTIMIZATION_TARGETS], + "rewrite_parameters": { + "train_params": None, + "rewrite_specific_params": None, + }, }, } diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py index 020acc5..d0b42b9 100644 --- a/tests/test_target_tosa_advisor.py +++ b/tests/test_target_tosa_advisor.py @@ -9,6 +9,7 @@ import pytest from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext from mlia.core.workflow import DefaultWorkflowExecutor +from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS from mlia.target.tosa.advisor import configure_and_get_tosa_advisor from mlia.target.tosa.advisor import TOSAInferenceAdvisor @@ -33,21 +34,11 @@ def test_configure_and_get_tosa_advisor( assert ctx.event_handlers is not None assert ctx.config_parameters == { "common_optimizations": { - "optimizations": [ - [ - { - "layers_to_optimize": None, - "optimization_target": 0.5, - "optimization_type": "pruning", - }, - { - "layers_to_optimize": None, - "optimization_target": 32, - "optimization_type": "clustering", - }, - ] - ], - "training_parameters": None, + "optimizations": [_DEFAULT_OPTIMIZATION_TARGETS], + "rewrite_parameters": { + "train_params": None, + "rewrite_specific_params": None, + }, }, "tosa_inference_advisor": { "model": str(test_tflite_model), diff --git a/tests_e2e/optimization_e2e_test.toml b/tests_e2e/optimization_e2e_test.toml index 099247c..f075ec4 100644 --- a/tests_e2e/optimization_e2e_test.toml +++ b/tests_e2e/optimization_e2e_test.toml @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -[training] -steps = 1000 +[rewrite.training_parameters] +steps = 100 |