From 3002baa6b1fd226d38bcfabfe3dc15556833be6a Mon Sep 17 00:00:00 2001 From: Nathan Bailey Date: Fri, 17 May 2024 09:05:03 +0100 Subject: fix: Extend docstrings in the rewrite module Rework doctrings in rewrite functions based on recent changes Resolves MLIA-944 Signed-off-by: Nathan Bailey Change-Id: I31a37e17a296f8a16d0db408d48c6de65c05300e --- src/mlia/nn/rewrite/core/rewrite.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) (limited to 'src/mlia/nn/rewrite/core') diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 6674d02..6d915c6 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -48,7 +48,7 @@ class Rewrite(ABC): self.function = rewrite_fn def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model: - """Perform the rewrite operation using the configured function.""" + """Return an instance of the rewrite model.""" try: return self.function(input_shape, output_shape) except Exception as ex: @@ -60,11 +60,11 @@ class Rewrite(ABC): @abstractmethod def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" + """Return rewrite callbacks.""" @abstractmethod def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return post-processing rewrite option.""" @abstractmethod def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: @@ -72,7 +72,7 @@ class Rewrite(ABC): class GenericRewrite(Rewrite): - """Graph rewrite logic for fully-connected rewrite.""" + """Rewrite class for generic rewrites e.g. fully-connected.""" def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" @@ -83,7 +83,7 @@ class GenericRewrite(Rewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return default post-processing rewrite option.""" return model def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: @@ -101,14 +101,14 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC): class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): - """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" + """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity24.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) def quantize(self, model: keras.Model) -> keras.Model: - """Skip quantization when using pruning rewrite.""" + """Skip quantization when using sparsity rewrite.""" return model def training_callbacks(self) -> list: @@ -116,7 +116,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): return [self.pruning_callback()] def post_process(self, model: keras.Model) -> keras.Model: - """Pruning-specific post-processing rewrite options.""" + """Pruning-specific post-processing rewrite option.""" return self.strip_pruning_wrapper(model) def preserved_quantize( @@ -151,7 +151,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): class ClusteringRewrite(QuantizeAwareTrainingRewrite): - """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + """Rewrite class for clustering rewrite e.g. fully-connected-clustering.""" _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) @@ -170,7 +170,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): if not number_of_clusters: raise ValueError( """ - Expected check_preserved_quantize to have argument number_of_clusters. + Expected check_optimization to have argument number_of_clusters. """ ) @@ -197,7 +197,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return the clustering stripped model.""" + """Clustering-specific post-processing rewrite option.""" return self._strip_clustering_wrapper(model) -- cgit v1.2.1