aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-05-17 09:05:03 +0100
committerNathan Bailey <nathan.bailey@arm.com>2024-05-21 16:51:15 +0100
commit3002baa6b1fd226d38bcfabfe3dc15556833be6a (patch)
treea7158e696f1b61cf98cef1de24f056bf9a71c6cd /src/mlia/nn/rewrite/core
parent856111bcaef76c60303bdf2ae7cbf718d93d1df4 (diff)
downloadmlia-main.tar.gz
fix: Extend docstrings in the rewrite moduleHEADmain
Rework doctrings in rewrite functions based on recent changes Resolves MLIA-944 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I31a37e17a296f8a16d0db408d48c6de65c05300e
Diffstat (limited to 'src/mlia/nn/rewrite/core')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py22
1 files changed, 11 insertions, 11 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 6674d02..6d915c6 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -48,7 +48,7 @@ class Rewrite(ABC):
self.function = rewrite_fn
def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model:
- """Perform the rewrite operation using the configured function."""
+ """Return an instance of the rewrite model."""
try:
return self.function(input_shape, output_shape)
except Exception as ex:
@@ -60,11 +60,11 @@ class Rewrite(ABC):
@abstractmethod
def training_callbacks(self) -> list:
- """Return default rewrite callbacks."""
+ """Return rewrite callbacks."""
@abstractmethod
def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ """Return post-processing rewrite option."""
@abstractmethod
def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
@@ -72,7 +72,7 @@ class Rewrite(ABC):
class GenericRewrite(Rewrite):
- """Graph rewrite logic for fully-connected rewrite."""
+ """Rewrite class for generic rewrites e.g. fully-connected."""
def quantize(self, model: keras.Model) -> keras.Model:
"""Return a quantized model if required."""
@@ -83,7 +83,7 @@ class GenericRewrite(Rewrite):
return []
def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ """Return default post-processing rewrite option."""
return model
def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
@@ -101,14 +101,14 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC):
class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
- """Graph rewrite logic for fully-connected-sparsity24 rewrite."""
+ """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity24."""
pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
def quantize(self, model: keras.Model) -> keras.Model:
- """Skip quantization when using pruning rewrite."""
+ """Skip quantization when using sparsity rewrite."""
return model
def training_callbacks(self) -> list:
@@ -116,7 +116,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
return [self.pruning_callback()]
def post_process(self, model: keras.Model) -> keras.Model:
- """Pruning-specific post-processing rewrite options."""
+ """Pruning-specific post-processing rewrite option."""
return self.strip_pruning_wrapper(model)
def preserved_quantize(
@@ -151,7 +151,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
class ClusteringRewrite(QuantizeAwareTrainingRewrite):
- """Graph clustering rewrite logic to be used by RewritingOptimizer."""
+ """Rewrite class for clustering rewrite e.g. fully-connected-clustering."""
_strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
@@ -170,7 +170,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite):
if not number_of_clusters:
raise ValueError(
"""
- Expected check_preserved_quantize to have argument number_of_clusters.
+ Expected check_optimization to have argument number_of_clusters.
"""
)
@@ -197,7 +197,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite):
return []
def post_process(self, model: keras.Model) -> keras.Model:
- """Return the clustering stripped model."""
+ """Clustering-specific post-processing rewrite option."""
return self._strip_clustering_wrapper(model)