diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 200 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 86 |
2 files changed, 216 insertions, 70 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 6674d02..c2ad364 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -8,11 +8,15 @@ import tempfile from abc import ABC from abc import abstractmethod from dataclasses import dataclass +from inspect import getfullargspec from pathlib import Path +from statistics import fmean from typing import Any from typing import Callable +from typing import Generator import numpy as np +import tensorflow as tf import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module @@ -31,12 +35,14 @@ from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite from mlia.nn.rewrite.library.fc_layer import fc_rewrite from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_unstructured_rewrite from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_unstructured_rewrite from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry logger = logging.getLogger(__name__) -RewriteCallable = Callable[[Any, Any], keras.Model] +RewriteCallable = Callable[..., keras.Model] class Rewrite(ABC): @@ -47,10 +53,23 @@ class Rewrite(ABC): self.name = name self.function = rewrite_fn - def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model: + def __call__( + self, input_shape: Any, output_shape: Any, **kwargs: Any + ) -> keras.Model: """Perform the rewrite operation using the configured function.""" try: - return self.function(input_shape, output_shape) + return self.function(input_shape, output_shape, **kwargs) + except TypeError as ex: + expected_args = self.return_rewrite_func_args() + if "input_shape" in expected_args: + expected_args.remove("input_shape") + if "output_shape" in expected_args: + expected_args.remove("output_shape") + raise KeyError( + f"Found unexpected parameters for rewrite. Expected (sub)set " + f"of {expected_args} found unexpected parameter(s) " + f"{list(set(list(kwargs.keys())) - set(expected_args))}" + ) from ex except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex @@ -58,21 +77,25 @@ class Rewrite(ABC): """Return a quantized model if required.""" return model + def return_rewrite_func_args(self) -> list[str]: + """Return the expected args of the rewrite function.""" + return getfullargspec(self.function).args + @abstractmethod def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" + """Return rewrite callbacks.""" @abstractmethod def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return post-processing rewrite option.""" @abstractmethod - def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + def check_optimization(self, model: keras.Model) -> bool: """Check if the optimization has produced the correct result.""" class GenericRewrite(Rewrite): - """Graph rewrite logic for fully-connected rewrite.""" + """Rewrite class for generic rewrites e.g. fully-connected.""" def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" @@ -83,10 +106,10 @@ class GenericRewrite(Rewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return default post-processing rewrite option.""" return model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization(self, model: keras.Model) -> bool: """Not needed here.""" return True @@ -99,16 +122,27 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC): """Apply optimization-aware quantization to a given model.""" return model + def check_optimization_generator( + self, model: keras.Model + ) -> Generator[tuple[tf.Tensor, keras.layers.Layer], None, None]: + """Loop for check_optimization function.""" + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + yield weight, layer + -class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): - """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" +class SparsityRewrite(QuantizeAwareTrainingRewrite): + """Base rewrite class for sparsity rewrites.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) def quantize(self, model: keras.Model) -> keras.Model: - """Skip quantization when using pruning rewrite.""" + """Skip quantization when using sparsity rewrite.""" return model def training_callbacks(self) -> list: @@ -116,7 +150,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): return [self.pruning_callback()] def post_process(self, model: keras.Model) -> keras.Model: - """Pruning-specific post-processing rewrite options.""" + """Pruning-specific post-processing rewrite option.""" return self.strip_pruning_wrapper(model) def preserved_quantize( @@ -129,29 +163,78 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): model, tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), ) - return model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization(self, model: keras.Model) -> bool: + """Not needed here.""" + return True + + +class UnstructuredSparsityRewrite(SparsityRewrite): + """ + Rewrite class for unstructured sparsity rewrite. + + e.g. fully-connected-unstructured-sparsity. + """ + + def check_optimization( + self, model: keras.Model, final_sparsity: float = 0.5, **_: Any + ) -> bool: + """Not needed here.""" + found_sparsity_list = [] + num_dec_places = str(final_sparsity)[::-1].find(".") + for weight, _ in self.check_optimization_generator(model=model): + weight_np = weight.numpy() + found_sparsity_list.append( + round(np.count_nonzero(weight_np) / weight_np.size, num_dec_places) + ) + if len(found_sparsity_list) == 0: + logger.warning( + "\nWARNING: Could not find any layers " + "in rewrite that could be sparsely pruned" + ) + return False + found_sparsity = fmean(found_sparsity_list) + if found_sparsity != final_sparsity: + logger.warning( + "\nWARNING: Found total sparsity of " + "rewrite model: %.2f " + "expected total sparsity to be: " + "%.2f\n", + found_sparsity, + final_sparsity, + ) + return False + return True + + +class StructuredSparsityRewrite(SparsityRewrite): + """Rewrite class for structured sparsity rewrite e.g. fully-connected-sparsity.""" + + def check_optimization( + self, + model: keras.Model, + sparsity_m: int = 2, + sparsity_n: int = 4, + **_: Any, + ) -> bool: """Check if sparity has produced the correct result.""" - for layer in model.layers: - for weight in layer.weights: - if "kernel" in weight.name: - if "kernel_min" in weight.name or "kernel_max" in weight.name: - continue - if not is_pruned_m_by_n(weight, m_by_n=(2, 4)): - logger.warning( - "\nWARNING: Could not find (2,4) sparsity, " - "in layer %s for weight %s \n", - layer.name, - weight.name, - ) - return False + for weight, layer in self.check_optimization_generator(model=model): + if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)): + logger.warning( + "\nWARNING: Could not find (%d, %d) sparsity, " + "in layer %s for weight %s \n", + sparsity_m, + sparsity_n, + layer.name, + weight.name, + ) + return False return True class ClusteringRewrite(QuantizeAwareTrainingRewrite): - """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + """Rewrite class for clustering rewrite e.g. fully-connected-clustering.""" _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) @@ -164,32 +247,22 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): ) return cqat_model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization( + self, model: keras.Model, num_clusters: int = 2, **_: Any + ) -> bool: """Check if clustering has produced the correct result.""" - number_of_clusters = kwargs.get("number_of_clusters") - if not number_of_clusters: - raise ValueError( - """ - Expected check_preserved_quantize to have argument number_of_clusters. - """ - ) - - for layer in model.layers: - for weight in layer.weights: - if "kernel" in weight.name: - if "kernel_min" in weight.name or "kernel_max" in weight.name: - continue - number_of_found_clusters = len(np.unique(weight)) - if number_of_found_clusters != number_of_clusters: - logger.warning( - "\nWARNING: Expected %d cluster(s), found %d " - "cluster(s) in layer %s for weight %s \n", - number_of_clusters, - number_of_found_clusters, - layer.name, - weight.name, - ) - return False + for weight, layer in self.check_optimization_generator(model=model): + number_of_found_clusters = len(np.unique(weight)) + if number_of_found_clusters != num_clusters: + logger.warning( + "\nWARNING: Expected %d cluster(s), found %d " + "cluster(s) in layer %s for weight %s \n", + num_clusters, + number_of_found_clusters, + layer.name, + weight.name, + ) + return False return True def training_callbacks(self) -> list: @@ -197,7 +270,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return the clustering stripped model.""" + """Clustering-specific post-processing rewrite option.""" return self._strip_clustering_wrapper(model) @@ -228,6 +301,7 @@ class RewriteConfiguration(OptimizerConfiguration): layers_to_optimize: list[str] | None = None dataset: Path | None = None train_params: TrainingParameters = TrainingParameters() + rewrite_specific_params: dict | None = None def __str__(self) -> str: """Return string representation of the configuration.""" @@ -240,10 +314,17 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ GenericRewrite("fully-connected", fc_rewrite), - Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite), + StructuredSparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite), ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite), - Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite), + StructuredSparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite), + UnstructuredSparsityRewrite( + "conv2d-unstructured-sparsity", conv2d_sparsity_unstructured_rewrite + ), + UnstructuredSparsityRewrite( + "fully-connected-unstructured-sparsity", + fc_sparsity_unstructured_rewrite, + ), ] ) @@ -265,7 +346,6 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] - use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -287,6 +367,10 @@ class RewritingOptimizer(Optimizer): input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, + rewrite_specific_params=self.optimizer_configuration.rewrite_specific_params, # pylint: disable=line-too-long + detect_activation_function=( + "activation" in rewrite.return_rewrite_func_args() + ), ) if orig_vs_repl_stats: diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 4204978..570968a 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -34,13 +34,13 @@ from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.utils.logging import log_action - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logger = logging.getLogger(__name__) @@ -83,6 +83,8 @@ def train( # pylint: disable=too-many-arguments input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, + detect_activation_function: bool = False, ) -> Any: """Extract and train a model, and return the results.""" if unmodified_model: @@ -122,6 +124,8 @@ def train( # pylint: disable=too-many-arguments rewrite=rewrite, is_qat=is_qat, train_params=train_params, + rewrite_specific_params=rewrite_specific_params, + detect_activation_function=detect_activation_function, ) for i, filename in enumerate(tflite_filenames): @@ -349,6 +353,41 @@ def set_up_data_pipeline( return dataset, steps_per_epoch +def detect_activation_from_rewrite_function(model_path: str) -> str: + """Given a rewrite model, choose the most common activation function.""" + interpreter = tf.lite.Interpreter(model_path=model_path) + interpreter.allocate_tensors() + act_func_match_list = [] + for tensor_details in interpreter.get_tensor_details(): + for act_func in ACTIVATION_FUNCTION_LIST: + tensor_name = tensor_details["name"].lower() + if act_func in tensor_name: + act_func_idx = tensor_name.index(act_func) + if ( + len(tensor_name) == act_func_idx + len(act_func) + or tensor_name[act_func_idx + len(act_func)] == ";" + ): + act_func_match_list.append( + tensor_name[ + act_func_idx : act_func_idx + len(act_func) # noqa: E203 + ] + ) + act_func_match = "relu" + if len(act_func_match_list) == 0: + logger.info( + "No activation function specified, setting activation function to ReLU" + ) + else: + act_func_match = max(set(act_func_match_list), key=act_func_match.count) + logger.info( + "No activation function specified, " + "setting activation function to most " + "common activation detected in rewrite graph: %s", + act_func_match, + ) + return act_func_match + + def train_in_dir( train_dir: str, baseline_dir: Any, @@ -356,6 +395,8 @@ def train_in_dir( rewrite: Callable, is_qat: bool, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, + detect_activation_function: bool = False, ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ and output.tfrec in train_dir. @@ -372,6 +413,18 @@ def train_in_dir( ) replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir)) + if detect_activation_function and ( + rewrite_specific_params is None + or "activation" not in list(rewrite_specific_params.keys()) + ): + detected_activation_function = detect_activation_from_rewrite_function( + ExtractPaths.tflite.replace(train_dir).as_posix() + ) + if rewrite_specific_params: + rewrite_specific_params["activation"] = detected_activation_function + else: + rewrite_specific_params = {"activation": detected_activation_function} + input_name, output_name = _get_io_tensors(teacher) model_is_quantized = replace.is_tensor_quantized(name=input_name) @@ -396,7 +449,13 @@ def train_in_dir( loss_fn = keras.losses.MeanSquaredError() model = create_model( - rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + rewrite_specific_params=rewrite_specific_params, ) logger.info(model.summary()) @@ -462,11 +521,9 @@ def train_in_dir( steps_per_epoch, post_process=True, ) - - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) if model_is_quantized and is_qat: model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] checkpoints = ( @@ -501,11 +558,10 @@ def train_in_dir( loss_fn, steps_per_epoch, ) - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) teacher.close() return output_filenames @@ -528,9 +584,13 @@ def create_model( # pylint: disable=too-many-arguments loss_fn: Callable, model_is_quantized: bool, model_to_load_from: keras.model | None = None, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Create a model, optionally from another.""" - model = rewrite(input_shape, output_shape) + if rewrite_specific_params: + model = rewrite(input_shape, output_shape, **rewrite_specific_params) + else: + model = rewrite(input_shape, output_shape) if model_is_quantized: model = rewrite.quantize(model) # type: ignore[attr-defined] model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) @@ -558,6 +618,7 @@ def model_fit( # pylint: disable=too-many-arguments loss_fn: Callable, steps_per_epoch: int, post_process: bool = False, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Train a tflite model.""" steps_so_far = 0 @@ -597,6 +658,7 @@ def model_fit( # pylint: disable=too-many-arguments loss_fn, model_is_quantized, model_to_load_from=model, + rewrite_specific_params=rewrite_specific_params, ) else: model_to_save = model |