diff options
Diffstat (limited to 'src/mlia/nn')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/record.py | 10 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 220 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 209 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/clustering.py | 51 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/fc_layer.py | 6 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/helper_functions.py | 32 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/sparsity.py | 45 | ||||
-rw-r--r-- | src/mlia/nn/select.py | 23 |
8 files changed, 510 insertions, 86 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index f85433d..7d9f219 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Save subgraph data.""" # pylint: disable=too-many-locals @@ -32,7 +32,7 @@ def dequantized_path(filename: str | Path) -> Path: return path -def record_model( +def record_model( # pylint: disable=too-many-arguments input_filename: str | Path, model_filename: str | Path, output_filename: str | Path, @@ -41,6 +41,7 @@ def record_model( num_procs: int = 1, num_threads: int = 0, dequantize_output: bool = False, + quantize_input: bool = False, ) -> None: """Model recorder. @@ -92,7 +93,10 @@ def record_model( for _, named_x in enumerate( track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) ): - named_y = model(named_x) + if quantize_input: + named_y = model(model.quantize_inputs(named_x)) + else: + named_y = model(named_x) write(writer, named_y) if dequantize_output: diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index c7d13ba..6d915c6 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -3,16 +3,21 @@ """Contains class RewritingOptimizer to replace a subgraph/layer of a model.""" from __future__ import annotations -import importlib import logging import tempfile +from abc import ABC +from abc import abstractmethod from dataclasses import dataclass from pathlib import Path from typing import Any from typing import Callable -from typing import cast +import numpy as np +import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 +from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module + is_pruned_m_by_n, +) from mlia.core.errors import ConfigurationError from mlia.core.reporting import Column @@ -22,16 +27,20 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters +from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite +from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite +from mlia.nn.rewrite.library.fc_layer import fc_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry - logger = logging.getLogger(__name__) RewriteCallable = Callable[[Any, Any], keras.Model] -class Rewrite: - """Graph rewrite logic to be used by RewritingOptimizer.""" +class Rewrite(ABC): + """Abstract class for rewrite logic to be used by RewritingOptimizer.""" def __init__(self, name: str, rewrite_fn: RewriteCallable): """Initialize a Rewrite instance with a given name and an optional function.""" @@ -39,40 +48,157 @@ class Rewrite: self.function = rewrite_fn def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model: - """Perform the rewrite operation using the configured function.""" + """Return an instance of the rewrite model.""" try: return self.function(input_shape, output_shape) except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return model -@dataclass -class DynamicallyLoadedRewrite(Rewrite): - """A rewrite which can load logic from a function loaded dynamically.""" + @abstractmethod + def training_callbacks(self) -> list: + """Return rewrite callbacks.""" - def __init__(self, name: str, function_name: str): - """Initialize.""" + @abstractmethod + def post_process(self, model: keras.Model) -> keras.Model: + """Return post-processing rewrite option.""" - def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model: - """Load the function from a file dynamically.""" - self.load_function(function_name) - return self.function(input_shape, output_shape) + @abstractmethod + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Check if the optimization has produced the correct result.""" - super().__init__(name, load_and_run) - def load_function(self, function_name: str) -> RewriteCallable: - """Return the rewrite function. Import using the auto_load attr if necessary.""" - try: - name_parts = function_name.split(".") - module_name = ".".join(name_parts[:-1]) - fn_name = name_parts[-1] - module = importlib.import_module(module_name) - self.function = cast(RewriteCallable, getattr(module, fn_name)) - return self.function - except Exception as ex: - raise RuntimeError( - f"Unable to load rewrite function '{function_name}' for '{self.name}'." - ) from ex +class GenericRewrite(Rewrite): + """Rewrite class for generic rewrites e.g. fully-connected.""" + + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite option.""" + return model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Not needed here.""" + return True + + +class QuantizeAwareTrainingRewrite(Rewrite, ABC): + """Abstract class for rewrites that perform QAT.""" + + @abstractmethod + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply optimization-aware quantization to a given model.""" + return model + + +class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): + """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity24.""" + + pruning_callback = tfmot.sparsity.keras.UpdatePruningStep + + strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) + + def quantize(self, model: keras.Model) -> keras.Model: + """Skip quantization when using sparsity rewrite.""" + return model + + def training_callbacks(self) -> list: + """Return pruning-specific rewrite callback.""" + return [self.pruning_callback()] + + def post_process(self, model: keras.Model) -> keras.Model: + """Pruning-specific post-processing rewrite option.""" + return self.strip_pruning_wrapper(model) + + def preserved_quantize( + self, + model: keras.Model, + ) -> keras.Model: + """Apply pruning-preserved quantization training to a given model.""" + model = tfmot.quantization.keras.quantize_annotate_model(model) + model = tfmot.quantization.keras.quantize_apply( + model, + tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), + ) + + return model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Check if sparity has produced the correct result.""" + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + if not is_pruned_m_by_n(weight, m_by_n=(2, 4)): + logger.warning( + "\nWARNING: Could not find (2,4) sparsity, " + "in layer %s for weight %s \n", + layer.name, + weight.name, + ) + return False + return True + + +class ClusteringRewrite(QuantizeAwareTrainingRewrite): + """Rewrite class for clustering rewrite e.g. fully-connected-clustering.""" + + _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply clustering-preserved quantization to a given model.""" + quant_aware_model = tfmot.quantization.keras.quantize_annotate_model(model) + cqat_model = tfmot.quantization.keras.quantize_apply( + quant_aware_model, + tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(), + ) + return cqat_model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Check if clustering has produced the correct result.""" + number_of_clusters = kwargs.get("number_of_clusters") + if not number_of_clusters: + raise ValueError( + """ + Expected check_optimization to have argument number_of_clusters. + """ + ) + + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + number_of_found_clusters = len(np.unique(weight)) + if number_of_found_clusters != number_of_clusters: + logger.warning( + "\nWARNING: Expected %d cluster(s), found %d " + "cluster(s) in layer %s for weight %s \n", + number_of_clusters, + number_of_found_clusters, + layer.name, + weight.name, + ) + return False + return True + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Clustering-specific post-processing rewrite option.""" + return self._strip_clustering_wrapper(model) class RewriteRegistry(Registry[Rewrite]): @@ -113,9 +239,11 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ - DynamicallyLoadedRewrite( - "fully-connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model" - ) + GenericRewrite("fully-connected", fc_rewrite), + Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite), + ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), + ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite), + Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite), ] ) @@ -149,22 +277,35 @@ class RewritingOptimizer(Optimizer): raise ConfigurationError( "Input and output tensor names need to be set for rewrite." ) - orig_vs_repl_stats, total_stats = train( source_model=tflite_model, unmodified_model=tflite_model if use_unmodified_model else None, output_model=str(tmp_output), input_tfrec=str(tfrecord), - replace_fn=rewrite, + rewrite=rewrite, + is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite), input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, ) if orig_vs_repl_stats: - orig_vs_repl = ["Replaced sub-graph only"] + [ - f"{stat:.3f}" for stat in orig_vs_repl_stats - ] + model_stats: list = [] + cp_param = self.optimizer_configuration.train_params.checkpoint_at + checkpoints = ( + [ + "At checkpoint " + str(checkpoint) + " steps" + for checkpoint in cp_param + ] + if cp_param + else [] + ) + checkpoints.append("All Steps") + for checkpoint, orig_vs_repl_stat in zip(checkpoints, orig_vs_repl_stats): + model_stats.append( + ["Replaced sub-graph: " + checkpoint] + + [f"{stat:.3f}" for stat in orig_vs_repl_stat] + ) total = ["Total model"] + [f"{stat:.3f}" for stat in total_stats] notes = ( "These metrics show the difference between original model\n" @@ -178,19 +319,20 @@ class RewritingOptimizer(Optimizer): table = Table( columns=[ Column( - "Original vs. optimized", + "Original vs. Optimized", alias="metric", fmt=Format(wrap_width=40), ), Column("MAE", alias="value", fmt=Format(wrap_width=15)), Column("NRMSE", alias="value", fmt=Format(wrap_width=15)), ], - rows=[orig_vs_repl, total], + rows=[*model_stats, total], name="Rewrite performance metrics", alias="rewrite_performance_metrics", notes=notes, ) logger.info(table.to_plain_text(show_title=True)) + self.model = TFLiteModel(tmp_output) def get_model(self) -> TFLiteModel: """Return optimized model.""" diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 60c39ae..4204978 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Sequential trainer.""" +# pylint: disable=too-many-arguments # pylint: disable=too-many-locals # pylint: disable=too-many-statements from __future__ import annotations @@ -22,7 +23,6 @@ from typing import Literal import numpy as np import tensorflow as tf -import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from numpy.random import Generator @@ -62,7 +62,7 @@ LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule) class TrainingParameters: """Define default parameters for the training.""" - augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"] + augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["none"] batch_size: int = 32 steps: int = 48000 learning_rate: float = 1e-3 @@ -73,12 +73,13 @@ class TrainingParameters: checkpoint_at: list | None = None -def train( +def train( # pylint: disable=too-many-arguments source_model: str, unmodified_model: Any, output_model: str, input_tfrec: str, - replace_fn: Callable, + rewrite: Callable, + is_qat: bool, input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), @@ -118,7 +119,8 @@ def train( train_dir=train_dir, baseline_dir=unmodified_model_dir_path, output_filename=Path(train_dir, "new.tflite"), - replace_fn=replace_fn, + rewrite=rewrite, + is_qat=is_qat, train_params=train_params, ) @@ -145,7 +147,8 @@ def train( # Assess the output diff between the parts after the rewrite subgraph # in original and optimized model optimized_end_path = Path(train_dir, "optimized_end.tfrec") - end_path = Path(train_dir, "end.tfrec") + optimized_end_path_dequant = Path(train_dir, "optimized_end_dequant.tfrec") + end_path = Path(train_dir, "end_dequant.tfrec") record_model( str(input_tfrec), @@ -153,16 +156,18 @@ def train( optimized_end_path, num_procs=train_params.num_procs, num_threads=train_params.num_threads, + dequantize_output=True, ) - mae, nrmse = diff_stats(end_path, str(optimized_end_path)) + + mae, nrmse = diff_stats(end_path, optimized_end_path_dequant) if unmodified_model_dir: cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() - return (results if train_params.checkpoint_at else results[0]), [ + return results, [ mae, nrmse, - ] # only return a list if multiple checkpoints are asked for + ] def eval_in_dir( @@ -177,24 +182,27 @@ def eval_in_dir( model_input = ( model_input_path if model_input_path.exists() - else ExtractPaths.tfrec.input(target_dir, False) + else ExtractPaths.tfrec.input(target_dir, True) ) output = ( model_output_path if model_output_path.exists() - else ExtractPaths.tfrec.output(target_dir, False) + else ExtractPaths.tfrec.output(target_dir, True) ) with tempfile.TemporaryDirectory() as tmp_dir: predict = Path(tmp_dir, "predict.tfrec") + predict_dequant = Path(tmp_dir, "predict_dequant.tfrec") record_model( str(model_input), new_part, str(predict), num_procs=num_procs, num_threads=num_threads, + dequantize_output=True, + quantize_input=True, ) - mae, nrmse = diff_stats(str(output), str(predict)) + mae, nrmse = diff_stats(str(output), predict_dequant) return mae, nrmse @@ -247,7 +255,7 @@ def set_up_data_pipeline( augmentations: tuple[float | None, float | None], steps: int, batch_size: int = 32, -) -> tf.data.Dataset: +) -> tuple[tf.data.Dataset, int]: """Create a data pipeline for training of the replacement model.""" _check_model_compatibility(teacher, replace) @@ -338,14 +346,15 @@ def set_up_data_pipeline( dataset = dataset.map(restore_shapes) dataset = dataset.prefetch(tf.data.AUTOTUNE) - return dataset + return dataset, steps_per_epoch def train_in_dir( train_dir: str, baseline_dir: Any, output_filename: Path, - replace_fn: Callable, + rewrite: Callable, + is_qat: bool, train_params: TrainingParameters = TrainingParameters(), ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ @@ -370,7 +379,7 @@ def train_in_dir( if model_is_quantized: replace.check_datatypes(np.int8) - dataset = set_up_data_pipeline( + dataset, steps_per_epoch = set_up_data_pipeline( teacher, replace, train_dir, @@ -380,15 +389,15 @@ def train_in_dir( ) input_shape = teacher.shape_from_name[input_name][1:] - output_shape = teacher.shape_from_name[output_name][1:] - model = replace_fn(input_shape, output_shape) + output_shape = teacher.shape_from_name[output_name][1:] optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = keras.losses.MeanSquaredError() - if model_is_quantized: - model = tfmot.quantization.keras.quantize_model(model) - model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) + + model = create_model( + rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + ) logger.info(model.summary()) @@ -428,11 +437,130 @@ def train_in_dir( elif train_params.learning_rate_schedule == "constant": callbacks = [] - output_filenames = [] + callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] + output_filenames: list = [] checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ train_params.steps ] + model, output_filenames = model_fit( + model, + train_params, + checkpoints, + optimizer, + dataset, + callbacks, + output_filename, + rewrite, + replace, + input_name, + output_name, + model_is_quantized, + output_filenames, + input_shape, + output_shape, + loss_fn, + steps_per_epoch, + post_process=True, + ) + + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) + if model_is_quantized and is_qat: + model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] + checkpoints = ( + train_params.checkpoint_at if train_params.checkpoint_at else [] + ) + [train_params.steps] + output_filenames = [] + + if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined] + rewrite.training_callbacks() # type: ignore[attr-defined] + ).issubset(callbacks): + callbacks.pop(-1) + + optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) + model = model_compile(model, optimizer, loss_fn) + + model, output_filenames = model_fit( + model, + train_params, + checkpoints, + optimizer, + dataset, + callbacks, + output_filename, + rewrite, + replace, + input_name, + output_name, + model_is_quantized, + output_filenames, + input_shape, + output_shape, + loss_fn, + steps_per_epoch, + ) + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) + + teacher.close() + return output_filenames + +def model_compile( + model: keras.Model, + optimizer: keras.optimizers.Nadam, + loss_fn: keras.losses.Loss, +) -> keras.Model: + """Compiles a tflite model.""" + model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) + return model + + +def create_model( # pylint: disable=too-many-arguments + rewrite: Callable, + input_shape: int, + output_shape: int, + optimizer: Callable, + loss_fn: Callable, + model_is_quantized: bool, + model_to_load_from: keras.model | None = None, +) -> keras.Model: + """Create a model, optionally from another.""" + model = rewrite(input_shape, output_shape) + if model_is_quantized: + model = rewrite.quantize(model) # type: ignore[attr-defined] + model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) + if model_to_load_from: + model.set_weights(model_to_load_from.get_weights()) + return model + + +def model_fit( # pylint: disable=too-many-arguments + model: keras.Model, + train_params: TrainingParameters, + checkpoints: list, + optimizer: tf.optimizers.Nadam, + dataset: tf.data.Dataset, + callbacks: list, + output_filename: Path, + rewrite: Callable, + replace: TFLiteModel, + input_name: str, + output_name: str, + model_is_quantized: bool, + output_filenames: list, + input_shape: int, + output_shape: int, + loss_fn: Callable, + steps_per_epoch: int, + post_process: bool = False, +) -> keras.Model: + """Train a tflite model.""" + steps_so_far = 0 while steps_so_far < train_params.steps: steps_to_train = checkpoints.pop(0) - steps_so_far lr_start = optimizer.learning_rate.numpy() @@ -452,15 +580,43 @@ def train_in_dir( ) if steps_so_far < train_params.steps: - filename, ext = Path(output_filename).parts[1:] - checkpoint_filename = filename + (f"_@{steps_so_far}") + ext + filename = Path(output_filename).stem + filename_dir = Path(output_filename).parent.as_posix() + ext = Path(output_filename).suffix + checkpoint_filename = ( + filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext + ) + # If post processing we are stripping the clustering/pruning layers below + # Thus copy the model before saving, so training can continue + if post_process: + model_to_save = create_model( + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + model_to_load_from=model, + ) + else: + model_to_save = model else: checkpoint_filename = str(output_filename) + logger.info("Evaluate final Keras Model using %d steps", steps_per_epoch) + model.evaluate( + dataset, + steps=steps_per_epoch, + ) + model_to_save = model with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): + if post_process: + model_to_save = rewrite.post_process( # type: ignore[attr-defined] + model_to_save + ) save_as_tflite( - model, + model_to_save, checkpoint_filename, input_name, replace.shape_from_name[input_name], @@ -470,8 +626,7 @@ def train_in_dir( ) output_filenames.append(checkpoint_filename) - teacher.close() - return output_filenames + return model_to_save, output_filenames def save_as_tflite( diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py new file mode 100644 index 0000000..81bfd90 --- /dev/null +++ b/src/mlia/nn/rewrite/library/clustering.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Rewrite functions used to return layers ready for clustering.""" +from typing import Any + +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters + + +def fc_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: + """Fully connected TensorFlow Lite model ready for clustering.""" + rewrite_params = { + "number_of_clusters": 4, + "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR, + } + model = tfmot.clustering.keras.cluster_weights( + to_cluster=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Flatten(), + keras.layers.Dense(units=output_shape), + ] + ), + **rewrite_params + ) + return model + + +def conv2d_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: + """Conv2d TensorFlow Lite model ready for clustering.""" + rewrite_params = { + "number_of_clusters": 4, + "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR, + } + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, output_shape=output_shape + ) + model = tfmot.clustering.keras.cluster_weights( + to_cluster=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv2d_parameters), + keras.layers.BatchNormalization(), + keras.layers.ReLU(), + ] + ), + **rewrite_params + ) + return model diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py index 041ce85..92195d1 100644 --- a/src/mlia/nn/rewrite/library/fc_layer.py +++ b/src/mlia/nn/rewrite/library/fc_layer.py @@ -1,13 +1,13 @@ # SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Example rewrite with one fully connected layer.""" +"""Rewrite function used to return regular layers.""" from typing import Any from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 -def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for rewrite.""" +def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: + """Fully connected TensorFlow Lite model for rewrite.""" model = keras.Sequential( ( keras.layers.InputLayer(input_shape=input_shape), diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py new file mode 100644 index 0000000..4f08170 --- /dev/null +++ b/src/mlia/nn/rewrite/library/helper_functions.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Helper functions for the rewrite library.""" +import math +from typing import Any + +import numpy as np + + +def compute_conv2d_parameters( + input_shape: np.ndarray, output_shape: np.ndarray +) -> dict[str, Any]: + """Compute needed kernel size and strides for a given input and output_shape.""" + input_shape = input_shape.tolist() + output_shape = output_shape.tolist() + assert len(input_shape) == 3 + assert len(output_shape) == 3 + num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1] + padding = "valid" + kernel_size = (3, 3) + stride_h = round(input_shape[0] / output_shape[0]) + check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1 + stride_w = round(input_shape[1] / output_shape[1]) + check_output_size_w = math.floor((input_shape[1] - kernel_size[1]) / stride_w) + 1 + if check_output_size_h != output_shape[0] or check_output_size_w != output_shape[1]: + padding = "same" + return { + "filters": num_filters, + "kernel_size": kernel_size, + "padding": padding, + "strides": (stride_h, stride_w), + } diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py new file mode 100644 index 0000000..745fa8b --- /dev/null +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Rewrite functions used to return layers ready for sparse pruning.""" +from typing import Any + +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters + + +def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: + """Fully connected TensorFlow Lite model ready for sparse pruning.""" + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Reshape([-1]), + keras.layers.Dense(output_shape), + ] + ), + sparsity_m_by_n=(2, 4), + ) + + return model + + +def conv2d_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: + """Conv2d TensorFlow Lite model ready for sparse pruning.""" + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, output_shape=output_shape + ) + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv2d_parameters), + keras.layers.BatchNormalization(), + keras.layers.ReLU(), + ] + ), + sparsity_m_by_n=(2, 4), + ) + + return model diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 81a614f..b61e713 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -151,7 +151,7 @@ def get_optimizer( def _get_optimizer( model: keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] @@ -173,22 +173,17 @@ def _get_optimizer( def _get_rewrite_params( - training_parameters: list[dict | None] | None = None, -) -> list: + training_parameters: dict | None = None, +) -> TrainingParameters: """Get the rewrite TrainingParameters. Return the default constructed TrainingParameters() per default, but can be overwritten in the unit tests. """ - if training_parameters is None: - return [TrainingParameters()] + if not training_parameters: + return TrainingParameters() - if training_parameters[0] is None: - train_params = TrainingParameters() - else: - train_params = TrainingParameters(**training_parameters[0]) - - return [train_params] + return TrainingParameters(**training_parameters) def _get_optimizer_configuration( @@ -196,7 +191,7 @@ def _get_optimizer_configuration( optimization_target: int | float | str, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -222,7 +217,7 @@ def _get_optimizer_configuration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=rewrite_params[0], + train_params=rewrite_params, ) raise ConfigurationError( |