diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-17 09:05:03 +0100 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-21 16:51:15 +0100 |
commit | 3002baa6b1fd226d38bcfabfe3dc15556833be6a (patch) | |
tree | a7158e696f1b61cf98cef1de24f056bf9a71c6cd /src/mlia/nn/rewrite/core | |
parent | 856111bcaef76c60303bdf2ae7cbf718d93d1df4 (diff) | |
download | mlia-main.tar.gz |
Rework doctrings in rewrite functions based on recent changes
Resolves MLIA-944
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I31a37e17a296f8a16d0db408d48c6de65c05300e
Diffstat (limited to 'src/mlia/nn/rewrite/core')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 6674d02..6d915c6 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -48,7 +48,7 @@ class Rewrite(ABC): self.function = rewrite_fn def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model: - """Perform the rewrite operation using the configured function.""" + """Return an instance of the rewrite model.""" try: return self.function(input_shape, output_shape) except Exception as ex: @@ -60,11 +60,11 @@ class Rewrite(ABC): @abstractmethod def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" + """Return rewrite callbacks.""" @abstractmethod def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return post-processing rewrite option.""" @abstractmethod def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: @@ -72,7 +72,7 @@ class Rewrite(ABC): class GenericRewrite(Rewrite): - """Graph rewrite logic for fully-connected rewrite.""" + """Rewrite class for generic rewrites e.g. fully-connected.""" def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" @@ -83,7 +83,7 @@ class GenericRewrite(Rewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return default post-processing rewrite option.""" return model def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: @@ -101,14 +101,14 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC): class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): - """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" + """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity24.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) def quantize(self, model: keras.Model) -> keras.Model: - """Skip quantization when using pruning rewrite.""" + """Skip quantization when using sparsity rewrite.""" return model def training_callbacks(self) -> list: @@ -116,7 +116,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): return [self.pruning_callback()] def post_process(self, model: keras.Model) -> keras.Model: - """Pruning-specific post-processing rewrite options.""" + """Pruning-specific post-processing rewrite option.""" return self.strip_pruning_wrapper(model) def preserved_quantize( @@ -151,7 +151,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): class ClusteringRewrite(QuantizeAwareTrainingRewrite): - """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + """Rewrite class for clustering rewrite e.g. fully-connected-clustering.""" _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) @@ -170,7 +170,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): if not number_of_clusters: raise ValueError( """ - Expected check_preserved_quantize to have argument number_of_clusters. + Expected check_optimization to have argument number_of_clusters. """ ) @@ -197,7 +197,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return the clustering stripped model.""" + """Clustering-specific post-processing rewrite option.""" return self._strip_clustering_wrapper(model) |