aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-05-17 09:05:03 +0100
committerNathan Bailey <nathan.bailey@arm.com>2024-05-21 16:51:15 +0100
commit3002baa6b1fd226d38bcfabfe3dc15556833be6a (patch)
treea7158e696f1b61cf98cef1de24f056bf9a71c6cd /src/mlia/nn
parent856111bcaef76c60303bdf2ae7cbf718d93d1df4 (diff)
downloadmlia-main.tar.gz
fix: Extend docstrings in the rewrite moduleHEADmain
Rework doctrings in rewrite functions based on recent changes Resolves MLIA-944 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I31a37e17a296f8a16d0db408d48c6de65c05300e
Diffstat (limited to 'src/mlia/nn')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py22
-rw-r--r--src/mlia/nn/rewrite/library/clustering.py6
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py4
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py6
4 files changed, 19 insertions, 19 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 6674d02..6d915c6 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -48,7 +48,7 @@ class Rewrite(ABC):
self.function = rewrite_fn
def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model:
- """Perform the rewrite operation using the configured function."""
+ """Return an instance of the rewrite model."""
try:
return self.function(input_shape, output_shape)
except Exception as ex:
@@ -60,11 +60,11 @@ class Rewrite(ABC):
@abstractmethod
def training_callbacks(self) -> list:
- """Return default rewrite callbacks."""
+ """Return rewrite callbacks."""
@abstractmethod
def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ """Return post-processing rewrite option."""
@abstractmethod
def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
@@ -72,7 +72,7 @@ class Rewrite(ABC):
class GenericRewrite(Rewrite):
- """Graph rewrite logic for fully-connected rewrite."""
+ """Rewrite class for generic rewrites e.g. fully-connected."""
def quantize(self, model: keras.Model) -> keras.Model:
"""Return a quantized model if required."""
@@ -83,7 +83,7 @@ class GenericRewrite(Rewrite):
return []
def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ """Return default post-processing rewrite option."""
return model
def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
@@ -101,14 +101,14 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC):
class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
- """Graph rewrite logic for fully-connected-sparsity24 rewrite."""
+ """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity24."""
pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
def quantize(self, model: keras.Model) -> keras.Model:
- """Skip quantization when using pruning rewrite."""
+ """Skip quantization when using sparsity rewrite."""
return model
def training_callbacks(self) -> list:
@@ -116,7 +116,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
return [self.pruning_callback()]
def post_process(self, model: keras.Model) -> keras.Model:
- """Pruning-specific post-processing rewrite options."""
+ """Pruning-specific post-processing rewrite option."""
return self.strip_pruning_wrapper(model)
def preserved_quantize(
@@ -151,7 +151,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
class ClusteringRewrite(QuantizeAwareTrainingRewrite):
- """Graph clustering rewrite logic to be used by RewritingOptimizer."""
+ """Rewrite class for clustering rewrite e.g. fully-connected-clustering."""
_strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
@@ -170,7 +170,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite):
if not number_of_clusters:
raise ValueError(
"""
- Expected check_preserved_quantize to have argument number_of_clusters.
+ Expected check_optimization to have argument number_of_clusters.
"""
)
@@ -197,7 +197,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite):
return []
def post_process(self, model: keras.Model) -> keras.Model:
- """Return the clustering stripped model."""
+ """Clustering-specific post-processing rewrite option."""
return self._strip_clustering_wrapper(model)
diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py
index 6f06c48..81bfd90 100644
--- a/src/mlia/nn/rewrite/library/clustering.py
+++ b/src/mlia/nn/rewrite/library/clustering.py
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Example rewrite with one fully connected clustered layer."""
+"""Rewrite functions used to return layers ready for clustering."""
from typing import Any
import tensorflow_model_optimization as tfmot
@@ -10,7 +10,7 @@ from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
def fc_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for clustering rewrite."""
+ """Fully connected TensorFlow Lite model ready for clustering."""
rewrite_params = {
"number_of_clusters": 4,
"cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
@@ -29,7 +29,7 @@ def fc_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
def conv2d_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for clustering rewrite."""
+ """Conv2d TensorFlow Lite model ready for clustering."""
rewrite_params = {
"number_of_clusters": 4,
"cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
index cb98cb9..92195d1 100644
--- a/src/mlia/nn/rewrite/library/fc_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_layer.py
@@ -1,13 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Example rewrite with one fully connected layer."""
+"""Rewrite function used to return regular layers."""
from typing import Any
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for rewrite."""
+ """Fully connected TensorFlow Lite model for rewrite."""
model = keras.Sequential(
(
keras.layers.InputLayer(input_shape=input_shape),
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
index 709593a..745fa8b 100644
--- a/src/mlia/nn/rewrite/library/sparsity.py
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Example rewrite with one fully connected clustered layer."""
+"""Rewrite functions used to return layers ready for sparse pruning."""
from typing import Any
import tensorflow_model_optimization as tfmot
@@ -10,7 +10,7 @@ from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for rewrite."""
+ """Fully connected TensorFlow Lite model ready for sparse pruning."""
model = tfmot.sparsity.keras.prune_low_magnitude(
to_prune=keras.Sequential(
[
@@ -26,7 +26,7 @@ def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
def conv2d_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for rewrite."""
+ """Conv2d TensorFlow Lite model ready for sparse pruning."""
conv2d_parameters = compute_conv2d_parameters(
input_shape=input_shape, output_shape=output_shape
)