diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-15 08:12:30 +0100 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-21 16:51:15 +0100 |
commit | 856111bcaef76c60303bdf2ae7cbf718d93d1df4 (patch) | |
tree | d955901817194e48e478f751140bd3c1741d1834 /src/mlia/nn/rewrite/core | |
parent | 0d3cc76284f9311c99169b568570d767f5b0aeb6 (diff) | |
download | mlia-856111bcaef76c60303bdf2ae7cbf718d93d1df4.tar.gz |
feat: Implement the conv2D rewrites for int8 and fp32 models
Enable clustering and fully connected rewrites for conv2D layers.
Resolves: MLIA-1159 and MLIA-1160
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I640b8a7e79e455b12fb68d02ac1c33213b8de9c6
Diffstat (limited to 'src/mlia/nn/rewrite/core')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 8fba806..6674d02 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -15,6 +15,9 @@ from typing import Callable import numpy as np import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 +from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module + is_pruned_m_by_n, +) from mlia.core.errors import ConfigurationError from mlia.core.reporting import Column @@ -32,7 +35,6 @@ from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry - logger = logging.getLogger(__name__) RewriteCallable = Callable[[Any, Any], keras.Model] @@ -131,7 +133,20 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): return model def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: - """Not needed here.""" + """Check if sparity has produced the correct result.""" + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + if not is_pruned_m_by_n(weight, m_by_n=(2, 4)): + logger.warning( + "\nWARNING: Could not find (2,4) sparsity, " + "in layer %s for weight %s \n", + layer.name, + weight.name, + ) + return False return True @@ -251,12 +266,6 @@ class RewritingOptimizer(Optimizer): self.optimizer_configuration.optimization_target ] - if self.optimizer_configuration.optimization_target in [ - "conv2d-clustering", - "conv2d-sparsity24", - ]: - raise NotImplementedError - use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) |