From 17813ba5be09f0e11fc0748afa4ccf2da02881b6 Mon Sep 17 00:00:00 2001 From: Madeleine Dunn Date: Mon, 13 Nov 2023 15:40:21 +0000 Subject: feat: Implement fp32 sparsity 2:4 rewrite - Update the existing placeholder with code to prune the given model Resolves: MLIA-1002 Signed-off-by: Madeleine Dunn Change-Id: I76b0e0bfe81be5e57d518cd7bb588eef76a11641 --- src/mlia/nn/rewrite/core/rewrite.py | 55 ++++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 10 deletions(-) (limited to 'src/mlia/nn/rewrite/core/rewrite.py') diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index a4d47c4..2a7b432 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -12,6 +12,7 @@ from typing import Any from typing import Callable from typing import cast +import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.core.errors import ConfigurationError @@ -22,6 +23,10 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters +from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite +from mlia.nn.rewrite.library.fc_sparsity24_layer import ( + get_keras_model as fc_rewrite_sparsity24, +) from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry @@ -45,6 +50,43 @@ class Rewrite: except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex + def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model: + """Return a quantized model if required.""" + if model_is_quantized: + model = tfmot.quantization.keras.quantize_model(model) + return model + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + return model + + +class PruningRewrite(Rewrite): + """Derived Rewrite class with pruning-specific logic.""" + + pruning_callback = tfmot.sparsity.keras.UpdatePruningStep + + strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) + + def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model: + """Return a quantized model if required.""" + if model_is_quantized: + # placeholder for PQAT + pass + return model + + def training_callbacks(self) -> list: + """Return pruning-specific rewrite callback.""" + return [self.pruning_callback()] + + def post_process(self, model: keras.Model) -> keras.Model: + """Pruning-specific post-processing rewrite options.""" + return self.strip_pruning_wrapper(model) + @dataclass class DynamicallyLoadedRewrite(Rewrite): @@ -113,14 +155,8 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ - DynamicallyLoadedRewrite( - "fully-connected", - "mlia.nn.rewrite.library.fc_layer.get_keras_model", - ), - DynamicallyLoadedRewrite( - "fully-connected-sparsity24", - "mlia.nn.rewrite.library.fc_sparsity24_layer.get_keras_model24", - ), + Rewrite("fully-connected", fc_rewrite), + PruningRewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), ] ) @@ -142,7 +178,6 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] - use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -161,7 +196,7 @@ class RewritingOptimizer(Optimizer): unmodified_model=tflite_model if use_unmodified_model else None, output_model=str(tmp_output), input_tfrec=str(tfrecord), - replace_fn=rewrite, + rewrite=rewrite, input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, -- cgit v1.2.1