aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md95
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py203
-rw-r--r--src/mlia/nn/rewrite/core/train.py86
-rw-r--r--src/mlia/nn/rewrite/library/clustering.py78
-rw-r--r--src/mlia/nn/rewrite/library/fc_clustering_layer.py26
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py6
-rw-r--r--src/mlia/nn/rewrite/library/fc_sparsity24_layer.py23
-rw-r--r--src/mlia/nn/rewrite/library/helper_functions.py61
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py144
-rw-r--r--src/mlia/nn/select.py24
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml19
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml19
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml20
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-custom-augmentation.toml13
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-fully-connected-clustering.toml17
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-fully-connected-pruning.toml17
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml18
-rw-r--r--src/mlia/resources/optimization_profiles/optimization.toml3
-rw-r--r--src/mlia/target/common/optimization.py87
-rw-r--r--src/mlia/target/config.py11
-rw-r--r--tests/conftest.py42
-rw-r--r--tests/test_cli_commands.py60
-rw-r--r--tests/test_cli_helpers.py2
-rw-r--r--tests/test_common_optimization.py149
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py331
-rw-r--r--tests/test_nn_rewrite_core_train.py44
-rw-r--r--tests/test_nn_rewrite_library_helper_functions.py103
-rw-r--r--tests/test_nn_select.py57
-rw-r--r--tests/test_target_cortex_a_advisor.py21
-rw-r--r--tests/test_target_tosa_advisor.py21
-rw-r--r--tests_e2e/optimization_e2e_test.toml4
31 files changed, 1521 insertions, 283 deletions
diff --git a/README.md b/README.md
index 7d08a16..a684342 100644
--- a/README.md
+++ b/README.md
@@ -42,7 +42,9 @@ Information on reporting security issues can be found in
## License
-ML Inference Advisor is licensed under [Apache License 2.0](LICENSES/Apache-2.0.txt).
+ML Inference Advisor is licensed under [Apache License 2.0](LICENSES/Apache-2.0.txt)
+unless otherwise indicated. This project contains software under a range of
+permissive licenses, see [LICENSES](LICENSES/).
## Trademarks and copyrights
@@ -181,6 +183,16 @@ documentation, e.g. in the
candidates from the rewrite library, with or without training using a
small portion of the training data, to achieve local performance gains.
+The following rewrites are supported:
+
+* fully-connected - replaces a subgraph with a fully connected layer
+* fully-connected-sparsity - replaces a subgraph with a pruned 2:4 sparse fully connected layer
+* fully-connected-unstructured-sparsity - replaces a subgraph with an unstructured pruned fully connected layer
+* fully-connected-clustering - replaces a subgraph with a clustered fully connected layer
+* conv2d-sparsity - replaces a subgraph with a pruned 2:4 sparse conv2d layer
+* conv2d-unstructured-sparsity - replaces a subgraph with an unstructured pruned conv2d layer
+* conv2d-clustering - replaces a subgraph with a clustered conv2d layer
+
**Note:** A ***Keras model*** (.h5 or SavedModel) is required as input to
perform pruning and clustering. A ***TensorFlow Lite model*** is required as input
to perform a rewrite.
@@ -209,15 +221,72 @@ mlia optimize ~/models/ds_cnn_large_fp32.tflite \
--rewrite-end MobileNet/fc1/BiasAdd
```
-### optimization Profiles
+### Optimization Profiles
Training parameters for rewrites can be specified.
-There are a number of predefined profiles:
+There are a number of predefined profiles for rewrites shown below:
+
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Augmentations |
+| :----------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :-------------: |
+| optimization | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | "gaussian" |
+
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Num Clusters | Cluster Centroids Init |
+| :-------------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :----------: | :--------------------------------: |
+| optimization-fully-connected-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 16 | "CentroidInitialization.LINEAR" |
+
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Sparsity M | Sparsity N |
+| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: |
+| optimization-fully-connected-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 2 | 4 |
+
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Initial Sparsity | End Sparsity | End Step |
+| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: | :--------: |
+| optimization-fully-connected-unstructured-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 0.25 | 0.5 | 48000 |
+
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Num Clusters | Cluster Centroids Init | Activation | Kernel Size |
+| :-------------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :----------: | :--------------------------------: | :--------: | :---------: |
+| optimization-conv2d-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 16 | "CentroidInitialization.LINEAR" | "relu" | 3x3 |
+
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Sparsity M | Sparsity N | Activation | Kernel Size |
+| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: | :--------: | :---------: |
+| optimization-conv2d-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 2 | 4 | "relu" | 3x3 |
+
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Initial Sparsity | End Sparsity | End Step | Activation | Kernel Size |
+| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: | :--------: | :--------:| :---------: |
+| optimization-conv2d-unstructured-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 0.25 | 0.5 | 48000 | "relu" | 3x3 |
+
+These are summarized below:
+
+* optimization - Provides training parameters for rewrites
+* optimization-fully-connected-clustering - Provides training parameters for rewrites and cluster specific parameters for the fully-connected-clustering rewrite
+* optimization-fully-connected-pruning - Provides training parameters for rewrites and pruning specific parameters for the fully-connected-sparsity rewrite
+* optimization-conv2d-clustering - Provides training parameters for rewrites and cluster specific parameters for the conv2d-clustering rewrite
+* optimization-conv2d-pruning - Provides training parameters for rewrites and pruning specific parameters for the conv2d-sparsity rewrite
+
+Note for convolutional rewrites (e.g. optimization-conv2d-pruning). The activation function for the rewrite can be selected in the optimization profile from the following list:
-| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints |
-| :----------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: |
-| optimization | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None |
+* "relu" - Standard ReLU activation function
+* "relu6" - ReLU6 activation function i.e. ReLU activation function capped at 6
+* "none" - No activation function
+
+The user can also specify custom augmentations as part of the training parameters. An example of this can be found in the following optimization profile:
+
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Augmentations - gaussian_strength | Augmentations - mixup_strength |
+| :------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :-------------------------------: | :----------------------------: |
+| optimization-custom-augmentation | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 0.1 | 0.1 |
+
+The augmentations consist of 2 parameters: mixup strength and gaussian strength.
+
+Augmentations can be selected from a number of pre-defined profiles (see the table below) or each individual parameter can be chosen (see optimization_custom_augmentation above for an example):
+
+| Name | MixUp Strength | Gaussian Strength |
+| :------------------: | :------------: | :---------------: |
+| "none" | None | None |
+| "gaussian" | None | 1.0 |
+| "mixup" | 1.0 | None |
+| "mixout" | 1.6 | None |
+| "mix_gaussian_large" | 2.0 | 1.0 |
+| "mix_gaussian_small" | 1.6 | 0.3 |
```bash
##### An example for using optimization Profiles
@@ -228,7 +297,7 @@ mlia optimize ~/models/ds_cnn_large_fp32.tflite \
--dataset input.tfrec \
--rewrite-target fully-connected \
--rewrite-start MobileNet/avg_pool/AvgPool \
- --rewrite-end MobileNet/fc1/BiasAdd_
+ --rewrite-end MobileNet/fc1/BiasAdd
```
#### Custom optimization Profiles
@@ -244,7 +313,17 @@ apply for each optimization.
``` bash
# for custom profiles
-mlia ops --optimization-profile ~/my_custom_optimization_profile.toml
+mlia optimize --optimization-profile ~/my_custom_optimization_profile.toml
+```
+
+When providing rewrite-specific parameters e.g. for clustering, the rewrite name should be specified in the toml:
+
+For example, the following provides rewrite-specific parameters for the fully-connected-clustering rewrite
+
+``` bash
+[rewrite.fully-connected-clustering]
+num_clusters = 16
+cluster_centroids_init = "CentroidInitialization.LINEAR"
```
# Target profiles
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index e2c097c..c2ad364 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -8,13 +8,20 @@ import tempfile
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
+from inspect import getfullargspec
from pathlib import Path
+from statistics import fmean
from typing import Any
from typing import Callable
+from typing import Generator
import numpy as np
+import tensorflow as tf
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module
+ is_pruned_m_by_n,
+)
from mlia.core.errors import ConfigurationError
from mlia.core.reporting import Column
@@ -24,19 +31,18 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
-from mlia.nn.rewrite.library.fc_clustering_layer import (
- get_keras_model_clus as fc_clustering_rewrite,
-)
-from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
-from mlia.nn.rewrite.library.fc_sparsity24_layer import (
- get_keras_model as fc_rewrite_sparsity24,
-)
+from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite
+from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite
+from mlia.nn.rewrite.library.fc_layer import fc_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_unstructured_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_unstructured_rewrite
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
-
logger = logging.getLogger(__name__)
-RewriteCallable = Callable[[Any, Any], keras.Model]
+RewriteCallable = Callable[..., keras.Model]
class Rewrite(ABC):
@@ -47,10 +53,23 @@ class Rewrite(ABC):
self.name = name
self.function = rewrite_fn
- def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model:
+ def __call__(
+ self, input_shape: Any, output_shape: Any, **kwargs: Any
+ ) -> keras.Model:
"""Perform the rewrite operation using the configured function."""
try:
- return self.function(input_shape, output_shape)
+ return self.function(input_shape, output_shape, **kwargs)
+ except TypeError as ex:
+ expected_args = self.return_rewrite_func_args()
+ if "input_shape" in expected_args:
+ expected_args.remove("input_shape")
+ if "output_shape" in expected_args:
+ expected_args.remove("output_shape")
+ raise KeyError(
+ f"Found unexpected parameters for rewrite. Expected (sub)set "
+ f"of {expected_args} found unexpected parameter(s) "
+ f"{list(set(list(kwargs.keys())) - set(expected_args))}"
+ ) from ex
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
@@ -58,21 +77,25 @@ class Rewrite(ABC):
"""Return a quantized model if required."""
return model
+ def return_rewrite_func_args(self) -> list[str]:
+ """Return the expected args of the rewrite function."""
+ return getfullargspec(self.function).args
+
@abstractmethod
def training_callbacks(self) -> list:
- """Return default rewrite callbacks."""
+ """Return rewrite callbacks."""
@abstractmethod
def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ """Return post-processing rewrite option."""
@abstractmethod
- def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ def check_optimization(self, model: keras.Model) -> bool:
"""Check if the optimization has produced the correct result."""
class GenericRewrite(Rewrite):
- """Graph rewrite logic for fully-connected rewrite."""
+ """Rewrite class for generic rewrites e.g. fully-connected."""
def quantize(self, model: keras.Model) -> keras.Model:
"""Return a quantized model if required."""
@@ -83,10 +106,10 @@ class GenericRewrite(Rewrite):
return []
def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ """Return default post-processing rewrite option."""
return model
- def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ def check_optimization(self, model: keras.Model) -> bool:
"""Not needed here."""
return True
@@ -99,16 +122,27 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC):
"""Apply optimization-aware quantization to a given model."""
return model
+ def check_optimization_generator(
+ self, model: keras.Model
+ ) -> Generator[tuple[tf.Tensor, keras.layers.Layer], None, None]:
+ """Loop for check_optimization function."""
+ for layer in model.layers:
+ for weight in layer.weights:
+ if "kernel" in weight.name:
+ if "kernel_min" in weight.name or "kernel_max" in weight.name:
+ continue
+ yield weight, layer
+
-class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
- """Graph rewrite logic for fully-connected-sparsity24 rewrite."""
+class SparsityRewrite(QuantizeAwareTrainingRewrite):
+ """Base rewrite class for sparsity rewrites."""
pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
def quantize(self, model: keras.Model) -> keras.Model:
- """Skip quantization when using pruning rewrite."""
+ """Skip quantization when using sparsity rewrite."""
return model
def training_callbacks(self) -> list:
@@ -116,7 +150,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
return [self.pruning_callback()]
def post_process(self, model: keras.Model) -> keras.Model:
- """Pruning-specific post-processing rewrite options."""
+ """Pruning-specific post-processing rewrite option."""
return self.strip_pruning_wrapper(model)
def preserved_quantize(
@@ -129,16 +163,78 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
model,
tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
)
-
return model
- def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ def check_optimization(self, model: keras.Model) -> bool:
"""Not needed here."""
return True
+class UnstructuredSparsityRewrite(SparsityRewrite):
+ """
+ Rewrite class for unstructured sparsity rewrite.
+
+ e.g. fully-connected-unstructured-sparsity.
+ """
+
+ def check_optimization(
+ self, model: keras.Model, final_sparsity: float = 0.5, **_: Any
+ ) -> bool:
+ """Not needed here."""
+ found_sparsity_list = []
+ num_dec_places = str(final_sparsity)[::-1].find(".")
+ for weight, _ in self.check_optimization_generator(model=model):
+ weight_np = weight.numpy()
+ found_sparsity_list.append(
+ round(np.count_nonzero(weight_np) / weight_np.size, num_dec_places)
+ )
+ if len(found_sparsity_list) == 0:
+ logger.warning(
+ "\nWARNING: Could not find any layers "
+ "in rewrite that could be sparsely pruned"
+ )
+ return False
+ found_sparsity = fmean(found_sparsity_list)
+ if found_sparsity != final_sparsity:
+ logger.warning(
+ "\nWARNING: Found total sparsity of "
+ "rewrite model: %.2f "
+ "expected total sparsity to be: "
+ "%.2f\n",
+ found_sparsity,
+ final_sparsity,
+ )
+ return False
+ return True
+
+
+class StructuredSparsityRewrite(SparsityRewrite):
+ """Rewrite class for structured sparsity rewrite e.g. fully-connected-sparsity."""
+
+ def check_optimization(
+ self,
+ model: keras.Model,
+ sparsity_m: int = 2,
+ sparsity_n: int = 4,
+ **_: Any,
+ ) -> bool:
+ """Check if sparity has produced the correct result."""
+ for weight, layer in self.check_optimization_generator(model=model):
+ if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)):
+ logger.warning(
+ "\nWARNING: Could not find (%d, %d) sparsity, "
+ "in layer %s for weight %s \n",
+ sparsity_m,
+ sparsity_n,
+ layer.name,
+ weight.name,
+ )
+ return False
+ return True
+
+
class ClusteringRewrite(QuantizeAwareTrainingRewrite):
- """Graph clustering rewrite logic to be used by RewritingOptimizer."""
+ """Rewrite class for clustering rewrite e.g. fully-connected-clustering."""
_strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
@@ -151,32 +247,22 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite):
)
return cqat_model
- def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ def check_optimization(
+ self, model: keras.Model, num_clusters: int = 2, **_: Any
+ ) -> bool:
"""Check if clustering has produced the correct result."""
- number_of_clusters = kwargs.get("number_of_clusters")
- if not number_of_clusters:
- raise ValueError(
- """
- Expected check_preserved_quantize to have argument number_of_clusters.
- """
- )
-
- for layer in model.layers:
- for weight in layer.weights:
- if "kernel" in weight.name:
- if "kernel_min" in weight.name or "kernel_max" in weight.name:
- continue
- number_of_found_clusters = len(np.unique(weight))
- if number_of_found_clusters != number_of_clusters:
- logger.warning(
- "\nWARNING: Expected %d cluster(s), found %d "
- "cluster(s) in layer %s for weight %s \n",
- number_of_clusters,
- number_of_found_clusters,
- layer.name,
- weight.name,
- )
- return False
+ for weight, layer in self.check_optimization_generator(model=model):
+ number_of_found_clusters = len(np.unique(weight))
+ if number_of_found_clusters != num_clusters:
+ logger.warning(
+ "\nWARNING: Expected %d cluster(s), found %d "
+ "cluster(s) in layer %s for weight %s \n",
+ num_clusters,
+ number_of_found_clusters,
+ layer.name,
+ weight.name,
+ )
+ return False
return True
def training_callbacks(self) -> list:
@@ -184,7 +270,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite):
return []
def post_process(self, model: keras.Model) -> keras.Model:
- """Return the clustering stripped model."""
+ """Clustering-specific post-processing rewrite option."""
return self._strip_clustering_wrapper(model)
@@ -215,6 +301,7 @@ class RewriteConfiguration(OptimizerConfiguration):
layers_to_optimize: list[str] | None = None
dataset: Path | None = None
train_params: TrainingParameters = TrainingParameters()
+ rewrite_specific_params: dict | None = None
def __str__(self) -> str:
"""Return string representation of the configuration."""
@@ -227,8 +314,17 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
GenericRewrite("fully-connected", fc_rewrite),
- Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
+ StructuredSparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite),
ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite),
+ ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite),
+ StructuredSparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite),
+ UnstructuredSparsityRewrite(
+ "conv2d-unstructured-sparsity", conv2d_sparsity_unstructured_rewrite
+ ),
+ UnstructuredSparsityRewrite(
+ "fully-connected-unstructured-sparsity",
+ fc_sparsity_unstructured_rewrite,
+ ),
]
)
@@ -250,7 +346,6 @@ class RewritingOptimizer(Optimizer):
rewrite = RewritingOptimizer.registry.items[
self.optimizer_configuration.optimization_target
]
-
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
@@ -272,6 +367,10 @@ class RewritingOptimizer(Optimizer):
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
train_params=self.optimizer_configuration.train_params,
+ rewrite_specific_params=self.optimizer_configuration.rewrite_specific_params, # pylint: disable=line-too-long
+ detect_activation_function=(
+ "activation" in rewrite.return_rewrite_func_args()
+ ),
)
if orig_vs_repl_stats:
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 4204978..570968a 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -34,13 +34,13 @@ from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.utils.logging import log_action
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
@@ -83,6 +83,8 @@ def train( # pylint: disable=too-many-arguments
input_tensors: list,
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
+ rewrite_specific_params: dict | None = None,
+ detect_activation_function: bool = False,
) -> Any:
"""Extract and train a model, and return the results."""
if unmodified_model:
@@ -122,6 +124,8 @@ def train( # pylint: disable=too-many-arguments
rewrite=rewrite,
is_qat=is_qat,
train_params=train_params,
+ rewrite_specific_params=rewrite_specific_params,
+ detect_activation_function=detect_activation_function,
)
for i, filename in enumerate(tflite_filenames):
@@ -349,6 +353,41 @@ def set_up_data_pipeline(
return dataset, steps_per_epoch
+def detect_activation_from_rewrite_function(model_path: str) -> str:
+ """Given a rewrite model, choose the most common activation function."""
+ interpreter = tf.lite.Interpreter(model_path=model_path)
+ interpreter.allocate_tensors()
+ act_func_match_list = []
+ for tensor_details in interpreter.get_tensor_details():
+ for act_func in ACTIVATION_FUNCTION_LIST:
+ tensor_name = tensor_details["name"].lower()
+ if act_func in tensor_name:
+ act_func_idx = tensor_name.index(act_func)
+ if (
+ len(tensor_name) == act_func_idx + len(act_func)
+ or tensor_name[act_func_idx + len(act_func)] == ";"
+ ):
+ act_func_match_list.append(
+ tensor_name[
+ act_func_idx : act_func_idx + len(act_func) # noqa: E203
+ ]
+ )
+ act_func_match = "relu"
+ if len(act_func_match_list) == 0:
+ logger.info(
+ "No activation function specified, setting activation function to ReLU"
+ )
+ else:
+ act_func_match = max(set(act_func_match_list), key=act_func_match.count)
+ logger.info(
+ "No activation function specified, "
+ "setting activation function to most "
+ "common activation detected in rewrite graph: %s",
+ act_func_match,
+ )
+ return act_func_match
+
+
def train_in_dir(
train_dir: str,
baseline_dir: Any,
@@ -356,6 +395,8 @@ def train_in_dir(
rewrite: Callable,
is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
+ rewrite_specific_params: dict | None = None,
+ detect_activation_function: bool = False,
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
and output.tfrec in train_dir.
@@ -372,6 +413,18 @@ def train_in_dir(
)
replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir))
+ if detect_activation_function and (
+ rewrite_specific_params is None
+ or "activation" not in list(rewrite_specific_params.keys())
+ ):
+ detected_activation_function = detect_activation_from_rewrite_function(
+ ExtractPaths.tflite.replace(train_dir).as_posix()
+ )
+ if rewrite_specific_params:
+ rewrite_specific_params["activation"] = detected_activation_function
+ else:
+ rewrite_specific_params = {"activation": detected_activation_function}
+
input_name, output_name = _get_io_tensors(teacher)
model_is_quantized = replace.is_tensor_quantized(name=input_name)
@@ -396,7 +449,13 @@ def train_in_dir(
loss_fn = keras.losses.MeanSquaredError()
model = create_model(
- rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized
+ rewrite,
+ input_shape,
+ output_shape,
+ optimizer,
+ loss_fn,
+ model_is_quantized,
+ rewrite_specific_params=rewrite_specific_params,
)
logger.info(model.summary())
@@ -462,11 +521,9 @@ def train_in_dir(
steps_per_epoch,
post_process=True,
)
-
- # Placeholder for now, will be parametrized later (MLIA-1114)
- # rewrite.check_optimization( # type: ignore[attr-defined]
- # model, number_of_clusters=32
- # )
+ rewrite.check_optimization( # type: ignore[attr-defined]
+ model, **rewrite_specific_params if rewrite_specific_params else {}
+ )
if model_is_quantized and is_qat:
model = rewrite.preserved_quantize(model) # type: ignore[attr-defined]
checkpoints = (
@@ -501,11 +558,10 @@ def train_in_dir(
loss_fn,
steps_per_epoch,
)
- # Placeholder for now, will be parametrized later (MLIA-1114)
- # rewrite.check_optimization( # type: ignore[attr-defined]
- # model, number_of_clusters=32
- # )
+ rewrite.check_optimization( # type: ignore[attr-defined]
+ model, **rewrite_specific_params if rewrite_specific_params else {}
+ )
teacher.close()
return output_filenames
@@ -528,9 +584,13 @@ def create_model( # pylint: disable=too-many-arguments
loss_fn: Callable,
model_is_quantized: bool,
model_to_load_from: keras.model | None = None,
+ rewrite_specific_params: dict | None = None,
) -> keras.Model:
"""Create a model, optionally from another."""
- model = rewrite(input_shape, output_shape)
+ if rewrite_specific_params:
+ model = rewrite(input_shape, output_shape, **rewrite_specific_params)
+ else:
+ model = rewrite(input_shape, output_shape)
if model_is_quantized:
model = rewrite.quantize(model) # type: ignore[attr-defined]
model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn)
@@ -558,6 +618,7 @@ def model_fit( # pylint: disable=too-many-arguments
loss_fn: Callable,
steps_per_epoch: int,
post_process: bool = False,
+ rewrite_specific_params: dict | None = None,
) -> keras.Model:
"""Train a tflite model."""
steps_so_far = 0
@@ -597,6 +658,7 @@ def model_fit( # pylint: disable=too-many-arguments
loss_fn,
model_is_quantized,
model_to_load_from=model,
+ rewrite_specific_params=rewrite_specific_params,
)
else:
model_to_save = model
diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py
new file mode 100644
index 0000000..48914dc
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/clustering.py
@@ -0,0 +1,78 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Rewrite functions used to return layers ready for clustering."""
+from typing import Any
+
+import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+from mlia.nn.rewrite.library.helper_functions import get_activation_function
+
+
+def fc_clustering_rewrite(
+ input_shape: Any,
+ output_shape: Any,
+ num_clusters: int = 2,
+ cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long
+ "CentroidInitialization.LINEAR"
+ ),
+) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for clustering."""
+ rewrite_params = {
+ "number_of_clusters": num_clusters,
+ "cluster_centroids_init": cluster_centroids_init,
+ }
+ model = tfmot.clustering.keras.cluster_weights(
+ to_cluster=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Flatten(),
+ keras.layers.Dense(units=output_shape),
+ ]
+ ),
+ **rewrite_params
+ )
+ return model
+
+
+def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value
+ input_shape: Any,
+ output_shape: Any,
+ num_clusters: int = 2,
+ cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long
+ "CentroidInitialization.LINEAR"
+ ),
+ activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
+) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for clustering."""
+ rewrite_params = {
+ "number_of_clusters": num_clusters,
+ "cluster_centroids_init": cluster_centroids_init,
+ }
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ activation_func_found = (
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
+ )
+ model = tfmot.clustering.keras.cluster_weights(
+ to_cluster=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ *activation_func_found,
+ ]
+ ),
+ **rewrite_params
+ )
+ return model
diff --git a/src/mlia/nn/rewrite/library/fc_clustering_layer.py b/src/mlia/nn/rewrite/library/fc_clustering_layer.py
deleted file mode 100644
index 7cc383e..0000000
--- a/src/mlia/nn/rewrite/library/fc_clustering_layer.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Example rewrite with one fully connected clustered layer."""
-from typing import Any
-
-import tensorflow_model_optimization as tfmot
-from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-
-
-def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for clustering rewrite."""
- rewrite_params = {
- "number_of_clusters": 32,
- "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
- }
- model = tfmot.clustering.keras.cluster_weights(
- to_cluster=keras.Sequential(
- [
- keras.layers.InputLayer(input_shape=input_shape),
- keras.layers.Flatten(),
- keras.layers.Dense(units=output_shape),
- ]
- ),
- **rewrite_params
- )
- return model
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
index 041ce85..92195d1 100644
--- a/src/mlia/nn/rewrite/library/fc_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_layer.py
@@ -1,13 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Example rewrite with one fully connected layer."""
+"""Rewrite function used to return regular layers."""
from typing import Any
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for rewrite."""
+def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Fully connected TensorFlow Lite model for rewrite."""
model = keras.Sequential(
(
keras.layers.InputLayer(input_shape=input_shape),
diff --git a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py b/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py
deleted file mode 100644
index 531b34a..0000000
--- a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py
+++ /dev/null
@@ -1,23 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Example rewrite with one fully connected 2:4 sparsity layer."""
-from typing import Any
-
-import tensorflow_model_optimization as tfmot
-from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-
-
-def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for rewrite."""
- model = tfmot.sparsity.keras.prune_low_magnitude(
- to_prune=keras.Sequential(
- [
- keras.layers.InputLayer(input_shape=input_shape),
- keras.layers.Reshape([-1]),
- keras.layers.Dense(output_shape),
- ]
- ),
- sparsity_m_by_n=(2, 4),
- )
-
- return model
diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py
new file mode 100644
index 0000000..1237c17
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/helper_functions.py
@@ -0,0 +1,61 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Helper functions for the rewrite library."""
+import math
+from typing import Any
+
+import numpy as np
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+ACTIVATION_FUNCTION_PRESETS = {
+ "relu": {"layer_func": keras.layers.ReLU, "extra_args": {}},
+ "relu6": {"layer_func": keras.layers.ReLU, "extra_args": {"max_value": 6}},
+ "none": {"layer_func": None, "extra_args": {}},
+}
+ACTIVATION_FUNCTION_LIST = [
+ act_func for act_func, _ in ACTIVATION_FUNCTION_PRESETS.items()
+]
+
+
+def get_activation_function(
+ activation: str = "relu",
+) -> tuple[type, dict]:
+ """Get the activation function from a key."""
+ if activation not in ACTIVATION_FUNCTION_LIST:
+ raise KeyError(
+ "Expected activation function to be "
+ f"in {ACTIVATION_FUNCTION_LIST}, found {activation}"
+ )
+ activation_function = ACTIVATION_FUNCTION_PRESETS[activation]["layer_func"]
+ activation_function_extra_args = ACTIVATION_FUNCTION_PRESETS[activation][
+ "extra_args"
+ ]
+ return activation_function, activation_function_extra_args
+
+
+def compute_conv2d_parameters( # pylint: disable=dangerous-default-value
+ input_shape: np.ndarray,
+ output_shape: np.ndarray,
+ kernel_size_input: list[int] = [3, 3],
+) -> dict[str, Any]:
+ """Compute needed kernel size and strides for a given input and output_shape."""
+ input_shape = input_shape.tolist()
+ output_shape = output_shape.tolist()
+ assert len(kernel_size_input) == 2, "Kernel size should have 2 entries"
+ assert len(input_shape) == 3
+ assert len(output_shape) == 3
+ kernel_size = tuple(kernel_size_input)
+ num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1]
+ padding = "valid"
+ stride_h = round(input_shape[0] / output_shape[0])
+ check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1
+ stride_w = round(input_shape[1] / output_shape[1])
+ check_output_size_w = math.floor((input_shape[1] - kernel_size[1]) / stride_w) + 1
+ if check_output_size_h != output_shape[0] or check_output_size_w != output_shape[1]:
+ padding = "same"
+ return {
+ "filters": num_filters,
+ "kernel_size": kernel_size,
+ "padding": padding,
+ "strides": (stride_h, stride_w),
+ }
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
new file mode 100644
index 0000000..1e53254
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -0,0 +1,144 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Rewrite functions used to return layers ready for sparse pruning."""
+from __future__ import annotations
+
+from typing import Any
+
+import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+from mlia.nn.rewrite.library.helper_functions import get_activation_function
+
+
+def fc_sparsity_unstructured_rewrite(
+ input_shape: Any,
+ output_shape: Any,
+ initial_sparsity: float = 0.5,
+ final_sparsity: float = 0.5,
+ begin_step: int = 0,
+ end_step: int = 48000,
+) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for unstructured sparse pruning."""
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ ]
+ ),
+ pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=initial_sparsity,
+ final_sparsity=final_sparsity,
+ begin_step=begin_step,
+ end_step=end_step,
+ ),
+ )
+
+ return model
+
+
+def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value
+ input_shape: Any,
+ output_shape: Any,
+ initial_sparsity: float = 0.5,
+ final_sparsity: float = 0.5,
+ begin_step: int = 0,
+ end_step: int = 48000,
+ activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
+) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for unstructured sparse pruning."""
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ activation_func_found = (
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
+ )
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ *activation_func_found,
+ ]
+ ),
+ pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=initial_sparsity,
+ final_sparsity=final_sparsity,
+ begin_step=begin_step,
+ end_step=end_step,
+ ),
+ )
+
+ return model
+
+
+def fc_sparsity_rewrite(
+ input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4
+) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for sparse pruning."""
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ ]
+ ),
+ sparsity_m_by_n=(
+ sparsity_m,
+ sparsity_n,
+ ),
+ )
+
+ return model
+
+
+def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value
+ input_shape: Any,
+ output_shape: Any,
+ sparsity_m: int = 2,
+ sparsity_n: int = 4,
+ activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
+) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for sparse pruning."""
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ activation_func_found = (
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
+ )
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ *activation_func_found,
+ ]
+ ),
+ sparsity_m_by_n=(
+ sparsity_m,
+ sparsity_n,
+ ),
+ )
+ return model
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index b61e713..d5470d1 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -17,7 +17,7 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
-from mlia.nn.rewrite.core.rewrite import TrainingParameters
+from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import KerasModel
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.optimizations.clustering import Clusterer
@@ -109,7 +109,7 @@ class MultiStageOptimizer(Optimizer):
def apply_optimization(self) -> None:
"""Apply optimization to the model."""
for config in self.optimizations:
- optimizer = get_optimizer(self.model, config)
+ optimizer = get_optimizer(self.model, config, {})
optimizer.apply_optimization()
self.model = optimizer.get_model()
@@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
model: keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
- training_parameters: dict | None = None,
+ rewrite_parameters: dict,
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
@@ -137,12 +137,12 @@ def get_optimizer(
if isinstance(config, OptimizationSettings):
return _get_optimizer(
- model, cast(OptimizationSettings, config), training_parameters
+ model, cast(OptimizationSettings, config), rewrite_parameters
)
if is_list_of(config, OptimizationSettings):
return _get_optimizer(
- model, cast(List[OptimizationSettings], config), training_parameters
+ model, cast(List[OptimizationSettings], config), rewrite_parameters
)
raise ConfigurationError(f"Unknown optimization configuration {config}")
@@ -151,7 +151,7 @@ def get_optimizer(
def _get_optimizer(
model: keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
- training_parameters: dict | None = None,
+ rewrite_parameters: dict,
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
@@ -162,12 +162,12 @@ def _get_optimizer(
_check_optimizer_params(opt_type, opt_target)
opt_config = _get_optimizer_configuration(
- opt_type, opt_target, layers_to_optimize, dataset, training_parameters
+ opt_type, opt_target, rewrite_parameters, layers_to_optimize, dataset
)
optimizer_configs.append(opt_config)
if len(optimizer_configs) == 1:
- return get_optimizer(model, optimizer_configs[0])
+ return get_optimizer(model, optimizer_configs[0], {})
return MultiStageOptimizer(model, optimizer_configs)
@@ -189,9 +189,9 @@ def _get_rewrite_params(
def _get_optimizer_configuration(
optimization_type: str,
optimization_target: int | float | str,
+ rewrite_parameters: dict,
layers_to_optimize: list[str] | None = None,
dataset: Path | None = None,
- training_parameters: dict | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
@@ -212,12 +212,14 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
- rewrite_params = _get_rewrite_params(training_parameters)
return RewriteConfiguration(
optimization_target=str(optimization_target),
layers_to_optimize=layers_to_optimize,
dataset=dataset,
- train_params=rewrite_params,
+ train_params=_get_rewrite_params(rewrite_parameters["train_params"]),
+ rewrite_specific_params=rewrite_parameters.get(
+ "rewrite_specific_params"
+ ),
)
raise ConfigurationError(
diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml
new file mode 100644
index 0000000..3d8adfa
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[rewrite.training_parameters]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.0
+augmentations.mixup_strength = 0.0
+
+[rewrite.conv2d-clustering]
+num_clusters = 16
+cluster_centroids_init = "CentroidInitialization.LINEAR"
+activation = "relu"
+kernel_size = [3, 3]
diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml
new file mode 100644
index 0000000..aa7f982
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[rewrite.training_parameters]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.0
+augmentations.mixup_strength = 0.0
+
+[rewrite.conv2d-sparsity]
+sparsity_m = 2
+sparsity_n = 4
+activation = "relu"
+kernel_size = [3, 3]
diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml
new file mode 100644
index 0000000..67740ca
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-unstructured-pruning.toml
@@ -0,0 +1,20 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[rewrite.training_parameters]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.0
+augmentations.mixup_strength = 0.0
+
+[rewrite.conv2d-unstructured-sparsity]
+initial_sparsity = 0.25
+final_sparsity = 0.5
+end_step = 48000
+activation = "relu"
+kernel_size = [3, 3]
diff --git a/src/mlia/resources/optimization_profiles/optimization-custom-augmentation.toml b/src/mlia/resources/optimization_profiles/optimization-custom-augmentation.toml
new file mode 100644
index 0000000..96d9742
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization-custom-augmentation.toml
@@ -0,0 +1,13 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[rewrite.training_parameters]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.1
+augmentations.mixup_strength = 0.1
diff --git a/src/mlia/resources/optimization_profiles/optimization-fully-connected-clustering.toml b/src/mlia/resources/optimization_profiles/optimization-fully-connected-clustering.toml
new file mode 100644
index 0000000..c5d460b
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization-fully-connected-clustering.toml
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[rewrite.training_parameters]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.0
+augmentations.mixup_strength = 0.0
+
+[rewrite.fully-connected-clustering]
+num_clusters = 16
+cluster_centroids_init = "CentroidInitialization.LINEAR"
diff --git a/src/mlia/resources/optimization_profiles/optimization-fully-connected-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-fully-connected-pruning.toml
new file mode 100644
index 0000000..f7f91ec
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization-fully-connected-pruning.toml
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[rewrite.training_parameters]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.0
+augmentations.mixup_strength = 0.0
+
+[rewrite.fully-connected-sparsity]
+sparsity_m = 2
+sparsity_n = 4
diff --git a/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml
new file mode 100644
index 0000000..cd5f745
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization-fully-connected-unstructured-pruning.toml
@@ -0,0 +1,18 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[rewrite.training_parameters]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.0
+augmentations.mixup_strength = 0.0
+
+[rewrite.fully-connected-unstructured-sparsity]
+initial_sparsity = 0.25
+final_sparsity = 0.5
+end_step = 48000
diff --git a/src/mlia/resources/optimization_profiles/optimization.toml b/src/mlia/resources/optimization_profiles/optimization.toml
index 623a763..6f2800e 100644
--- a/src/mlia/resources/optimization_profiles/optimization.toml
+++ b/src/mlia/resources/optimization_profiles/optimization.toml
@@ -1,11 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-[training]
+[rewrite.training_parameters]
batch_size = 32
learning_rate = 1e-3
show_progress = true
steps = 48000
learning_rate_schedule = "cosine"
+augmentations = "gaussian"
num_procs = 1
num_threads = 0
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
index 1423189..69d3a24 100644
--- a/src/mlia/target/common/optimization.py
+++ b/src/mlia/target/common/optimization.py
@@ -17,6 +17,7 @@ from mlia.core.errors import FunctionalityNotSupportedError
from mlia.core.performance import estimate_performance
from mlia.core.performance import P
from mlia.core.performance import PerformanceEstimator
+from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.select import get_optimizer
from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.config import get_keras_model
@@ -50,7 +51,7 @@ class OptimizingDataCollector(ContextAwareDataCollector):
optimizations = self._get_optimization_settings(self.context)
- training_parameters = self._get_training_settings(self.context)
+ rewrite_parameters = self._get_rewrite_settings(self.context)
if not optimizations or optimizations == [[]]:
raise FunctionalityNotSupportedError(
@@ -77,7 +78,7 @@ class OptimizingDataCollector(ContextAwareDataCollector):
model = self.model # type: ignore
optimizers: list[Callable] = [
- partial(self.optimize_model, opts, training_parameters)
+ partial(self.optimize_model, opts, rewrite_parameters)
for opts in opt_settings
]
@@ -86,12 +87,12 @@ class OptimizingDataCollector(ContextAwareDataCollector):
def optimize_model(
self,
opt_settings: list[OptimizationSettings],
- training_parameters: dict | None,
+ rewrite_parameters: dict,
model: KerasModel | TFLiteModel,
) -> Any:
"""Run optimization."""
optimizer = get_optimizer(
- model, opt_settings, training_parameters=training_parameters
+ model, opt_settings, rewrite_parameters=rewrite_parameters
)
opts_as_str = ", ".join(str(opt) for opt in opt_settings)
logger.info("Applying model optimizations - [%s]", opts_as_str)
@@ -123,11 +124,11 @@ class OptimizingDataCollector(ContextAwareDataCollector):
context=context,
)
- def _get_training_settings(self, context: Context) -> dict:
+ def _get_rewrite_settings(self, context: Context) -> list[dict]:
"""Get optimization settings."""
return self.get_parameter( # type: ignore
OptimizingDataCollector.name(),
- "training_parameters",
+ "rewrite_parameters",
expected_type=dict,
expected=False,
context=context,
@@ -218,7 +219,53 @@ _DEFAULT_OPTIMIZATION_TARGETS = [
]
-def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -> None:
+def parse_augmentations(
+ augmentations: dict | str | None,
+) -> tuple[float | None, float | None]:
+ """Parse augmentations from optimization-profile and return a valid tuple."""
+ if isinstance(augmentations, str):
+ match_augmentation = AUGMENTATION_PRESETS.get(augmentations)
+ if not match_augmentation:
+ match_augmentation = AUGMENTATION_PRESETS["none"]
+ return match_augmentation
+ if isinstance(augmentations, dict):
+ augmentation_keys_test_for_valid = list(augmentations.keys())
+ augmentation_keys_test_for_float = list(augmentations.keys())
+ valid_keys = ["mixup_strength", "gaussian_strength"]
+ tuple_to_return = []
+ for valid_key in valid_keys.copy():
+ if augmentations.get(valid_key) is not None:
+ del augmentation_keys_test_for_valid[
+ augmentation_keys_test_for_valid.index(valid_key)
+ ]
+ if isinstance(augmentations.get(valid_key), float):
+ tuple_to_return.append(augmentations[valid_key])
+ del augmentation_keys_test_for_float[
+ augmentation_keys_test_for_float.index(valid_key)
+ ]
+ else:
+ tuple_to_return.append(None)
+ else:
+ tuple_to_return.append(None)
+ if len(augmentation_keys_test_for_valid) > 0:
+ logger.warning(
+ "Warning! Expected augmentation parameters to be 'gaussian_strength' "
+ "and/or 'mixup_strength' got %s. "
+ "Removing invalid augmentations",
+ str(list(augmentations.keys())),
+ )
+ elif len(augmentation_keys_test_for_float) > 0:
+ logger.warning(
+ "Warning! Not all augmentation parameters were floats, "
+ "removing non-float augmentations"
+ )
+ return (tuple_to_return[0], tuple_to_return[1])
+ return AUGMENTATION_PRESETS["none"]
+
+
+def add_common_optimization_params( # pylint: disable=too-many-branches
+ advisor_parameters: dict, extra_args: dict
+) -> None:
"""Add common optimization parameters."""
optimization_targets = extra_args.get("optimization_targets")
if not optimization_targets:
@@ -227,18 +274,32 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
if not is_list_of(optimization_targets, dict):
raise TypeError("Optimization targets value has wrong format.")
- rewrite_parameters = extra_args.get("optimization_profile")
training_parameters = None
- if rewrite_parameters:
- if not isinstance(rewrite_parameters, dict):
- raise TypeError("Training Parameter values has wrong format.")
- training_parameters = extra_args["optimization_profile"].get("training")
+ rewrite_specific_parameters = None
+
+ optimization_parameters = extra_args.get("optimization_profile")
+ if optimization_parameters: # pylint: disable=too-many-nested-blocks
+ if not isinstance(optimization_parameters, dict):
+ raise TypeError("Optimization Parameter values has wrong format.")
+
+ if optimization_parameters.get("rewrite"):
+ rewrite_params = optimization_parameters["rewrite"]
+ training_parameters = rewrite_params.get("training_parameters")
+ if training_parameters:
+ training_parameters["augmentations"] = parse_augmentations(
+ training_parameters.get("augmentations")
+ )
+ optimization_target = optimization_targets[0]["optimization_target"]
+ rewrite_specific_parameters = rewrite_params.get(optimization_target)
advisor_parameters.update(
{
"common_optimizations": {
"optimizations": [optimization_targets],
- "training_parameters": training_parameters,
+ "rewrite_parameters": {
+ "train_params": training_parameters,
+ "rewrite_specific_params": rewrite_specific_parameters,
+ },
},
}
)
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index 8492086..8a5b360 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -71,7 +71,16 @@ def is_builtin_target_profile(profile_name: str | Path) -> bool:
return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES
-BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = ["optimization"]
+BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = [
+ "optimization",
+ "optimization-custom-augmentation",
+ "optimization-fully-connected-clustering",
+ "optimization-fully-connected-pruning",
+ "optimization-fully-connected-unstructured-pruning",
+ "optimization-conv2d-clustering",
+ "optimization-conv2d-pruning",
+ "optimization-conv2d-unstructured-pruning",
+]
def is_builtin_optimization_profile(optimization_name: str | Path) -> bool:
diff --git a/tests/conftest.py b/tests/conftest.py
index 3d0b832..a64f320 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -126,9 +126,28 @@ def get_test_keras_model() -> keras.Model:
return model
+def get_test_keras_model_no_activation() -> keras.Model:
+ """Return test Keras model."""
+ model = keras.Sequential(
+ [
+ keras.Input(shape=(28, 28, 1), batch_size=1, name="input"),
+ keras.layers.Reshape((28, 28, 1)),
+ keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv1"),
+ keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv2"),
+ keras.layers.MaxPool2D(2, 2),
+ keras.layers.Flatten(),
+ keras.layers.Dense(10, name="output"),
+ ]
+ )
+
+ model.compile(optimizer="sgd", loss="mean_squared_error")
+ return model
+
+
TEST_MODEL_KERAS_FILE = "test_model.h5"
TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite"
TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite"
+TEST_MODEL_TFLITE_NO_ACT_FILE = "test_model_no_act.tflite"
TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite"
TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model"
TEST_MODEL_INVALID_FILE = "invalid.tflite"
@@ -153,6 +172,13 @@ def fixture_test_models_path(
keras_model, quantized=False, output_path=tmp_path / TEST_MODEL_TFLITE_FP32_FILE
)
+ # Un-quantized TensorFlow Lite model with ReLU activation (fp32)
+ convert_to_tflite(
+ get_test_keras_model_no_activation(),
+ quantized=False,
+ output_path=tmp_path / TEST_MODEL_TFLITE_NO_ACT_FILE,
+ )
+
# Quantized TensorFlow Lite model (int8)
tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE
convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path)
@@ -195,6 +221,12 @@ def fixture_test_tflite_vela_model(test_models_path: Path) -> Path:
return test_models_path / TEST_MODEL_TFLITE_VELA_FILE
+@pytest.fixture(scope="session", name="test_tflite_no_act_model")
+def fixture_test_tflite_no_act_model(test_models_path: Path) -> Path:
+ """Return test TensorFlow Lite model with relu activation."""
+ return test_models_path / TEST_MODEL_TFLITE_NO_ACT_FILE
+
+
@pytest.fixture(scope="session", name="test_tf_model")
def fixture_test_tf_model(test_models_path: Path) -> Path:
"""Return test TensorFlow Lite model."""
@@ -257,17 +289,17 @@ def fixture_test_tfrecord_fp32(
yield from create_tfrecord(tmp_path_factory, random_data)
-@pytest.fixture(scope="session", autouse=True)
+@pytest.fixture(scope="function", autouse=True)
def set_training_steps(
request: _pytest.fixtures.SubRequest,
) -> Generator[None, None, None]:
"""Speed up tests by using MockTrainingParameters."""
- if "set_training_steps" == request.fixturename:
- yield
- else:
+ if "skip_set_training_steps" not in request.keywords:
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"mlia.nn.select._get_rewrite_params",
- MagicMock(return_value=[MockTrainingParameters(), None, None]),
+ MagicMock(return_value=MockTrainingParameters()),
)
yield
+ else:
+ yield
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 93a05bd..5a91cd7 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -90,7 +90,7 @@ def test_performance_unknown_target(
None,
None,
True,
- "fully-connected-sparsity24",
+ "fully-connected-sparsity",
"sequential/flatten/Reshape",
"StatefulPartitionedCall:0",
does_not_raise(),
@@ -139,8 +139,10 @@ def test_performance_unknown_target(
Exception,
match=re.escape(
"Invalid rewrite target: 'random'. "
- "Supported rewrites: ['fully-connected',"
- " 'fully-connected-clustering', 'fully-connected-sparsity24']"
+ "Supported rewrites: ['conv2d-clustering', 'conv2d-sparsity', "
+ "'conv2d-unstructured-sparsity', 'fully-connected', "
+ "'fully-connected-clustering', 'fully-connected-sparsity', "
+ "'fully-connected-unstructured-sparsity']"
),
),
],
@@ -195,6 +197,58 @@ def test_performance_unknown_target(
"StatefulPartitionedCall:0",
does_not_raise(),
],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "fully-connected-unstructured-sparsity",
+ "sequential/flatten/Reshape",
+ "StatefulPartitionedCall:0",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "conv2d-sparsity",
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "conv2d-unstructured-sparsity",
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "conv2d-clustering",
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ does_not_raise(),
+ ],
],
)
def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments
diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py
index 0e9f0d6..69e6ffe 100644
--- a/tests/test_cli_helpers.py
+++ b/tests/test_cli_helpers.py
@@ -156,7 +156,7 @@ def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None:
def test_copy_optimization_file_to_output_dir(tmp_path: Path) -> None:
- """Test if the optimization profile file is copied into the output directory."""
+ """Test if the profile file is copied into the output directory."""
test_target_profile_name = "optimization"
test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml")
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 341e0d2..bdcf034 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the common optimization module."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raises
from pathlib import Path
from typing import Any
@@ -15,6 +17,7 @@ from mlia.nn.tensorflow.config import TFLiteModel
from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.common.optimization import OptimizingDataCollector
+from mlia.target.common.optimization import parse_augmentations
from mlia.target.config import load_profile
from mlia.target.config import TargetProfile
@@ -57,7 +60,10 @@ def test_optimizing_data_collector(
config_parameters={
"common_optimizations": {
"optimizations": optimizations,
- "training_parameters": training_parameters,
+ "rewrite_parameters": {
+ "train_params": training_parameters,
+ "rewrite_specific_params": None,
+ },
}
}
)
@@ -94,12 +100,15 @@ def test_optimizing_data_collector(
collector.set_context(context)
collector.collect_data()
assert optimize_model_mock.call_args.args[0] == opt_settings[0]
- assert optimize_model_mock.call_args.args[1] == training_parameters
+ assert optimize_model_mock.call_args.args[1] == {
+ "train_params": training_parameters,
+ "rewrite_specific_params": None,
+ }
assert fake_optimizer.invocation_count == 1
@pytest.mark.parametrize(
- "extra_args, error_to_raise",
+ "extra_args, error_to_raise, rewrite_parameter_type",
[
(
{
@@ -112,14 +121,39 @@ def test_optimizing_data_collector(
],
},
does_not_raises(),
+ type(None),
),
(
{
+ "optimization_targets": [
+ {
+ "optimization_type": "rewrite",
+ "optimization_target": "fully-connected-clustering",
+ }
+ ],
"optimization_profile": load_profile(
- "src/mlia/resources/optimization_profiles/optimization.toml"
- )
+ "src/mlia/resources/optimization_profiles/"
+ "optimization-fully-connected-clustering.toml"
+ ),
},
does_not_raises(),
+ dict,
+ ),
+ (
+ {
+ "optimization_targets": [
+ {
+ "optimization_type": "rewrite",
+ "optimization_target": "fully-connected-sparsity",
+ }
+ ],
+ "optimization_profile": load_profile(
+ "src/mlia/resources/optimization_profiles/"
+ "optimization-fully-connected-pruning.toml"
+ ),
+ },
+ does_not_raises(),
+ dict,
),
(
{
@@ -132,16 +166,22 @@ def test_optimizing_data_collector(
pytest.raises(
TypeError, match="Optimization targets value has wrong format."
),
+ type(None),
),
(
{"optimization_profile": [32, 1e-3, True, 48000, "cosine", 1, 0]},
pytest.raises(
- TypeError, match="Training Parameter values has wrong format."
+ TypeError, match="Optimization Parameter values has wrong format."
),
+ type(None),
),
],
)
-def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -> None:
+def test_add_common_optimization_params(
+ extra_args: dict,
+ error_to_raise: Any,
+ rewrite_parameter_type: dict | None,
+) -> None:
"""Test to check that optimization_targets and optimization_profiles are
correctly parsed."""
advisor_parameters: dict = {}
@@ -158,12 +198,93 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -
]
if not extra_args.get("optimization_profile"):
- assert (
- advisor_parameters["common_optimizations"]["training_parameters"]
- is None
- )
+ assert advisor_parameters["common_optimizations"]["rewrite_parameters"] == {
+ "train_params": None,
+ "rewrite_specific_params": None,
+ }
else:
- assert (
- advisor_parameters["common_optimizations"]["training_parameters"]
- == extra_args["optimization_profile"]["training"]
+ if not extra_args["optimization_profile"].get("rewrite"):
+ assert isinstance(
+ advisor_parameters["common_optimizations"]["rewrite_parameters"][
+ "train_params"
+ ],
+ type(None),
+ )
+ elif not extra_args["optimization_profile"]["rewrite"].get(
+ "training_parameters"
+ ):
+ assert isinstance(
+ advisor_parameters["common_optimizations"]["rewrite_parameters"][
+ "train_params"
+ ],
+ type(None),
+ )
+ else:
+ assert isinstance(
+ advisor_parameters["common_optimizations"]["rewrite_parameters"][
+ "train_params"
+ ],
+ dict,
+ )
+
+ assert isinstance(
+ advisor_parameters["common_optimizations"]["rewrite_parameters"][
+ "rewrite_specific_params"
+ ],
+ rewrite_parameter_type, # type: ignore
)
+
+
+@pytest.mark.parametrize(
+ "augmentations, expected_output",
+ [
+ (
+ {"gaussian_strength": 1.0, "mixup_strength": 1.0},
+ (1.0, 1.0),
+ ),
+ (
+ {"gaussian_strength": 1.0},
+ (None, 1.0),
+ ),
+ (
+ {"Wrong param": 1.0, "mixup_strength": 1.0},
+ (1.0, None),
+ ),
+ (
+ {"Wrong param1": 1.0, "Wrong param2": 1.0},
+ (None, None),
+ ),
+ (
+ "gaussian",
+ (None, 1.0),
+ ),
+ (
+ "mix_gaussian_large",
+ (2.0, 1.0),
+ ),
+ (
+ "not in presets",
+ (None, None),
+ ),
+ (
+ {"gaussian_strength": 1.0, "mixup_strength": 1.0, "mix2": 1.0},
+ (1.0, 1.0),
+ ),
+ (
+ {"gaussian_strength": "not a float", "mixup_strength": 1.0},
+ (1.0, None),
+ ),
+ (
+ None,
+ (None, None),
+ ),
+ ],
+)
+def test_parse_augmentations(
+ augmentations: dict | str | None, expected_output: tuple
+) -> None:
+ """Check that augmentation parameters in optimization_profiles are
+ correctly parsed."""
+
+ augmentation_output = parse_augmentations(augmentations)
+ assert augmentation_output == expected_output
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index e502842..9e3287e 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -3,18 +3,23 @@
"""Tests for module mlia.nn.rewrite.core.rewrite."""
from __future__ import annotations
+import re
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import cast
from unittest.mock import MagicMock
+import numpy as np
import pytest
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module
ClusterWeights,
)
+from tensorflow_model_optimization.python.core.sparsity.keras.pruning_wrapper import ( # pylint: disable=no-name-in-module
+ PruneLowMagnitude,
+)
from mlia.nn.rewrite.core.rewrite import ClusteringRewrite
from mlia.nn.rewrite.core.rewrite import GenericRewrite
@@ -23,40 +28,19 @@ from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
-from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite
+from mlia.nn.rewrite.core.rewrite import StructuredSparsityRewrite
from mlia.nn.rewrite.core.rewrite import TrainingParameters
+from mlia.nn.rewrite.core.rewrite import UnstructuredSparsityRewrite
from mlia.nn.rewrite.core.train import train_in_dir
-from mlia.nn.rewrite.library.fc_clustering_layer import (
- get_keras_model_clus as fc_clustering_rewrite,
-)
+from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_unstructured_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_unstructured_rewrite
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
-class TestRewrite(Rewrite):
- """Test rewrite class."""
-
- def quantize(self, model: keras.Model) -> keras.Model:
- """Return a quantized model if required."""
- return tfmot.quantization.keras.quantize_model(model)
-
- def preserved_quantize(self, model: keras.Model) -> keras.Model:
- """Not needed."""
- return model
-
- def training_callbacks(self) -> list:
- """Return default rewrite callbacks."""
- return []
-
- def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
- return model
-
- def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
- """Not needed here."""
- return True
-
-
def mock_rewrite_function(*_: Any) -> Any:
"""Mock function to test autoloading of rewrite functions."""
@@ -67,10 +51,10 @@ def test_rewrite() -> None:
def bad_rewrite_func() -> Any:
raise NotImplementedError()
- rewrite = TestRewrite(
+ rewrite = GenericRewrite(
"BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)
)
- with pytest.raises(RuntimeError):
+ with pytest.raises(KeyError):
rewrite((1, 2), (1, 2))
@@ -79,7 +63,9 @@ def test_rewrite() -> None:
[
("fully-connected", 0, GenericRewrite),
("fully-connected-clustering", 0, ClusteringRewrite),
- ("fully-connected-sparsity24", 1, Sparsity24Rewrite),
+ ("fully-connected-sparsity", 1, StructuredSparsityRewrite),
+ ("conv2d-clustering", 0, ClusteringRewrite),
+ ("conv2d-sparsity", 1, StructuredSparsityRewrite),
],
)
def test_rewrite_selection(
@@ -96,8 +82,10 @@ def test_rewrite_selection(
"rewrite_name, expected_error",
[
("fully-connected", does_not_raise()),
- ("fully-connected-sparsity24", does_not_raise()),
+ ("fully-connected-sparsity", does_not_raise()),
("fully-connected-clustering", does_not_raise()),
+ ("conv2d-clustering", does_not_raise()),
+ ("conv2d-sparsity", does_not_raise()),
("random", does_not_raise()),
],
)
@@ -105,7 +93,8 @@ def test_rewrite_configuration(
test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
) -> None:
"""Test get_rewrite function only supports rewrite type fully-connected,
- fully-connected-clustering and fully-connected-sparsity24."""
+ fully-connected-clustering, fully-connected-sparsity, conv2d-clustering
+ and conv2d-sparsity."""
with expected_error:
config_obj = RewriteConfiguration(
rewrite_name,
@@ -120,29 +109,195 @@ def test_rewrite_configuration(
assert isinstance(rewriter_obj, RewritingOptimizer)
+def train_rewrite_model(
+ input_shape: tuple | np.ndarray,
+ output_shape: int | np.ndarray,
+ rewrite_model: keras.Model,
+ epochs: int = 1,
+) -> keras.Model:
+ """Helper function to quickly train a rewrite model."""
+ rewrite_model.compile(
+ optimizer=keras.optimizers.Nadam(learning_rate=0.01),
+ loss=keras.losses.MeanSquaredError(),
+ metrics=["mae"],
+ )
+ if isinstance(output_shape, int):
+ output_shape_list = [output_shape]
+ else:
+ output_shape_list = output_shape.tolist()
+ rewrite_model.fit(
+ x=np.random.rand(16, *input_shape),
+ y=np.random.rand(16, *output_shape_list),
+ batch_size=1,
+ epochs=epochs,
+ callbacks=[tfmot.sparsity.keras.UpdatePruningStep()],
+ )
+ return rewrite_model
+
+
def test_rewrite_fully_connected_clustering() -> None:
- """Check that model has the set number of clusters"""
+ """Check that fully connected clustering rewrite model
+ has the set number of clusters."""
+
+ rewrite = ClusteringRewrite(
+ "fully-connected-clustering",
+ fc_clustering_rewrite,
+ )
+
+ model = rewrite(input_shape=(28, 28), output_shape=10, num_clusters=2)
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(
+ model,
+ num_clusters=2,
+ )
+
+
+def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> None:
+ """
+ Check that sparse fully connected
+ rewrite model is correctly sparse.
+ """
+
+ rewrite = StructuredSparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite)
+ input_shape = (28, 28)
+ output_shape = 10
+ model = rewrite(
+ input_shape=tuple(input_shape),
+ output_shape=output_shape,
+ sparsity_m=2,
+ sparsity_n=4,
+ )
+ model = rewrite.post_process(model)
+ assert not rewrite.check_optimization(model)
+ log_records = caplog.records
+ warning_messages = [x.message for x in log_records if x.levelno == 30]
+ assert (
+ re.search(
+ r"\nWARNING: Could not find \(2, 4\) sparsity, in "
+ r"layer dense_?\d? for weight dense_?\d?\/kernel:0 \n",
+ warning_messages[0],
+ )
+ is not None
+ )
+ model = rewrite(
+ input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4
+ )
+ train_rewrite_model(
+ input_shape=input_shape, output_shape=output_shape, rewrite_model=model
+ )
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model)
+
+
+def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None:
+ """Check that sparse conv2d rewrite model is correctly sparse."""
+
+ rewrite = StructuredSparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite)
+ input_shape = np.array([28, 28, 3])
+ output_shape = np.array([14, 14, 3])
+ model = rewrite(
+ input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4
+ )
+ model = rewrite.post_process(model)
+ assert not rewrite.check_optimization(model)
+ log_records = caplog.records
+ warning_messages = [x.message for x in log_records if x.levelno == 30]
+ assert (
+ re.search(
+ r"\nWARNING: Could not find \(2, 4\) sparsity, in "
+ r"layer conv2d_?\d? for weight conv2d_?\d?\/kernel:0 \n",
+ warning_messages[0],
+ )
+ is not None
+ )
+ model = rewrite(
+ input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4
+ )
+ train_rewrite_model(
+ input_shape=input_shape, output_shape=output_shape, rewrite_model=model
+ )
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model)
+
- rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
- model = rewrite(input_shape=(28, 28), output_shape=10)
+def test_rewrite_conv2d_unstructured_sparsity(caplog: pytest.LogCaptureFixture) -> None:
+ """Check that an unstructured sparse conv2d rewrite is correctly sparse."""
+
+ rewrite = UnstructuredSparsityRewrite(
+ "conv2d-unstructured-sparsity", conv2d_sparsity_unstructured_rewrite
+ )
+ input_shape = np.array([28, 28, 3])
+ output_shape = np.array([14, 14, 3])
+ model = rewrite(
+ input_shape=input_shape, output_shape=output_shape, final_sparsity=0.50
+ )
+ model = rewrite.post_process(model)
+ assert not rewrite.check_optimization(model)
+ log_records = caplog.records
+ warning_messages = [x.message for x in log_records if x.levelno == 30]
+ assert (
+ re.search(
+ r"\nWARNING: Found total sparsity of rewrite model: \d.\d\d "
+ r"expected total sparsity to be: 0.50\n",
+ warning_messages[0],
+ )
+ is not None
+ )
+ model = rewrite(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ final_sparsity=0.5,
+ end_step=120,
+ )
+ train_rewrite_model(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ rewrite_model=model,
+ epochs=10,
+ )
model = rewrite.post_process(model)
- assert rewrite.check_optimization(model, number_of_clusters=32)
+ assert rewrite.check_optimization(model)
-def test_rewrite_fully_connected_clustering_error_handling() -> None:
- """Check that model has the set number of clusters
- and that when quantized the number of clusters
- remain."""
+def test_rewrite_fully_connected_unstructured_sparsity(
+ caplog: pytest.LogCaptureFixture,
+) -> None:
+ """Check that an unstructured sparse FC rewrite is correctly sparse."""
- rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
- model = rewrite(input_shape=(28, 28), output_shape=10)
- with pytest.raises(
- ValueError,
- match=(
- r"Expected check_preserved_quantize to have argument number_of_clusters"
- ),
- ):
- rewrite.check_optimization(model, bad_arg_name=25)
+ rewrite = UnstructuredSparsityRewrite(
+ "fully-connected-unstructured-sparsity", fc_sparsity_unstructured_rewrite
+ )
+ input_shape = (28, 28)
+ output_shape = 10
+ model = rewrite(
+ input_shape=tuple(input_shape), output_shape=output_shape, final_sparsity=0.5
+ )
+ model = rewrite.post_process(model)
+ assert not rewrite.check_optimization(model)
+ log_records = caplog.records
+ warning_messages = [x.message for x in log_records if x.levelno == 30]
+ assert (
+ re.search(
+ r"\nWARNING: Found total sparsity of rewrite model: \d.\d\d "
+ r"expected total sparsity to be: 0.50\n",
+ warning_messages[0],
+ )
+ is not None
+ )
+ model = rewrite(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ final_sparsity=0.5,
+ end_step=120,
+ )
+ train_rewrite_model(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ rewrite_model=model,
+ epochs=10,
+ )
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model)
@pytest.mark.parametrize(
@@ -151,6 +306,40 @@ def test_rewrite_fully_connected_clustering_error_handling() -> None:
["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False],
["fully-connected-clustering", [ClusterWeights, ClusterWeights], False],
["fully-connected-clustering", [ClusterWeights, ClusterWeights], True],
+ ["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], False],
+ ["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], True],
+ ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], False],
+ ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], True],
+ [
+ "conv2d-sparsity",
+ [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude],
+ False,
+ ],
+ [
+ "conv2d-sparsity",
+ [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude],
+ True,
+ ],
+ [
+ "fully-connected-unstructured-sparsity",
+ [PruneLowMagnitude, PruneLowMagnitude],
+ False,
+ ],
+ [
+ "fully-connected-unstructured-sparsity",
+ [PruneLowMagnitude, PruneLowMagnitude],
+ True,
+ ],
+ [
+ "conv2d-unstructured-sparsity",
+ [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude],
+ False,
+ ],
+ [
+ "conv2d-unstructured-sparsity",
+ [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude],
+ True,
+ ],
],
)
def test_rewriting_optimizer( # pylint: disable=too-many-locals
@@ -162,24 +351,32 @@ def test_rewriting_optimizer( # pylint: disable=too-many-locals
expected_layers: list[object],
quant: bool,
) -> None:
- """Test fc_layer rewrite process with rewrite type fully-connected."""
+ """Test the rewrite process with all rewrite types."""
tfrecord = test_tfrecord if quant else test_tfrecord_fp32
tflite_model = test_tflite_model if quant else test_tflite_model_fp32
+ rewrite_function = RewritingOptimizer.registry.items[rewrite_type]
config_obj = RewriteConfiguration(
rewrite_type,
- ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
+ ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"]
+ if "fully-connected" in rewrite_type
+ else [
+ "sequential/conv1/Relu;sequential/conv1/Conv2D",
+ "sequential/conv2/Relu;sequential/conv2/Conv2D",
+ ],
tfrecord,
train_params=MockTrainingParameters(),
)
-
test_obj = RewritingOptimizer(tflite_model, config_obj)
- rewrite_function = RewritingOptimizer.registry.items[
- test_obj.optimizer_configuration.optimization_target
- ]
# Input, output shape does not matter, just need the test the layers are as expected
- rewrite_model = rewrite_function(input_shape=(28, 28, 1), output_shape=12)
+ rewrite_model = (
+ rewrite_function(input_shape=(28, 28, 1), output_shape=12)
+ if "fully-connected" in rewrite_type
+ else rewrite_function(
+ input_shape=np.array([28, 28, 3]), output_shape=np.array([14, 14, 3])
+ )
+ )
for idx, layer in enumerate(rewrite_model.layers):
assert isinstance(layer, expected_layers[idx]) # type: ignore
@@ -197,8 +394,14 @@ def test_register_rewrite_function() -> None:
"""Test adding rewrite functions and verify they are reported via the registry."""
registry = RewriteRegistry()
- rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1))
- rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2))
+ rewrite1 = GenericRewrite(
+ "r1",
+ cast(RewriteCallable, lambda: 1),
+ )
+ rewrite2 = GenericRewrite(
+ "r2",
+ cast(RewriteCallable, lambda: 2),
+ )
registry.register_rewrite(rewrite1)
registry.register_rewrite(rewrite2)
@@ -207,11 +410,15 @@ def test_register_rewrite_function() -> None:
def test_builtin_rewrite_names() -> None:
"""Test if all builtin rewrites are properly registered and returned."""
- assert RewritingOptimizer.builtin_rewrite_names() == [
+ assert set(RewritingOptimizer.builtin_rewrite_names()) == {
+ "conv2d-clustering",
+ "conv2d-sparsity",
+ "conv2d-unstructured-sparsity",
"fully-connected",
"fully-connected-clustering",
- "fully-connected-sparsity24",
- ]
+ "fully-connected-sparsity",
+ "fully-connected-unstructured-sparsity",
+ }
def test_rewrite_configuration_train_params(
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 94c99ff..03b230f 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -16,11 +16,12 @@ from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.nn.rewrite.core.train import augment_fn_twins
from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
+from mlia.nn.rewrite.core.train import detect_activation_from_rewrite_function
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
-from tests.test_nn_rewrite_core_rewrite import TestRewrite
+from tests.test_nn_rewrite_core_rewrite import GenericRewrite
from tests.utils.rewrite import MockTrainingParameters
@@ -54,7 +55,7 @@ def check_train(
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
output_file = Path(tmp_dir, "out.tflite")
- mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv)
+ mock_rewrite = GenericRewrite("replace", replace_fully_connected_with_conv)
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
@@ -65,6 +66,7 @@ def check_train(
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
train_params=train_params,
+ rewrite_specific_params={},
)
assert len(result[0][0]) == 2
@@ -249,3 +251,41 @@ def test_train_checkpoint(
use_unmodified_model=False,
quantized=True,
)
+
+
+def test_detect_activation_from_rewrite_function_no_activation(
+ caplog: pytest.LogCaptureFixture, test_tflite_no_act_model: Path
+) -> None:
+ """
+ Test function detect_activation_from_rewrite_function()
+ with a model with no activation functions.
+ """
+ caplog.set_level(level=20)
+ activation = detect_activation_from_rewrite_function(
+ test_tflite_no_act_model.as_posix()
+ )
+ log_records = caplog.get_records(when="call")
+ logging_messages = [x.message for x in log_records if x.levelno == 20]
+ assert activation == "relu"
+ assert (
+ "No activation function specified, setting activation function to ReLU"
+ in logging_messages
+ )
+
+
+def test_detect_activation_from_rewrite_function_relu_activation(
+ caplog: pytest.LogCaptureFixture, test_tflite_model: Path
+) -> None:
+ """
+ Test function detect_activation_from_rewrite_function()
+ with a model with ReLU activation functions.
+ """
+ caplog.set_level(level=20)
+ activation = detect_activation_from_rewrite_function(test_tflite_model.as_posix())
+ log_records = caplog.get_records(when="call")
+ logging_messages = [x.message for x in log_records if x.levelno == 20]
+ assert activation == "relu"
+ assert (
+ "No activation function specified, setting activation function "
+ "to most common activation detected in rewrite graph: relu" in logging_messages
+ )
diff --git a/tests/test_nn_rewrite_library_helper_functions.py b/tests/test_nn_rewrite_library_helper_functions.py
new file mode 100644
index 0000000..a0dd7b9
--- /dev/null
+++ b/tests/test_nn_rewrite_library_helper_functions.py
@@ -0,0 +1,103 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.library.helper_functions."""
+from __future__ import annotations
+
+from contextlib import ExitStack as does_not_raise
+from typing import Any
+
+import numpy as np
+import pytest
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST
+from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+from mlia.nn.rewrite.library.helper_functions import get_activation_function
+
+
+def compute_conv_output(
+ input_data: np.ndarray, input_shape: np.ndarray, conv_parameters: dict[str, Any]
+) -> np.ndarray:
+ """Compute the output of a conv layer for testing."""
+ test_model = keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv_parameters),
+ ]
+ )
+ output = test_model(input_data)
+ return np.array(output.shape[1:])
+
+
+@pytest.mark.parametrize(
+ "input_shape, output_shape, kernel_size",
+ [
+ (np.array([32, 32, 3]), np.array([16, 16, 3]), [3, 3]),
+ (np.array([32, 32, 3]), np.array([8, 8, 3]), [3, 3]),
+ (np.array([32, 32, 3]), np.array([8, 16, 3]), [3, 3]),
+ (np.array([25, 10, 3]), np.array([13, 5, 3]), [3, 3]),
+ (np.array([25, 10, 3]), np.array([7, 5, 3]), [3, 3]),
+ (np.array([25, 10, 3]), np.array([6, 4, 3]), [3, 3]),
+ (np.array([25, 10, 3]), np.array([5, 5, 3]), [3, 3]),
+ (np.array([32, 32, 3]), np.array([16, 16, 3]), [1, 3]),
+ (np.array([32, 32, 3]), np.array([16, 16, 3]), [1, 1]),
+ (np.array([32, 32, 3]), np.array([16, 16, 3]), [5, 5]),
+ ],
+)
+def test_compute_conv2d_parameters(
+ input_shape: np.ndarray, output_shape: np.ndarray, kernel_size: list[int]
+) -> None:
+ """Test to check compute_conv2d_parameters works as expected."""
+ conv_parameters = compute_conv2d_parameters(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ computed_output_shape = compute_conv_output(
+ np.random.rand(1, *input_shape), input_shape, conv_parameters
+ )
+ assert np.equal(computed_output_shape, output_shape).all()
+
+
+@pytest.mark.parametrize(
+ "activation, expected_function_type, expected_extra_args, expected_error",
+ [
+ ("relu", keras.layers.ReLU, {}, does_not_raise()),
+ ("relu6", keras.layers.ReLU, {"max_value": 6}, does_not_raise()),
+ ("none", None, {}, does_not_raise()),
+ (
+ "wrong_key",
+ keras.layers.Identity,
+ {},
+ pytest.raises(
+ KeyError,
+ match=(
+ "Expected activation function to be "
+ rf"in \{ACTIVATION_FUNCTION_LIST}\, found wrong_key"
+ ),
+ ),
+ ),
+ ],
+)
+def test_get_activation_functions(
+ activation: str,
+ expected_function_type: type,
+ expected_extra_args: dict,
+ expected_error: Any,
+) -> None:
+ """
+ Check the get_activation_function returns
+ the expected layer and extra arguments.
+ """
+ with expected_error:
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ if activation_function:
+ assert isinstance(
+ activation_function(**activation_function_extra_args),
+ expected_function_type,
+ )
+ else:
+ assert activation_function is None
+ assert expected_extra_args == activation_function_extra_args
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py
index 4095076..08752bd 100644
--- a/tests/test_nn_select.py
+++ b/tests/test_nn_select.py
@@ -4,12 +4,12 @@
from __future__ import annotations
from contextlib import ExitStack as does_not_raise
-from dataclasses import asdict
from pathlib import Path
from typing import Any
from typing import cast
import pytest
+import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.errors import ConfigurationError
@@ -176,23 +176,50 @@ def test_get_optimizer(
model = test_tflite_model
else:
model = keras.models.load_model(str(test_keras_model))
- optimizer = get_optimizer(model, config)
+ optimizer = get_optimizer(
+ model, config, {"train_params": None, "rewrite_specific_params": None}
+ )
assert isinstance(optimizer, expected_type)
assert optimizer.optimization_config() == expected_config
+# pylint: disable=line-too-long
@pytest.mark.parametrize(
- "rewrite_parameters",
- [None, {"batch_size": 64, "learning_rate": 0.003}],
+ "rewrite_parameters, optimization_target",
+ [
+ [
+ {"train_params": None, "rewrite_specific_params": None},
+ "fully-connected-clustering",
+ ],
+ [
+ {
+ "train_params": None,
+ "rewrite_specific_params": {
+ "num_clusters": 5,
+ "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization(
+ "CentroidInitialization.LINEAR"
+ ),
+ },
+ },
+ "fully-connected-clustering",
+ ],
+ [
+ {"train_params": None, "rewrite_specific_params": None},
+ "fully-connected",
+ ],
+ ],
)
+# pylint: enable=line-too-long
@pytest.mark.skip_set_training_steps
def test_get_optimizer_training_parameters(
- rewrite_parameters: dict | None, test_tflite_model: Path
+ rewrite_parameters: dict,
+ optimization_target: str,
+ test_tflite_model: Path,
) -> None:
"""Test function get_optimzer with various combinations of parameters."""
config = OptimizationSettings(
optimization_type="rewrite",
- optimization_target="fully-connected", # type: ignore
+ optimization_target=optimization_target, # type: ignore
layers_to_optimize=None,
dataset=None,
)
@@ -200,18 +227,20 @@ def test_get_optimizer_training_parameters(
RewritingOptimizer,
get_optimizer(test_tflite_model, config, rewrite_parameters),
)
+ assert len(list(rewrite_parameters.items())) == 2
+ if rewrite_parameters.get("rewrite_specific_params"):
+ assert isinstance(
+ rewrite_parameters["rewrite_specific_params"],
+ type(optimizer.optimizer_configuration.rewrite_specific_params),
+ )
+ assert (
+ optimizer.optimizer_configuration.rewrite_specific_params
+ == rewrite_parameters["rewrite_specific_params"]
+ )
assert isinstance(
optimizer.optimizer_configuration.train_params, TrainingParameters
)
- if not rewrite_parameters:
- assert asdict(TrainingParameters()) == asdict(
- optimizer.optimizer_configuration.train_params
- )
- else:
- assert asdict(TrainingParameters()) | rewrite_parameters == asdict(
- optimizer.optimizer_configuration.train_params
- )
@pytest.mark.parametrize(
diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py
index 7bb57c3..2f06f54 100644
--- a/tests/test_target_cortex_a_advisor.py
+++ b/tests/test_target_cortex_a_advisor.py
@@ -8,6 +8,7 @@ import pytest
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.workflow import DefaultWorkflowExecutor
+from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
from mlia.target.cortex_a.advisor import CortexAInferenceAdvisor
@@ -33,21 +34,11 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None:
"target_profile": "cortex-a",
},
"common_optimizations": {
- "optimizations": [
- [
- {
- "layers_to_optimize": None,
- "optimization_target": 0.5,
- "optimization_type": "pruning",
- },
- {
- "layers_to_optimize": None,
- "optimization_target": 32,
- "optimization_type": "clustering",
- },
- ]
- ],
- "training_parameters": None,
+ "optimizations": [_DEFAULT_OPTIMIZATION_TARGETS],
+ "rewrite_parameters": {
+ "train_params": None,
+ "rewrite_specific_params": None,
+ },
},
}
diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py
index 020acc5..d0b42b9 100644
--- a/tests/test_target_tosa_advisor.py
+++ b/tests/test_target_tosa_advisor.py
@@ -9,6 +9,7 @@ import pytest
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.workflow import DefaultWorkflowExecutor
+from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
from mlia.target.tosa.advisor import configure_and_get_tosa_advisor
from mlia.target.tosa.advisor import TOSAInferenceAdvisor
@@ -33,21 +34,11 @@ def test_configure_and_get_tosa_advisor(
assert ctx.event_handlers is not None
assert ctx.config_parameters == {
"common_optimizations": {
- "optimizations": [
- [
- {
- "layers_to_optimize": None,
- "optimization_target": 0.5,
- "optimization_type": "pruning",
- },
- {
- "layers_to_optimize": None,
- "optimization_target": 32,
- "optimization_type": "clustering",
- },
- ]
- ],
- "training_parameters": None,
+ "optimizations": [_DEFAULT_OPTIMIZATION_TARGETS],
+ "rewrite_parameters": {
+ "train_params": None,
+ "rewrite_specific_params": None,
+ },
},
"tosa_inference_advisor": {
"model": str(test_tflite_model),
diff --git a/tests_e2e/optimization_e2e_test.toml b/tests_e2e/optimization_e2e_test.toml
index 099247c..f075ec4 100644
--- a/tests_e2e/optimization_e2e_test.toml
+++ b/tests_e2e/optimization_e2e_test.toml
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-[training]
-steps = 1000
+[rewrite.training_parameters]
+steps = 100