diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-17 09:05:03 +0100 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-21 16:51:15 +0100 |
commit | 3002baa6b1fd226d38bcfabfe3dc15556833be6a (patch) | |
tree | a7158e696f1b61cf98cef1de24f056bf9a71c6cd /src/mlia/nn/rewrite | |
parent | 856111bcaef76c60303bdf2ae7cbf718d93d1df4 (diff) | |
download | mlia-main.tar.gz |
Rework doctrings in rewrite functions based on recent changes
Resolves MLIA-944
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I31a37e17a296f8a16d0db408d48c6de65c05300e
Diffstat (limited to 'src/mlia/nn/rewrite')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 22 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/clustering.py | 6 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/fc_layer.py | 4 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/sparsity.py | 6 |
4 files changed, 19 insertions, 19 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 6674d02..6d915c6 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -48,7 +48,7 @@ class Rewrite(ABC): self.function = rewrite_fn def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model: - """Perform the rewrite operation using the configured function.""" + """Return an instance of the rewrite model.""" try: return self.function(input_shape, output_shape) except Exception as ex: @@ -60,11 +60,11 @@ class Rewrite(ABC): @abstractmethod def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" + """Return rewrite callbacks.""" @abstractmethod def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return post-processing rewrite option.""" @abstractmethod def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: @@ -72,7 +72,7 @@ class Rewrite(ABC): class GenericRewrite(Rewrite): - """Graph rewrite logic for fully-connected rewrite.""" + """Rewrite class for generic rewrites e.g. fully-connected.""" def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" @@ -83,7 +83,7 @@ class GenericRewrite(Rewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return default post-processing rewrite option.""" return model def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: @@ -101,14 +101,14 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC): class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): - """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" + """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity24.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) def quantize(self, model: keras.Model) -> keras.Model: - """Skip quantization when using pruning rewrite.""" + """Skip quantization when using sparsity rewrite.""" return model def training_callbacks(self) -> list: @@ -116,7 +116,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): return [self.pruning_callback()] def post_process(self, model: keras.Model) -> keras.Model: - """Pruning-specific post-processing rewrite options.""" + """Pruning-specific post-processing rewrite option.""" return self.strip_pruning_wrapper(model) def preserved_quantize( @@ -151,7 +151,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): class ClusteringRewrite(QuantizeAwareTrainingRewrite): - """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + """Rewrite class for clustering rewrite e.g. fully-connected-clustering.""" _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) @@ -170,7 +170,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): if not number_of_clusters: raise ValueError( """ - Expected check_preserved_quantize to have argument number_of_clusters. + Expected check_optimization to have argument number_of_clusters. """ ) @@ -197,7 +197,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return the clustering stripped model.""" + """Clustering-specific post-processing rewrite option.""" return self._strip_clustering_wrapper(model) diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py index 6f06c48..81bfd90 100644 --- a/src/mlia/nn/rewrite/library/clustering.py +++ b/src/mlia/nn/rewrite/library/clustering.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Example rewrite with one fully connected clustered layer.""" +"""Rewrite functions used to return layers ready for clustering.""" from typing import Any import tensorflow_model_optimization as tfmot @@ -10,7 +10,7 @@ from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters def fc_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for clustering rewrite.""" + """Fully connected TensorFlow Lite model ready for clustering.""" rewrite_params = { "number_of_clusters": 4, "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR, @@ -29,7 +29,7 @@ def fc_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: def conv2d_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for clustering rewrite.""" + """Conv2d TensorFlow Lite model ready for clustering.""" rewrite_params = { "number_of_clusters": 4, "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR, diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py index cb98cb9..92195d1 100644 --- a/src/mlia/nn/rewrite/library/fc_layer.py +++ b/src/mlia/nn/rewrite/library/fc_layer.py @@ -1,13 +1,13 @@ # SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Example rewrite with one fully connected layer.""" +"""Rewrite function used to return regular layers.""" from typing import Any from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for rewrite.""" + """Fully connected TensorFlow Lite model for rewrite.""" model = keras.Sequential( ( keras.layers.InputLayer(input_shape=input_shape), diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py index 709593a..745fa8b 100644 --- a/src/mlia/nn/rewrite/library/sparsity.py +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Example rewrite with one fully connected clustered layer.""" +"""Rewrite functions used to return layers ready for sparse pruning.""" from typing import Any import tensorflow_model_optimization as tfmot @@ -10,7 +10,7 @@ from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for rewrite.""" + """Fully connected TensorFlow Lite model ready for sparse pruning.""" model = tfmot.sparsity.keras.prune_low_magnitude( to_prune=keras.Sequential( [ @@ -26,7 +26,7 @@ def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: def conv2d_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: - """Generate TensorFlow Lite model for rewrite.""" + """Conv2d TensorFlow Lite model ready for sparse pruning.""" conv2d_parameters = compute_conv2d_parameters( input_shape=input_shape, output_shape=output_shape ) |