aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py55
1 files changed, 45 insertions, 10 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index a4d47c4..2a7b432 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -12,6 +12,7 @@ from typing import Any
from typing import Callable
from typing import cast
+import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.errors import ConfigurationError
@@ -22,6 +23,10 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
+from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
+from mlia.nn.rewrite.library.fc_sparsity24_layer import (
+ get_keras_model as fc_rewrite_sparsity24,
+)
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
@@ -45,6 +50,43 @@ class Rewrite:
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
+ def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model:
+ """Return a quantized model if required."""
+ if model_is_quantized:
+ model = tfmot.quantization.keras.quantize_model(model)
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
+ return model
+
+
+class PruningRewrite(Rewrite):
+ """Derived Rewrite class with pruning-specific logic."""
+
+ pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
+
+ strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
+
+ def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model:
+ """Return a quantized model if required."""
+ if model_is_quantized:
+ # placeholder for PQAT
+ pass
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return pruning-specific rewrite callback."""
+ return [self.pruning_callback()]
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Pruning-specific post-processing rewrite options."""
+ return self.strip_pruning_wrapper(model)
+
@dataclass
class DynamicallyLoadedRewrite(Rewrite):
@@ -113,14 +155,8 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
- DynamicallyLoadedRewrite(
- "fully-connected",
- "mlia.nn.rewrite.library.fc_layer.get_keras_model",
- ),
- DynamicallyLoadedRewrite(
- "fully-connected-sparsity24",
- "mlia.nn.rewrite.library.fc_sparsity24_layer.get_keras_model24",
- ),
+ Rewrite("fully-connected", fc_rewrite),
+ PruningRewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
]
)
@@ -142,7 +178,6 @@ class RewritingOptimizer(Optimizer):
rewrite = RewritingOptimizer.registry.items[
self.optimizer_configuration.optimization_target
]
-
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
@@ -161,7 +196,7 @@ class RewritingOptimizer(Optimizer):
unmodified_model=tflite_model if use_unmodified_model else None,
output_model=str(tmp_output),
input_tfrec=str(tfrecord),
- replace_fn=rewrite,
+ rewrite=rewrite,
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
train_params=self.optimizer_configuration.train_params,