aboutsummaryrefslogtreecommitdiff
path: root/src/mlia
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py10
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py220
-rw-r--r--src/mlia/nn/rewrite/core/train.py209
-rw-r--r--src/mlia/nn/rewrite/library/clustering.py51
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py6
-rw-r--r--src/mlia/nn/rewrite/library/helper_functions.py32
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py45
-rw-r--r--src/mlia/nn/select.py23
-rw-r--r--src/mlia/resources/optimization_profiles/optimization.toml1
-rw-r--r--src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml13
-rw-r--r--src/mlia/target/common/optimization.py68
11 files changed, 584 insertions, 94 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index f85433d..7d9f219 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Save subgraph data."""
# pylint: disable=too-many-locals
@@ -32,7 +32,7 @@ def dequantized_path(filename: str | Path) -> Path:
return path
-def record_model(
+def record_model( # pylint: disable=too-many-arguments
input_filename: str | Path,
model_filename: str | Path,
output_filename: str | Path,
@@ -41,6 +41,7 @@ def record_model(
num_procs: int = 1,
num_threads: int = 0,
dequantize_output: bool = False,
+ quantize_input: bool = False,
) -> None:
"""Model recorder.
@@ -92,7 +93,10 @@ def record_model(
for _, named_x in enumerate(
track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
):
- named_y = model(named_x)
+ if quantize_input:
+ named_y = model(model.quantize_inputs(named_x))
+ else:
+ named_y = model(named_x)
write(writer, named_y)
if dequantize_output:
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index c7d13ba..6d915c6 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -3,16 +3,21 @@
"""Contains class RewritingOptimizer to replace a subgraph/layer of a model."""
from __future__ import annotations
-import importlib
import logging
import tempfile
+from abc import ABC
+from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import cast
+import numpy as np
+import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module
+ is_pruned_m_by_n,
+)
from mlia.core.errors import ConfigurationError
from mlia.core.reporting import Column
@@ -22,16 +27,20 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
+from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite
+from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite
+from mlia.nn.rewrite.library.fc_layer import fc_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
-
logger = logging.getLogger(__name__)
RewriteCallable = Callable[[Any, Any], keras.Model]
-class Rewrite:
- """Graph rewrite logic to be used by RewritingOptimizer."""
+class Rewrite(ABC):
+ """Abstract class for rewrite logic to be used by RewritingOptimizer."""
def __init__(self, name: str, rewrite_fn: RewriteCallable):
"""Initialize a Rewrite instance with a given name and an optional function."""
@@ -39,40 +48,157 @@ class Rewrite:
self.function = rewrite_fn
def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model:
- """Perform the rewrite operation using the configured function."""
+ """Return an instance of the rewrite model."""
try:
return self.function(input_shape, output_shape)
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ return model
-@dataclass
-class DynamicallyLoadedRewrite(Rewrite):
- """A rewrite which can load logic from a function loaded dynamically."""
+ @abstractmethod
+ def training_callbacks(self) -> list:
+ """Return rewrite callbacks."""
- def __init__(self, name: str, function_name: str):
- """Initialize."""
+ @abstractmethod
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return post-processing rewrite option."""
- def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model:
- """Load the function from a file dynamically."""
- self.load_function(function_name)
- return self.function(input_shape, output_shape)
+ @abstractmethod
+ def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ """Check if the optimization has produced the correct result."""
- super().__init__(name, load_and_run)
- def load_function(self, function_name: str) -> RewriteCallable:
- """Return the rewrite function. Import using the auto_load attr if necessary."""
- try:
- name_parts = function_name.split(".")
- module_name = ".".join(name_parts[:-1])
- fn_name = name_parts[-1]
- module = importlib.import_module(module_name)
- self.function = cast(RewriteCallable, getattr(module, fn_name))
- return self.function
- except Exception as ex:
- raise RuntimeError(
- f"Unable to load rewrite function '{function_name}' for '{self.name}'."
- ) from ex
+class GenericRewrite(Rewrite):
+ """Rewrite class for generic rewrites e.g. fully-connected."""
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ return tfmot.quantization.keras.quantize_model(model)
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite option."""
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Not needed here."""
+ return True
+
+
+class QuantizeAwareTrainingRewrite(Rewrite, ABC):
+ """Abstract class for rewrites that perform QAT."""
+
+ @abstractmethod
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Apply optimization-aware quantization to a given model."""
+ return model
+
+
+class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
+ """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity24."""
+
+ pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
+
+ strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Skip quantization when using sparsity rewrite."""
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return pruning-specific rewrite callback."""
+ return [self.pruning_callback()]
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Pruning-specific post-processing rewrite option."""
+ return self.strip_pruning_wrapper(model)
+
+ def preserved_quantize(
+ self,
+ model: keras.Model,
+ ) -> keras.Model:
+ """Apply pruning-preserved quantization training to a given model."""
+ model = tfmot.quantization.keras.quantize_annotate_model(model)
+ model = tfmot.quantization.keras.quantize_apply(
+ model,
+ tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
+ )
+
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Check if sparity has produced the correct result."""
+ for layer in model.layers:
+ for weight in layer.weights:
+ if "kernel" in weight.name:
+ if "kernel_min" in weight.name or "kernel_max" in weight.name:
+ continue
+ if not is_pruned_m_by_n(weight, m_by_n=(2, 4)):
+ logger.warning(
+ "\nWARNING: Could not find (2,4) sparsity, "
+ "in layer %s for weight %s \n",
+ layer.name,
+ weight.name,
+ )
+ return False
+ return True
+
+
+class ClusteringRewrite(QuantizeAwareTrainingRewrite):
+ """Rewrite class for clustering rewrite e.g. fully-connected-clustering."""
+
+ _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
+
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Apply clustering-preserved quantization to a given model."""
+ quant_aware_model = tfmot.quantization.keras.quantize_annotate_model(model)
+ cqat_model = tfmot.quantization.keras.quantize_apply(
+ quant_aware_model,
+ tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(),
+ )
+ return cqat_model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Check if clustering has produced the correct result."""
+ number_of_clusters = kwargs.get("number_of_clusters")
+ if not number_of_clusters:
+ raise ValueError(
+ """
+ Expected check_optimization to have argument number_of_clusters.
+ """
+ )
+
+ for layer in model.layers:
+ for weight in layer.weights:
+ if "kernel" in weight.name:
+ if "kernel_min" in weight.name or "kernel_max" in weight.name:
+ continue
+ number_of_found_clusters = len(np.unique(weight))
+ if number_of_found_clusters != number_of_clusters:
+ logger.warning(
+ "\nWARNING: Expected %d cluster(s), found %d "
+ "cluster(s) in layer %s for weight %s \n",
+ number_of_clusters,
+ number_of_found_clusters,
+ layer.name,
+ weight.name,
+ )
+ return False
+ return True
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Clustering-specific post-processing rewrite option."""
+ return self._strip_clustering_wrapper(model)
class RewriteRegistry(Registry[Rewrite]):
@@ -113,9 +239,11 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
- DynamicallyLoadedRewrite(
- "fully-connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model"
- )
+ GenericRewrite("fully-connected", fc_rewrite),
+ Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite),
+ ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite),
+ ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite),
+ Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite),
]
)
@@ -149,22 +277,35 @@ class RewritingOptimizer(Optimizer):
raise ConfigurationError(
"Input and output tensor names need to be set for rewrite."
)
-
orig_vs_repl_stats, total_stats = train(
source_model=tflite_model,
unmodified_model=tflite_model if use_unmodified_model else None,
output_model=str(tmp_output),
input_tfrec=str(tfrecord),
- replace_fn=rewrite,
+ rewrite=rewrite,
+ is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite),
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
train_params=self.optimizer_configuration.train_params,
)
if orig_vs_repl_stats:
- orig_vs_repl = ["Replaced sub-graph only"] + [
- f"{stat:.3f}" for stat in orig_vs_repl_stats
- ]
+ model_stats: list = []
+ cp_param = self.optimizer_configuration.train_params.checkpoint_at
+ checkpoints = (
+ [
+ "At checkpoint " + str(checkpoint) + " steps"
+ for checkpoint in cp_param
+ ]
+ if cp_param
+ else []
+ )
+ checkpoints.append("All Steps")
+ for checkpoint, orig_vs_repl_stat in zip(checkpoints, orig_vs_repl_stats):
+ model_stats.append(
+ ["Replaced sub-graph: " + checkpoint]
+ + [f"{stat:.3f}" for stat in orig_vs_repl_stat]
+ )
total = ["Total model"] + [f"{stat:.3f}" for stat in total_stats]
notes = (
"These metrics show the difference between original model\n"
@@ -178,19 +319,20 @@ class RewritingOptimizer(Optimizer):
table = Table(
columns=[
Column(
- "Original vs. optimized",
+ "Original vs. Optimized",
alias="metric",
fmt=Format(wrap_width=40),
),
Column("MAE", alias="value", fmt=Format(wrap_width=15)),
Column("NRMSE", alias="value", fmt=Format(wrap_width=15)),
],
- rows=[orig_vs_repl, total],
+ rows=[*model_stats, total],
name="Rewrite performance metrics",
alias="rewrite_performance_metrics",
notes=notes,
)
logger.info(table.to_plain_text(show_title=True))
+ self.model = TFLiteModel(tmp_output)
def get_model(self) -> TFLiteModel:
"""Return optimized model."""
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 60c39ae..4204978 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
+# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
from __future__ import annotations
@@ -22,7 +23,6 @@ from typing import Literal
import numpy as np
import tensorflow as tf
-import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from numpy.random import Generator
@@ -62,7 +62,7 @@ LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule)
class TrainingParameters:
"""Define default parameters for the training."""
- augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"]
+ augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["none"]
batch_size: int = 32
steps: int = 48000
learning_rate: float = 1e-3
@@ -73,12 +73,13 @@ class TrainingParameters:
checkpoint_at: list | None = None
-def train(
+def train( # pylint: disable=too-many-arguments
source_model: str,
unmodified_model: Any,
output_model: str,
input_tfrec: str,
- replace_fn: Callable,
+ rewrite: Callable,
+ is_qat: bool,
input_tensors: list,
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
@@ -118,7 +119,8 @@ def train(
train_dir=train_dir,
baseline_dir=unmodified_model_dir_path,
output_filename=Path(train_dir, "new.tflite"),
- replace_fn=replace_fn,
+ rewrite=rewrite,
+ is_qat=is_qat,
train_params=train_params,
)
@@ -145,7 +147,8 @@ def train(
# Assess the output diff between the parts after the rewrite subgraph
# in original and optimized model
optimized_end_path = Path(train_dir, "optimized_end.tfrec")
- end_path = Path(train_dir, "end.tfrec")
+ optimized_end_path_dequant = Path(train_dir, "optimized_end_dequant.tfrec")
+ end_path = Path(train_dir, "end_dequant.tfrec")
record_model(
str(input_tfrec),
@@ -153,16 +156,18 @@ def train(
optimized_end_path,
num_procs=train_params.num_procs,
num_threads=train_params.num_threads,
+ dequantize_output=True,
)
- mae, nrmse = diff_stats(end_path, str(optimized_end_path))
+
+ mae, nrmse = diff_stats(end_path, optimized_end_path_dequant)
if unmodified_model_dir:
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
- return (results if train_params.checkpoint_at else results[0]), [
+ return results, [
mae,
nrmse,
- ] # only return a list if multiple checkpoints are asked for
+ ]
def eval_in_dir(
@@ -177,24 +182,27 @@ def eval_in_dir(
model_input = (
model_input_path
if model_input_path.exists()
- else ExtractPaths.tfrec.input(target_dir, False)
+ else ExtractPaths.tfrec.input(target_dir, True)
)
output = (
model_output_path
if model_output_path.exists()
- else ExtractPaths.tfrec.output(target_dir, False)
+ else ExtractPaths.tfrec.output(target_dir, True)
)
with tempfile.TemporaryDirectory() as tmp_dir:
predict = Path(tmp_dir, "predict.tfrec")
+ predict_dequant = Path(tmp_dir, "predict_dequant.tfrec")
record_model(
str(model_input),
new_part,
str(predict),
num_procs=num_procs,
num_threads=num_threads,
+ dequantize_output=True,
+ quantize_input=True,
)
- mae, nrmse = diff_stats(str(output), str(predict))
+ mae, nrmse = diff_stats(str(output), predict_dequant)
return mae, nrmse
@@ -247,7 +255,7 @@ def set_up_data_pipeline(
augmentations: tuple[float | None, float | None],
steps: int,
batch_size: int = 32,
-) -> tf.data.Dataset:
+) -> tuple[tf.data.Dataset, int]:
"""Create a data pipeline for training of the replacement model."""
_check_model_compatibility(teacher, replace)
@@ -338,14 +346,15 @@ def set_up_data_pipeline(
dataset = dataset.map(restore_shapes)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
- return dataset
+ return dataset, steps_per_epoch
def train_in_dir(
train_dir: str,
baseline_dir: Any,
output_filename: Path,
- replace_fn: Callable,
+ rewrite: Callable,
+ is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
@@ -370,7 +379,7 @@ def train_in_dir(
if model_is_quantized:
replace.check_datatypes(np.int8)
- dataset = set_up_data_pipeline(
+ dataset, steps_per_epoch = set_up_data_pipeline(
teacher,
replace,
train_dir,
@@ -380,15 +389,15 @@ def train_in_dir(
)
input_shape = teacher.shape_from_name[input_name][1:]
- output_shape = teacher.shape_from_name[output_name][1:]
- model = replace_fn(input_shape, output_shape)
+ output_shape = teacher.shape_from_name[output_name][1:]
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = keras.losses.MeanSquaredError()
- if model_is_quantized:
- model = tfmot.quantization.keras.quantize_model(model)
- model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
+
+ model = create_model(
+ rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized
+ )
logger.info(model.summary())
@@ -428,11 +437,130 @@ def train_in_dir(
elif train_params.learning_rate_schedule == "constant":
callbacks = []
- output_filenames = []
+ callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined]
+ output_filenames: list = []
checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
train_params.steps
]
+ model, output_filenames = model_fit(
+ model,
+ train_params,
+ checkpoints,
+ optimizer,
+ dataset,
+ callbacks,
+ output_filename,
+ rewrite,
+ replace,
+ input_name,
+ output_name,
+ model_is_quantized,
+ output_filenames,
+ input_shape,
+ output_shape,
+ loss_fn,
+ steps_per_epoch,
+ post_process=True,
+ )
+
+ # Placeholder for now, will be parametrized later (MLIA-1114)
+ # rewrite.check_optimization( # type: ignore[attr-defined]
+ # model, number_of_clusters=32
+ # )
+ if model_is_quantized and is_qat:
+ model = rewrite.preserved_quantize(model) # type: ignore[attr-defined]
+ checkpoints = (
+ train_params.checkpoint_at if train_params.checkpoint_at else []
+ ) + [train_params.steps]
+ output_filenames = []
+
+ if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined]
+ rewrite.training_callbacks() # type: ignore[attr-defined]
+ ).issubset(callbacks):
+ callbacks.pop(-1)
+
+ optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
+ model = model_compile(model, optimizer, loss_fn)
+
+ model, output_filenames = model_fit(
+ model,
+ train_params,
+ checkpoints,
+ optimizer,
+ dataset,
+ callbacks,
+ output_filename,
+ rewrite,
+ replace,
+ input_name,
+ output_name,
+ model_is_quantized,
+ output_filenames,
+ input_shape,
+ output_shape,
+ loss_fn,
+ steps_per_epoch,
+ )
+ # Placeholder for now, will be parametrized later (MLIA-1114)
+ # rewrite.check_optimization( # type: ignore[attr-defined]
+ # model, number_of_clusters=32
+ # )
+
+ teacher.close()
+ return output_filenames
+
+def model_compile(
+ model: keras.Model,
+ optimizer: keras.optimizers.Nadam,
+ loss_fn: keras.losses.Loss,
+) -> keras.Model:
+ """Compiles a tflite model."""
+ model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
+ return model
+
+
+def create_model( # pylint: disable=too-many-arguments
+ rewrite: Callable,
+ input_shape: int,
+ output_shape: int,
+ optimizer: Callable,
+ loss_fn: Callable,
+ model_is_quantized: bool,
+ model_to_load_from: keras.model | None = None,
+) -> keras.Model:
+ """Create a model, optionally from another."""
+ model = rewrite(input_shape, output_shape)
+ if model_is_quantized:
+ model = rewrite.quantize(model) # type: ignore[attr-defined]
+ model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn)
+ if model_to_load_from:
+ model.set_weights(model_to_load_from.get_weights())
+ return model
+
+
+def model_fit( # pylint: disable=too-many-arguments
+ model: keras.Model,
+ train_params: TrainingParameters,
+ checkpoints: list,
+ optimizer: tf.optimizers.Nadam,
+ dataset: tf.data.Dataset,
+ callbacks: list,
+ output_filename: Path,
+ rewrite: Callable,
+ replace: TFLiteModel,
+ input_name: str,
+ output_name: str,
+ model_is_quantized: bool,
+ output_filenames: list,
+ input_shape: int,
+ output_shape: int,
+ loss_fn: Callable,
+ steps_per_epoch: int,
+ post_process: bool = False,
+) -> keras.Model:
+ """Train a tflite model."""
+ steps_so_far = 0
while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
lr_start = optimizer.learning_rate.numpy()
@@ -452,15 +580,43 @@ def train_in_dir(
)
if steps_so_far < train_params.steps:
- filename, ext = Path(output_filename).parts[1:]
- checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
+ filename = Path(output_filename).stem
+ filename_dir = Path(output_filename).parent.as_posix()
+ ext = Path(output_filename).suffix
+ checkpoint_filename = (
+ filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext
+ )
+ # If post processing we are stripping the clustering/pruning layers below
+ # Thus copy the model before saving, so training can continue
+ if post_process:
+ model_to_save = create_model(
+ rewrite,
+ input_shape,
+ output_shape,
+ optimizer,
+ loss_fn,
+ model_is_quantized,
+ model_to_load_from=model,
+ )
+ else:
+ model_to_save = model
else:
checkpoint_filename = str(output_filename)
+ logger.info("Evaluate final Keras Model using %d steps", steps_per_epoch)
+ model.evaluate(
+ dataset,
+ steps=steps_per_epoch,
+ )
+ model_to_save = model
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):
+ if post_process:
+ model_to_save = rewrite.post_process( # type: ignore[attr-defined]
+ model_to_save
+ )
save_as_tflite(
- model,
+ model_to_save,
checkpoint_filename,
input_name,
replace.shape_from_name[input_name],
@@ -470,8 +626,7 @@ def train_in_dir(
)
output_filenames.append(checkpoint_filename)
- teacher.close()
- return output_filenames
+ return model_to_save, output_filenames
def save_as_tflite(
diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py
new file mode 100644
index 0000000..81bfd90
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/clustering.py
@@ -0,0 +1,51 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Rewrite functions used to return layers ready for clustering."""
+from typing import Any
+
+import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+
+
+def fc_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for clustering."""
+ rewrite_params = {
+ "number_of_clusters": 4,
+ "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
+ }
+ model = tfmot.clustering.keras.cluster_weights(
+ to_cluster=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Flatten(),
+ keras.layers.Dense(units=output_shape),
+ ]
+ ),
+ **rewrite_params
+ )
+ return model
+
+
+def conv2d_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for clustering."""
+ rewrite_params = {
+ "number_of_clusters": 4,
+ "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
+ }
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape, output_shape=output_shape
+ )
+ model = tfmot.clustering.keras.cluster_weights(
+ to_cluster=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ keras.layers.ReLU(),
+ ]
+ ),
+ **rewrite_params
+ )
+ return model
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
index 041ce85..92195d1 100644
--- a/src/mlia/nn/rewrite/library/fc_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_layer.py
@@ -1,13 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Example rewrite with one fully connected layer."""
+"""Rewrite function used to return regular layers."""
from typing import Any
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model:
- """Generate TensorFlow Lite model for rewrite."""
+def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Fully connected TensorFlow Lite model for rewrite."""
model = keras.Sequential(
(
keras.layers.InputLayer(input_shape=input_shape),
diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py
new file mode 100644
index 0000000..4f08170
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/helper_functions.py
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Helper functions for the rewrite library."""
+import math
+from typing import Any
+
+import numpy as np
+
+
+def compute_conv2d_parameters(
+ input_shape: np.ndarray, output_shape: np.ndarray
+) -> dict[str, Any]:
+ """Compute needed kernel size and strides for a given input and output_shape."""
+ input_shape = input_shape.tolist()
+ output_shape = output_shape.tolist()
+ assert len(input_shape) == 3
+ assert len(output_shape) == 3
+ num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1]
+ padding = "valid"
+ kernel_size = (3, 3)
+ stride_h = round(input_shape[0] / output_shape[0])
+ check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1
+ stride_w = round(input_shape[1] / output_shape[1])
+ check_output_size_w = math.floor((input_shape[1] - kernel_size[1]) / stride_w) + 1
+ if check_output_size_h != output_shape[0] or check_output_size_w != output_shape[1]:
+ padding = "same"
+ return {
+ "filters": num_filters,
+ "kernel_size": kernel_size,
+ "padding": padding,
+ "strides": (stride_h, stride_w),
+ }
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
new file mode 100644
index 0000000..745fa8b
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Rewrite functions used to return layers ready for sparse pruning."""
+from typing import Any
+
+import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+
+
+def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for sparse pruning."""
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ ]
+ ),
+ sparsity_m_by_n=(2, 4),
+ )
+
+ return model
+
+
+def conv2d_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for sparse pruning."""
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape, output_shape=output_shape
+ )
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ keras.layers.ReLU(),
+ ]
+ ),
+ sparsity_m_by_n=(2, 4),
+ )
+
+ return model
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 81a614f..b61e713 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
model: keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
@@ -151,7 +151,7 @@ def get_optimizer(
def _get_optimizer(
model: keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
@@ -173,22 +173,17 @@ def _get_optimizer(
def _get_rewrite_params(
- training_parameters: list[dict | None] | None = None,
-) -> list:
+ training_parameters: dict | None = None,
+) -> TrainingParameters:
"""Get the rewrite TrainingParameters.
Return the default constructed TrainingParameters() per default, but can be
overwritten in the unit tests.
"""
- if training_parameters is None:
- return [TrainingParameters()]
+ if not training_parameters:
+ return TrainingParameters()
- if training_parameters[0] is None:
- train_params = TrainingParameters()
- else:
- train_params = TrainingParameters(**training_parameters[0])
-
- return [train_params]
+ return TrainingParameters(**training_parameters)
def _get_optimizer_configuration(
@@ -196,7 +191,7 @@ def _get_optimizer_configuration(
optimization_target: int | float | str,
layers_to_optimize: list[str] | None = None,
dataset: Path | None = None,
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
@@ -222,7 +217,7 @@ def _get_optimizer_configuration(
optimization_target=str(optimization_target),
layers_to_optimize=layers_to_optimize,
dataset=dataset,
- train_params=rewrite_params[0],
+ train_params=rewrite_params,
)
raise ConfigurationError(
diff --git a/src/mlia/resources/optimization_profiles/optimization.toml b/src/mlia/resources/optimization_profiles/optimization.toml
index 623a763..42b64f0 100644
--- a/src/mlia/resources/optimization_profiles/optimization.toml
+++ b/src/mlia/resources/optimization_profiles/optimization.toml
@@ -7,5 +7,6 @@ learning_rate = 1e-3
show_progress = true
steps = 48000
learning_rate_schedule = "cosine"
+augmentations = "gaussian"
num_procs = 1
num_threads = 0
diff --git a/src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml b/src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml
new file mode 100644
index 0000000..5d1f917
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml
@@ -0,0 +1,13 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[training]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.1
+augmentations.mixup_strength = 0.1
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
index 8c5d184..a139a7d 100644
--- a/src/mlia/target/common/optimization.py
+++ b/src/mlia/target/common/optimization.py
@@ -17,6 +17,7 @@ from mlia.core.errors import FunctionalityNotSupportedError
from mlia.core.performance import estimate_performance
from mlia.core.performance import P
from mlia.core.performance import PerformanceEstimator
+from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.select import get_optimizer
from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.config import get_keras_model
@@ -86,7 +87,7 @@ class OptimizingDataCollector(ContextAwareDataCollector):
def optimize_model(
self,
opt_settings: list[OptimizationSettings],
- training_parameters: list[dict | None],
+ training_parameters: dict | None,
model: KerasModel | TFLiteModel,
) -> Any:
"""Run optimization."""
@@ -123,12 +124,12 @@ class OptimizingDataCollector(ContextAwareDataCollector):
context=context,
)
- def _get_training_settings(self, context: Context) -> list[dict]:
+ def _get_training_settings(self, context: Context) -> dict:
"""Get optimization settings."""
return self.get_parameter( # type: ignore
OptimizingDataCollector.name(),
"training_parameters",
- expected_type=list,
+ expected_type=dict,
expected=False,
context=context,
)
@@ -218,7 +219,54 @@ _DEFAULT_OPTIMIZATION_TARGETS = [
]
-def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -> None:
+def parse_augmentations(
+ augmentations: dict | str | None,
+) -> tuple[float | None, float | None]:
+ """Parse augmentations from optimization-profile and return a valid tuple."""
+ if isinstance(augmentations, str):
+ match_augmentation = AUGMENTATION_PRESETS.get(augmentations)
+ if not match_augmentation:
+ match_augmentation = AUGMENTATION_PRESETS["none"]
+ return match_augmentation
+ if isinstance(augmentations, dict):
+ augmentation_keys_test_for_valid = list(augmentations.keys())
+ augmentation_keys_test_for_float = list(augmentations.keys())
+ valid_keys = ["mixup_strength", "gaussian_strength"]
+ tuple_to_return = []
+ for valid_key in valid_keys.copy():
+ if augmentations.get(valid_key):
+ del augmentation_keys_test_for_valid[
+ augmentation_keys_test_for_valid.index(valid_key)
+ ]
+ if isinstance(augmentations.get(valid_key), float):
+ tuple_to_return.append(augmentations[valid_key])
+ del augmentation_keys_test_for_float[
+ augmentation_keys_test_for_float.index(valid_key)
+ ]
+ else:
+ tuple_to_return.append(None)
+ else:
+ tuple_to_return.append(None)
+
+ if len(augmentation_keys_test_for_valid) > 0:
+ logger.warning(
+ "Warning! Expected augmentation parameters to be 'gaussian_strength' "
+ "and/or 'mixup_strength' got %s. "
+ "Removing invalid augmentations",
+ str(list(augmentations.keys())),
+ )
+ elif len(augmentation_keys_test_for_float) > 0:
+ logger.warning(
+ "Warning! Not all augmentation parameters were floats, "
+ "removing non-float augmentations"
+ )
+ return (tuple_to_return[0], tuple_to_return[1])
+ return AUGMENTATION_PRESETS["none"]
+
+
+def add_common_optimization_params( # pylint: disable=too-many-branches
+ advisor_parameters: dict, extra_args: dict
+) -> None:
"""Add common optimization parameters."""
optimization_targets = extra_args.get("optimization_targets")
if not optimization_targets:
@@ -228,18 +276,22 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
raise TypeError("Optimization targets value has wrong format.")
rewrite_parameters = extra_args.get("optimization_profile")
- if not rewrite_parameters:
- training_parameters = None
- else:
+ training_parameters = None
+ if rewrite_parameters:
if not isinstance(rewrite_parameters, dict):
raise TypeError("Training Parameter values has wrong format.")
training_parameters = extra_args["optimization_profile"].get("training")
+ if training_parameters:
+ training_parameters["augmentations"] = parse_augmentations(
+ training_parameters.get("augmentations")
+ )
+
advisor_parameters.update(
{
"common_optimizations": {
"optimizations": [optimization_targets],
- "training_parameters": [training_parameters],
+ "training_parameters": training_parameters,
},
}
)