aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py241
-rw-r--r--src/mlia/nn/rewrite/core/train.py86
-rw-r--r--src/mlia/nn/rewrite/library/clustering.py52
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py18
-rw-r--r--src/mlia/nn/rewrite/library/helper_functions.py35
-rw-r--r--src/mlia/nn/rewrite/library/layers.py53
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py121
-rw-r--r--src/mlia/nn/select.py24
8 files changed, 504 insertions, 126 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 6674d02..d7ffec1 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -8,11 +8,15 @@ import tempfile
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
+from inspect import getfullargspec
from pathlib import Path
+from statistics import fmean
from typing import Any
from typing import Callable
+from typing import Generator
import numpy as np
+import tensorflow as tf
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module
@@ -29,28 +33,55 @@ from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite
from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite
-from mlia.nn.rewrite.library.fc_layer import fc_rewrite
+from mlia.nn.rewrite.library.layers import conv2d_rewrite
+from mlia.nn.rewrite.library.layers import fc_rewrite
from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_unstructured_rewrite
from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_unstructured_rewrite
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
logger = logging.getLogger(__name__)
-RewriteCallable = Callable[[Any, Any], keras.Model]
+RewriteCallable = Callable[..., keras.Model]
class Rewrite(ABC):
"""Abstract class for rewrite logic to be used by RewritingOptimizer."""
- def __init__(self, name: str, rewrite_fn: RewriteCallable):
+ def __init__(
+ self,
+ name: str,
+ rewrite_fn: RewriteCallable,
+ rewrite_fn_extra_args: dict[str, Any] | None = None,
+ ):
"""Initialize a Rewrite instance with a given name and an optional function."""
self.name = name
self.function = rewrite_fn
+ self.function_extra_args = rewrite_fn_extra_args
- def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model:
+ def __call__(
+ self, input_shape: Any, output_shape: Any, **kwargs: Any
+ ) -> keras.Model:
"""Perform the rewrite operation using the configured function."""
try:
- return self.function(input_shape, output_shape)
+ if self.function_extra_args:
+ return self.function(
+ input_shape, output_shape, **kwargs, **self.function_extra_args
+ )
+
+ return self.function(input_shape, output_shape, **kwargs)
+ except TypeError as ex:
+ expected_args = self.return_rewrite_func_args()
+ if "input_shape" in expected_args:
+ expected_args.remove("input_shape")
+ if "output_shape" in expected_args:
+ expected_args.remove("output_shape")
+ raise KeyError(
+ f"Found unexpected parameters for rewrite. Expected (sub)set "
+ f"of {expected_args} found unexpected parameter(s) "
+ f"{list(set(list(kwargs.keys())) - set(expected_args))}"
+ ) from ex
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
@@ -58,21 +89,25 @@ class Rewrite(ABC):
"""Return a quantized model if required."""
return model
+ def return_rewrite_func_args(self) -> list[str]:
+ """Return the expected args of the rewrite function."""
+ return getfullargspec(self.function).args
+
@abstractmethod
def training_callbacks(self) -> list:
- """Return default rewrite callbacks."""
+ """Return rewrite callbacks."""
@abstractmethod
def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ """Return post-processing rewrite option."""
@abstractmethod
- def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ def check_optimization(self, model: keras.Model) -> bool:
"""Check if the optimization has produced the correct result."""
class GenericRewrite(Rewrite):
- """Graph rewrite logic for fully-connected rewrite."""
+ """Rewrite class for generic rewrites e.g. fully-connected."""
def quantize(self, model: keras.Model) -> keras.Model:
"""Return a quantized model if required."""
@@ -83,10 +118,10 @@ class GenericRewrite(Rewrite):
return []
def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ """Return default post-processing rewrite option."""
return model
- def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ def check_optimization(self, model: keras.Model, **_: Any) -> bool:
"""Not needed here."""
return True
@@ -99,16 +134,31 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC):
"""Apply optimization-aware quantization to a given model."""
return model
+ def check_optimization_generator(
+ self, model: keras.Model
+ ) -> Generator[tuple[tf.Tensor, keras.layers.Layer], None, None]:
+ """Loop for check_optimization function."""
+ for layer in model.layers:
+ for weight in layer.weights:
+ if "kernel" in weight.name:
+ if (
+ "kernel_min" in weight.name
+ or "kernel_max" in weight.name
+ or "depthwise" in weight.name
+ ):
+ continue
+ yield weight, layer
+
-class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
- """Graph rewrite logic for fully-connected-sparsity24 rewrite."""
+class SparsityRewrite(QuantizeAwareTrainingRewrite):
+ """Base rewrite class for sparsity rewrites."""
pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
def quantize(self, model: keras.Model) -> keras.Model:
- """Skip quantization when using pruning rewrite."""
+ """Skip quantization when using sparsity rewrite."""
return model
def training_callbacks(self) -> list:
@@ -116,7 +166,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
return [self.pruning_callback()]
def post_process(self, model: keras.Model) -> keras.Model:
- """Pruning-specific post-processing rewrite options."""
+ """Pruning-specific post-processing rewrite option."""
return self.strip_pruning_wrapper(model)
def preserved_quantize(
@@ -129,29 +179,78 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
model,
tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
)
-
return model
- def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ def check_optimization(self, model: keras.Model) -> bool:
+ """Not needed here."""
+ return True
+
+
+class UnstructuredSparsityRewrite(SparsityRewrite):
+ """
+ Rewrite class for unstructured sparsity rewrite.
+
+ e.g. fully-connected-unstructured-sparsity.
+ """
+
+ def check_optimization(
+ self, model: keras.Model, final_sparsity: float = 0.5, **_: Any
+ ) -> bool:
+ """Not needed here."""
+ found_sparsity_list = []
+ num_dec_places = str(final_sparsity)[::-1].find(".")
+ for weight, _ in self.check_optimization_generator(model=model):
+ weight_np = weight.numpy()
+ found_sparsity_list.append(
+ round(np.count_nonzero(weight_np) / weight_np.size, num_dec_places)
+ )
+ if len(found_sparsity_list) == 0:
+ logger.warning(
+ "\nWARNING: Could not find any layers "
+ "in rewrite that could be sparsely pruned"
+ )
+ return False
+ found_sparsity = fmean(found_sparsity_list)
+ if found_sparsity != final_sparsity:
+ logger.warning(
+ "\nWARNING: Found total sparsity of "
+ "rewrite model: %.2f "
+ "expected total sparsity to be: "
+ "%.2f\n",
+ found_sparsity,
+ final_sparsity,
+ )
+ return False
+ return True
+
+
+class StructuredSparsityRewrite(SparsityRewrite):
+ """Rewrite class for structured sparsity rewrite e.g. fully-connected-sparsity."""
+
+ def check_optimization(
+ self,
+ model: keras.Model,
+ sparsity_m: int = 2,
+ sparsity_n: int = 4,
+ **_: Any,
+ ) -> bool:
"""Check if sparity has produced the correct result."""
- for layer in model.layers:
- for weight in layer.weights:
- if "kernel" in weight.name:
- if "kernel_min" in weight.name or "kernel_max" in weight.name:
- continue
- if not is_pruned_m_by_n(weight, m_by_n=(2, 4)):
- logger.warning(
- "\nWARNING: Could not find (2,4) sparsity, "
- "in layer %s for weight %s \n",
- layer.name,
- weight.name,
- )
- return False
+ for weight, layer in self.check_optimization_generator(model=model):
+ if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)):
+ logger.warning(
+ "\nWARNING: Could not find (%d, %d) sparsity, "
+ "in layer %s for weight %s \n",
+ sparsity_m,
+ sparsity_n,
+ layer.name,
+ weight.name,
+ )
+ return False
return True
class ClusteringRewrite(QuantizeAwareTrainingRewrite):
- """Graph clustering rewrite logic to be used by RewritingOptimizer."""
+ """Rewrite class for clustering rewrite e.g. fully-connected-clustering."""
_strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
@@ -164,32 +263,22 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite):
)
return cqat_model
- def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ def check_optimization(
+ self, model: keras.Model, num_clusters: int = 2, **_: Any
+ ) -> bool:
"""Check if clustering has produced the correct result."""
- number_of_clusters = kwargs.get("number_of_clusters")
- if not number_of_clusters:
- raise ValueError(
- """
- Expected check_preserved_quantize to have argument number_of_clusters.
- """
- )
-
- for layer in model.layers:
- for weight in layer.weights:
- if "kernel" in weight.name:
- if "kernel_min" in weight.name or "kernel_max" in weight.name:
- continue
- number_of_found_clusters = len(np.unique(weight))
- if number_of_found_clusters != number_of_clusters:
- logger.warning(
- "\nWARNING: Expected %d cluster(s), found %d "
- "cluster(s) in layer %s for weight %s \n",
- number_of_clusters,
- number_of_found_clusters,
- layer.name,
- weight.name,
- )
- return False
+ for weight, layer in self.check_optimization_generator(model=model):
+ number_of_found_clusters = len(np.unique(weight))
+ if number_of_found_clusters != num_clusters:
+ logger.warning(
+ "\nWARNING: Expected %d cluster(s), found %d "
+ "cluster(s) in layer %s for weight %s \n",
+ num_clusters,
+ number_of_found_clusters,
+ layer.name,
+ weight.name,
+ )
+ return False
return True
def training_callbacks(self) -> list:
@@ -197,7 +286,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite):
return []
def post_process(self, model: keras.Model) -> keras.Model:
- """Return the clustering stripped model."""
+ """Clustering-specific post-processing rewrite option."""
return self._strip_clustering_wrapper(model)
@@ -228,6 +317,7 @@ class RewriteConfiguration(OptimizerConfiguration):
layers_to_optimize: list[str] | None = None
dataset: Path | None = None
train_params: TrainingParameters = TrainingParameters()
+ rewrite_specific_params: dict | None = None
def __str__(self) -> str:
"""Return string representation of the configuration."""
@@ -240,10 +330,38 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
GenericRewrite("fully-connected", fc_rewrite),
- Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite),
+ GenericRewrite("conv2d", conv2d_rewrite),
+ GenericRewrite(
+ "depthwise-separable-conv2d",
+ conv2d_rewrite,
+ {"layer_type": keras.layers.SeparableConv2D},
+ ),
+ StructuredSparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite),
ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite),
ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite),
- Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite),
+ StructuredSparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite),
+ UnstructuredSparsityRewrite(
+ "conv2d-unstructured-sparsity", conv2d_sparsity_unstructured_rewrite
+ ),
+ UnstructuredSparsityRewrite(
+ "fully-connected-unstructured-sparsity",
+ fc_sparsity_unstructured_rewrite,
+ ),
+ ClusteringRewrite(
+ "depthwise-separable-conv2d-clustering",
+ conv2d_clustering_rewrite,
+ {"layer_type": keras.layers.SeparableConv2D},
+ ),
+ StructuredSparsityRewrite(
+ "depthwise-separable-conv2d-sparsity",
+ conv2d_sparsity_rewrite,
+ {"layer_type": keras.layers.SeparableConv2D},
+ ),
+ UnstructuredSparsityRewrite(
+ "depthwise-separable-conv2d-unstructured-sparsity",
+ conv2d_sparsity_unstructured_rewrite,
+ {"layer_type": keras.layers.SeparableConv2D},
+ ),
]
)
@@ -265,7 +383,6 @@ class RewritingOptimizer(Optimizer):
rewrite = RewritingOptimizer.registry.items[
self.optimizer_configuration.optimization_target
]
-
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
@@ -287,6 +404,10 @@ class RewritingOptimizer(Optimizer):
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
train_params=self.optimizer_configuration.train_params,
+ rewrite_specific_params=self.optimizer_configuration.rewrite_specific_params, # pylint: disable=line-too-long
+ detect_activation_function=(
+ "activation" in rewrite.return_rewrite_func_args()
+ ),
)
if orig_vs_repl_stats:
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 4204978..570968a 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -34,13 +34,13 @@ from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.utils.logging import log_action
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
@@ -83,6 +83,8 @@ def train( # pylint: disable=too-many-arguments
input_tensors: list,
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
+ rewrite_specific_params: dict | None = None,
+ detect_activation_function: bool = False,
) -> Any:
"""Extract and train a model, and return the results."""
if unmodified_model:
@@ -122,6 +124,8 @@ def train( # pylint: disable=too-many-arguments
rewrite=rewrite,
is_qat=is_qat,
train_params=train_params,
+ rewrite_specific_params=rewrite_specific_params,
+ detect_activation_function=detect_activation_function,
)
for i, filename in enumerate(tflite_filenames):
@@ -349,6 +353,41 @@ def set_up_data_pipeline(
return dataset, steps_per_epoch
+def detect_activation_from_rewrite_function(model_path: str) -> str:
+ """Given a rewrite model, choose the most common activation function."""
+ interpreter = tf.lite.Interpreter(model_path=model_path)
+ interpreter.allocate_tensors()
+ act_func_match_list = []
+ for tensor_details in interpreter.get_tensor_details():
+ for act_func in ACTIVATION_FUNCTION_LIST:
+ tensor_name = tensor_details["name"].lower()
+ if act_func in tensor_name:
+ act_func_idx = tensor_name.index(act_func)
+ if (
+ len(tensor_name) == act_func_idx + len(act_func)
+ or tensor_name[act_func_idx + len(act_func)] == ";"
+ ):
+ act_func_match_list.append(
+ tensor_name[
+ act_func_idx : act_func_idx + len(act_func) # noqa: E203
+ ]
+ )
+ act_func_match = "relu"
+ if len(act_func_match_list) == 0:
+ logger.info(
+ "No activation function specified, setting activation function to ReLU"
+ )
+ else:
+ act_func_match = max(set(act_func_match_list), key=act_func_match.count)
+ logger.info(
+ "No activation function specified, "
+ "setting activation function to most "
+ "common activation detected in rewrite graph: %s",
+ act_func_match,
+ )
+ return act_func_match
+
+
def train_in_dir(
train_dir: str,
baseline_dir: Any,
@@ -356,6 +395,8 @@ def train_in_dir(
rewrite: Callable,
is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
+ rewrite_specific_params: dict | None = None,
+ detect_activation_function: bool = False,
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
and output.tfrec in train_dir.
@@ -372,6 +413,18 @@ def train_in_dir(
)
replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir))
+ if detect_activation_function and (
+ rewrite_specific_params is None
+ or "activation" not in list(rewrite_specific_params.keys())
+ ):
+ detected_activation_function = detect_activation_from_rewrite_function(
+ ExtractPaths.tflite.replace(train_dir).as_posix()
+ )
+ if rewrite_specific_params:
+ rewrite_specific_params["activation"] = detected_activation_function
+ else:
+ rewrite_specific_params = {"activation": detected_activation_function}
+
input_name, output_name = _get_io_tensors(teacher)
model_is_quantized = replace.is_tensor_quantized(name=input_name)
@@ -396,7 +449,13 @@ def train_in_dir(
loss_fn = keras.losses.MeanSquaredError()
model = create_model(
- rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized
+ rewrite,
+ input_shape,
+ output_shape,
+ optimizer,
+ loss_fn,
+ model_is_quantized,
+ rewrite_specific_params=rewrite_specific_params,
)
logger.info(model.summary())
@@ -462,11 +521,9 @@ def train_in_dir(
steps_per_epoch,
post_process=True,
)
-
- # Placeholder for now, will be parametrized later (MLIA-1114)
- # rewrite.check_optimization( # type: ignore[attr-defined]
- # model, number_of_clusters=32
- # )
+ rewrite.check_optimization( # type: ignore[attr-defined]
+ model, **rewrite_specific_params if rewrite_specific_params else {}
+ )
if model_is_quantized and is_qat:
model = rewrite.preserved_quantize(model) # type: ignore[attr-defined]
checkpoints = (
@@ -501,11 +558,10 @@ def train_in_dir(
loss_fn,
steps_per_epoch,
)
- # Placeholder for now, will be parametrized later (MLIA-1114)
- # rewrite.check_optimization( # type: ignore[attr-defined]
- # model, number_of_clusters=32
- # )
+ rewrite.check_optimization( # type: ignore[attr-defined]
+ model, **rewrite_specific_params if rewrite_specific_params else {}
+ )
teacher.close()
return output_filenames
@@ -528,9 +584,13 @@ def create_model( # pylint: disable=too-many-arguments
loss_fn: Callable,
model_is_quantized: bool,
model_to_load_from: keras.model | None = None,
+ rewrite_specific_params: dict | None = None,
) -> keras.Model:
"""Create a model, optionally from another."""
- model = rewrite(input_shape, output_shape)
+ if rewrite_specific_params:
+ model = rewrite(input_shape, output_shape, **rewrite_specific_params)
+ else:
+ model = rewrite(input_shape, output_shape)
if model_is_quantized:
model = rewrite.quantize(model) # type: ignore[attr-defined]
model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn)
@@ -558,6 +618,7 @@ def model_fit( # pylint: disable=too-many-arguments
loss_fn: Callable,
steps_per_epoch: int,
post_process: bool = False,
+ rewrite_specific_params: dict | None = None,
) -> keras.Model:
"""Train a tflite model."""
steps_so_far = 0
@@ -597,6 +658,7 @@ def model_fit( # pylint: disable=too-many-arguments
loss_fn,
model_is_quantized,
model_to_load_from=model,
+ rewrite_specific_params=rewrite_specific_params,
)
else:
model_to_save = model
diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py
index 6f06c48..9247457 100644
--- a/src/mlia/nn/rewrite/library/clustering.py
+++ b/src/mlia/nn/rewrite/library/clustering.py
@@ -1,19 +1,27 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Example rewrite with one fully connected clustered layer."""
+"""Rewrite functions used to return layers ready for clustering."""
from typing import Any
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+from mlia.nn.rewrite.library.helper_functions import get_activation_function
-def fc_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for clustering rewrite."""
+def fc_clustering_rewrite(
+ input_shape: Any,
+ output_shape: Any,
+ num_clusters: int = 2,
+ cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long
+ "CentroidInitialization.LINEAR"
+ ),
+) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for clustering."""
rewrite_params = {
- "number_of_clusters": 4,
- "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
+ "number_of_clusters": num_clusters,
+ "cluster_centroids_init": cluster_centroids_init,
}
model = tfmot.clustering.keras.cluster_weights(
to_cluster=keras.Sequential(
@@ -28,22 +36,42 @@ def fc_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
return model
-def conv2d_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for clustering rewrite."""
+def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value
+ input_shape: Any,
+ output_shape: Any,
+ num_clusters: int = 2,
+ cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long
+ "CentroidInitialization.LINEAR"
+ ),
+ activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
+ layer_type: type[keras.layers.Layer] = keras.layers.Conv2D,
+) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for clustering."""
rewrite_params = {
- "number_of_clusters": 4,
- "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
+ "number_of_clusters": num_clusters,
+ "cluster_centroids_init": cluster_centroids_init,
}
conv2d_parameters = compute_conv2d_parameters(
- input_shape=input_shape, output_shape=output_shape
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ activation_func_found = ( # pylint: disable=duplicate-code
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
)
model = tfmot.clustering.keras.cluster_weights(
to_cluster=keras.Sequential(
[
keras.layers.InputLayer(input_shape=input_shape),
- keras.layers.Conv2D(**conv2d_parameters),
+ layer_type(**conv2d_parameters),
keras.layers.BatchNormalization(),
- keras.layers.ReLU(),
+ *activation_func_found,
]
),
**rewrite_params
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
deleted file mode 100644
index cb98cb9..0000000
--- a/src/mlia/nn/rewrite/library/fc_layer.py
+++ /dev/null
@@ -1,18 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Example rewrite with one fully connected layer."""
-from typing import Any
-
-from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-
-
-def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for rewrite."""
- model = keras.Sequential(
- (
- keras.layers.InputLayer(input_shape=input_shape),
- keras.layers.Reshape([-1]),
- keras.layers.Dense(output_shape),
- )
- )
- return model
diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py
index 4f08170..1237c17 100644
--- a/src/mlia/nn/rewrite/library/helper_functions.py
+++ b/src/mlia/nn/rewrite/library/helper_functions.py
@@ -5,19 +5,48 @@ import math
from typing import Any
import numpy as np
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+ACTIVATION_FUNCTION_PRESETS = {
+ "relu": {"layer_func": keras.layers.ReLU, "extra_args": {}},
+ "relu6": {"layer_func": keras.layers.ReLU, "extra_args": {"max_value": 6}},
+ "none": {"layer_func": None, "extra_args": {}},
+}
+ACTIVATION_FUNCTION_LIST = [
+ act_func for act_func, _ in ACTIVATION_FUNCTION_PRESETS.items()
+]
-def compute_conv2d_parameters(
- input_shape: np.ndarray, output_shape: np.ndarray
+
+def get_activation_function(
+ activation: str = "relu",
+) -> tuple[type, dict]:
+ """Get the activation function from a key."""
+ if activation not in ACTIVATION_FUNCTION_LIST:
+ raise KeyError(
+ "Expected activation function to be "
+ f"in {ACTIVATION_FUNCTION_LIST}, found {activation}"
+ )
+ activation_function = ACTIVATION_FUNCTION_PRESETS[activation]["layer_func"]
+ activation_function_extra_args = ACTIVATION_FUNCTION_PRESETS[activation][
+ "extra_args"
+ ]
+ return activation_function, activation_function_extra_args
+
+
+def compute_conv2d_parameters( # pylint: disable=dangerous-default-value
+ input_shape: np.ndarray,
+ output_shape: np.ndarray,
+ kernel_size_input: list[int] = [3, 3],
) -> dict[str, Any]:
"""Compute needed kernel size and strides for a given input and output_shape."""
input_shape = input_shape.tolist()
output_shape = output_shape.tolist()
+ assert len(kernel_size_input) == 2, "Kernel size should have 2 entries"
assert len(input_shape) == 3
assert len(output_shape) == 3
+ kernel_size = tuple(kernel_size_input)
num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1]
padding = "valid"
- kernel_size = (3, 3)
stride_h = round(input_shape[0] / output_shape[0])
check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1
stride_w = round(input_shape[1] / output_shape[1])
diff --git a/src/mlia/nn/rewrite/library/layers.py b/src/mlia/nn/rewrite/library/layers.py
new file mode 100644
index 0000000..abf0a4c
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/layers.py
@@ -0,0 +1,53 @@
+# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Rewrite function used to return regular layers."""
+from typing import Any
+
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+from mlia.nn.rewrite.library.helper_functions import get_activation_function
+
+
+def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Fully connected TensorFlow Lite model for rewrite."""
+ model = keras.Sequential(
+ (
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ )
+ )
+ return model
+
+
+def conv2d_rewrite( # pylint: disable=dangerous-default-value
+ input_shape: Any,
+ output_shape: Any,
+ activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
+ layer_type: type[keras.layers.Layer] = keras.layers.Conv2D,
+) -> keras.Model:
+ """Fully connected TensorFlow Lite model for rewrite."""
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ activation_func_found = ( # pylint: disable=duplicate-code
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
+ )
+ model = keras.Sequential(
+ (
+ keras.layers.InputLayer(input_shape=input_shape),
+ layer_type(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ *activation_func_found,
+ )
+ )
+ return model
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
index 709593a..5102094 100644
--- a/src/mlia/nn/rewrite/library/sparsity.py
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -1,16 +1,26 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Example rewrite with one fully connected clustered layer."""
+"""Rewrite functions used to return layers ready for sparse pruning."""
+from __future__ import annotations
+
from typing import Any
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+from mlia.nn.rewrite.library.helper_functions import get_activation_function
-def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for rewrite."""
+def fc_sparsity_unstructured_rewrite(
+ input_shape: Any,
+ output_shape: Any,
+ initial_sparsity: float = 0.5,
+ final_sparsity: float = 0.5,
+ begin_step: int = 0,
+ end_step: int = 48000,
+) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for unstructured sparse pruning."""
model = tfmot.sparsity.keras.prune_low_magnitude(
to_prune=keras.Sequential(
[
@@ -19,27 +29,118 @@ def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
keras.layers.Dense(output_shape),
]
),
- sparsity_m_by_n=(2, 4),
+ pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=initial_sparsity,
+ final_sparsity=final_sparsity,
+ begin_step=begin_step,
+ end_step=end_step,
+ ),
)
return model
-def conv2d_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for rewrite."""
+def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value, too-many-arguments
+ input_shape: Any,
+ output_shape: Any,
+ initial_sparsity: float = 0.5,
+ final_sparsity: float = 0.5,
+ begin_step: int = 0,
+ end_step: int = 48000,
+ activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
+ layer_type: type[keras.layers.Layer] = keras.layers.Conv2D,
+) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for unstructured sparse pruning."""
conv2d_parameters = compute_conv2d_parameters(
- input_shape=input_shape, output_shape=output_shape
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ activation_func_found = (
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
)
model = tfmot.sparsity.keras.prune_low_magnitude(
to_prune=keras.Sequential(
[
keras.layers.InputLayer(input_shape=input_shape),
- keras.layers.Conv2D(**conv2d_parameters),
+ layer_type(**conv2d_parameters),
keras.layers.BatchNormalization(),
- keras.layers.ReLU(),
+ *activation_func_found,
]
),
- sparsity_m_by_n=(2, 4),
+ pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=initial_sparsity,
+ final_sparsity=final_sparsity,
+ begin_step=begin_step,
+ end_step=end_step,
+ ),
)
return model
+
+
+def fc_sparsity_rewrite(
+ input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4
+) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for sparse pruning."""
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ ]
+ ),
+ sparsity_m_by_n=(
+ sparsity_m,
+ sparsity_n,
+ ),
+ )
+
+ return model
+
+
+def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value
+ input_shape: Any,
+ output_shape: Any,
+ sparsity_m: int = 2,
+ sparsity_n: int = 4,
+ activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
+ layer_type: type[keras.layers.Layer] = keras.layers.Conv2D,
+) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for sparse pruning."""
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ activation_func_found = ( # pylint: disable=duplicate-code
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
+ )
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ layer_type(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ *activation_func_found,
+ ]
+ ),
+ sparsity_m_by_n=(
+ sparsity_m,
+ sparsity_n,
+ ),
+ )
+ return model
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index b61e713..d5470d1 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -17,7 +17,7 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
-from mlia.nn.rewrite.core.rewrite import TrainingParameters
+from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import KerasModel
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.optimizations.clustering import Clusterer
@@ -109,7 +109,7 @@ class MultiStageOptimizer(Optimizer):
def apply_optimization(self) -> None:
"""Apply optimization to the model."""
for config in self.optimizations:
- optimizer = get_optimizer(self.model, config)
+ optimizer = get_optimizer(self.model, config, {})
optimizer.apply_optimization()
self.model = optimizer.get_model()
@@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
model: keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
- training_parameters: dict | None = None,
+ rewrite_parameters: dict,
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
@@ -137,12 +137,12 @@ def get_optimizer(
if isinstance(config, OptimizationSettings):
return _get_optimizer(
- model, cast(OptimizationSettings, config), training_parameters
+ model, cast(OptimizationSettings, config), rewrite_parameters
)
if is_list_of(config, OptimizationSettings):
return _get_optimizer(
- model, cast(List[OptimizationSettings], config), training_parameters
+ model, cast(List[OptimizationSettings], config), rewrite_parameters
)
raise ConfigurationError(f"Unknown optimization configuration {config}")
@@ -151,7 +151,7 @@ def get_optimizer(
def _get_optimizer(
model: keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
- training_parameters: dict | None = None,
+ rewrite_parameters: dict,
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
@@ -162,12 +162,12 @@ def _get_optimizer(
_check_optimizer_params(opt_type, opt_target)
opt_config = _get_optimizer_configuration(
- opt_type, opt_target, layers_to_optimize, dataset, training_parameters
+ opt_type, opt_target, rewrite_parameters, layers_to_optimize, dataset
)
optimizer_configs.append(opt_config)
if len(optimizer_configs) == 1:
- return get_optimizer(model, optimizer_configs[0])
+ return get_optimizer(model, optimizer_configs[0], {})
return MultiStageOptimizer(model, optimizer_configs)
@@ -189,9 +189,9 @@ def _get_rewrite_params(
def _get_optimizer_configuration(
optimization_type: str,
optimization_target: int | float | str,
+ rewrite_parameters: dict,
layers_to_optimize: list[str] | None = None,
dataset: Path | None = None,
- training_parameters: dict | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
@@ -212,12 +212,14 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
- rewrite_params = _get_rewrite_params(training_parameters)
return RewriteConfiguration(
optimization_target=str(optimization_target),
layers_to_optimize=layers_to_optimize,
dataset=dataset,
- train_params=rewrite_params,
+ train_params=_get_rewrite_params(rewrite_parameters["train_params"]),
+ rewrite_specific_params=rewrite_parameters.get(
+ "rewrite_specific_params"
+ ),
)
raise ConfigurationError(