aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.pre-commit-config.yaml8
-rw-r--r--pre_commit_hooks/check_copyright_header.py34
-rw-r--r--pyproject.toml4
-rw-r--r--setup.cfg8
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py201
-rw-r--r--src/mlia/nn/rewrite/core/train.py178
-rw-r--r--src/mlia/nn/rewrite/library/fc_clustering_layer.py26
-rw-r--r--src/mlia/nn/rewrite/library/fc_sparsity24_layer.py23
-rw-r--r--src/mlia/nn/select.py23
-rw-r--r--src/mlia/target/common/optimization.py13
-rw-r--r--tests/test_cli_commands.py29
-rw-r--r--tests/test_common_optimization.py18
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py166
-rw-r--r--tests/test_nn_rewrite_core_train.py26
-rw-r--r--tests/test_nn_select.py12
-rw-r--r--tests/test_target_cortex_a_advisor.py2
-rw-r--r--tests/test_target_tosa_advisor.py2
17 files changed, 629 insertions, 144 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b601b03..3788326 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -124,3 +124,11 @@ repos:
hooks:
- id: commitizen-branch
args: [--rev-range, HEAD~1..HEAD]
+
+- repo: local
+ hooks:
+ - id: check-copyright-header
+ name: Check Copyright header years
+ entry: python pre_commit_hooks/check_copyright_header.py
+ language: python
+ verbose: true
diff --git a/pre_commit_hooks/check_copyright_header.py b/pre_commit_hooks/check_copyright_header.py
new file mode 100644
index 0000000..ded7675
--- /dev/null
+++ b/pre_commit_hooks/check_copyright_header.py
@@ -0,0 +1,34 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Pre-commit hook that checks the current year is in the Copyright header of a file.
+
+If the header is out of date it will print a warning.
+"""
+import datetime
+import subprocess # nosec
+
+
+class CopyrightHeaderChecker:
+ """Class that wraps the checker for the Copyright header."""
+
+ def check_files_have_updated_header(self, filenames: list) -> None:
+ """Check whether input files have the current year in the copyright string."""
+ current_year = str(datetime.datetime.now().year)
+ for filename in filenames:
+ with open(filename, encoding="utf-8") as file:
+ first_line = file.readline()
+ second_line = file.readline()
+ if filename.endswith(".md") and current_year not in second_line:
+ print(f"WARNING: The Copyright header of {filename} is out of date!")
+
+ if not filename.endswith(".md") and current_year not in first_line:
+ print(f"WARNING: The Copyright header of {filename} is out of date!")
+
+
+if __name__ == "__main__":
+ staged_files = (
+ subprocess.check_output(["git", "diff", "--cached", "--name-only"]) # nosec
+ .decode()
+ .splitlines()
+ )
+ CopyrightHeaderChecker().check_files_have_updated_header(filenames=staged_files)
diff --git a/pyproject.toml b/pyproject.toml
index 0c4cc8c..cf2db54 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -90,6 +90,6 @@ update_changelog_on_bump = true
schema_pattern = "(?s)(build|ci|docs|feat|fix|perf|refactor|style|test)(\\(\\S+\\))?!?:( [A-Z][^\\n\\r]+)((\\n\\n.*)|(\\s*))?$"
schema = "<type>(<scope>): <Subject-capitalized>\n<BLANK LINE>\n<body>\n<BLANK LINE>\n(BREAKING CHANGE: )<footer>"
# Commit parser is used to render the commits for RELEASES.md
-commit_parser = "^((?P<change_type>feat|fix|refactor|perf|BREAKING CHANGE)(?:\\((?P<scope>[^()\\r\\n]*)\\)|\\()?(?P<breaking>!)?|\\w+!):\\s(?P<message>.*)?"
+commit_parser = "^((?P<change_type>build|ci|docs|feat|fix|perf|refactor|style|test|BREAKING CHANGE)(?:\\((?P<scope>[^()\\r\\n]*)\\)|\\()?(?P<breaking>!)?|\\w+!):\\s(?P<message>.*)?"
# Change type map to render the title for that category as per {tag:title}
-change_type_map = {'feat' = 'Feature changes', 'fix' = 'Bug fix', 'perf' = 'Performance improvements'}
+change_type_map = {'feat' = 'Feature changes', 'fix' = 'Bug fix', 'perf' = 'Performance improvements', 'build' = 'Development changes'}
diff --git a/setup.cfg b/setup.cfg
index 6917747..0714caf 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-FileCopyrightText: Copyright (c) 2020 Troy Comi
# SPDX-License-Identifier: Apache-2.0 AND MIT
@@ -28,8 +28,12 @@ python_requires = >=3.9.0
package_dir =
= src
packages = find_namespace:
+# Pinning tensorflow & h5py to work around build issue on aarch64:
+# https://github.com/h5py/h5py/issues/2408
+# Idea is to unpin these when it's resolved.
install_requires =
- tensorflow~=2.15.1
+ tensorflow==2.15.1
+ h5py==3.10.0
tensorflow-model-optimization~=0.7.5
ethos-u-vela~=3.11.0
flaky~=3.7.0
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index c7d13ba..e2c097c 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -3,15 +3,17 @@
"""Contains class RewritingOptimizer to replace a subgraph/layer of a model."""
from __future__ import annotations
-import importlib
import logging
import tempfile
+from abc import ABC
+from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import cast
+import numpy as np
+import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.errors import ConfigurationError
@@ -22,6 +24,13 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
+from mlia.nn.rewrite.library.fc_clustering_layer import (
+ get_keras_model_clus as fc_clustering_rewrite,
+)
+from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
+from mlia.nn.rewrite.library.fc_sparsity24_layer import (
+ get_keras_model as fc_rewrite_sparsity24,
+)
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
@@ -30,8 +39,8 @@ logger = logging.getLogger(__name__)
RewriteCallable = Callable[[Any, Any], keras.Model]
-class Rewrite:
- """Graph rewrite logic to be used by RewritingOptimizer."""
+class Rewrite(ABC):
+ """Abstract class for rewrite logic to be used by RewritingOptimizer."""
def __init__(self, name: str, rewrite_fn: RewriteCallable):
"""Initialize a Rewrite instance with a given name and an optional function."""
@@ -45,34 +54,138 @@ class Rewrite:
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ return model
-@dataclass
-class DynamicallyLoadedRewrite(Rewrite):
- """A rewrite which can load logic from a function loaded dynamically."""
+ @abstractmethod
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
- def __init__(self, name: str, function_name: str):
- """Initialize."""
+ @abstractmethod
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
- def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model:
- """Load the function from a file dynamically."""
- self.load_function(function_name)
- return self.function(input_shape, output_shape)
+ @abstractmethod
+ def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ """Check if the optimization has produced the correct result."""
- super().__init__(name, load_and_run)
- def load_function(self, function_name: str) -> RewriteCallable:
- """Return the rewrite function. Import using the auto_load attr if necessary."""
- try:
- name_parts = function_name.split(".")
- module_name = ".".join(name_parts[:-1])
- fn_name = name_parts[-1]
- module = importlib.import_module(module_name)
- self.function = cast(RewriteCallable, getattr(module, fn_name))
- return self.function
- except Exception as ex:
- raise RuntimeError(
- f"Unable to load rewrite function '{function_name}' for '{self.name}'."
- ) from ex
+class GenericRewrite(Rewrite):
+ """Graph rewrite logic for fully-connected rewrite."""
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ return tfmot.quantization.keras.quantize_model(model)
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Not needed here."""
+ return True
+
+
+class QuantizeAwareTrainingRewrite(Rewrite, ABC):
+ """Abstract class for rewrites that perform QAT."""
+
+ @abstractmethod
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Apply optimization-aware quantization to a given model."""
+ return model
+
+
+class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
+ """Graph rewrite logic for fully-connected-sparsity24 rewrite."""
+
+ pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
+
+ strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Skip quantization when using pruning rewrite."""
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return pruning-specific rewrite callback."""
+ return [self.pruning_callback()]
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Pruning-specific post-processing rewrite options."""
+ return self.strip_pruning_wrapper(model)
+
+ def preserved_quantize(
+ self,
+ model: keras.Model,
+ ) -> keras.Model:
+ """Apply pruning-preserved quantization training to a given model."""
+ model = tfmot.quantization.keras.quantize_annotate_model(model)
+ model = tfmot.quantization.keras.quantize_apply(
+ model,
+ tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
+ )
+
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Not needed here."""
+ return True
+
+
+class ClusteringRewrite(QuantizeAwareTrainingRewrite):
+ """Graph clustering rewrite logic to be used by RewritingOptimizer."""
+
+ _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
+
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Apply clustering-preserved quantization to a given model."""
+ quant_aware_model = tfmot.quantization.keras.quantize_annotate_model(model)
+ cqat_model = tfmot.quantization.keras.quantize_apply(
+ quant_aware_model,
+ tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(),
+ )
+ return cqat_model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Check if clustering has produced the correct result."""
+ number_of_clusters = kwargs.get("number_of_clusters")
+ if not number_of_clusters:
+ raise ValueError(
+ """
+ Expected check_preserved_quantize to have argument number_of_clusters.
+ """
+ )
+
+ for layer in model.layers:
+ for weight in layer.weights:
+ if "kernel" in weight.name:
+ if "kernel_min" in weight.name or "kernel_max" in weight.name:
+ continue
+ number_of_found_clusters = len(np.unique(weight))
+ if number_of_found_clusters != number_of_clusters:
+ logger.warning(
+ "\nWARNING: Expected %d cluster(s), found %d "
+ "cluster(s) in layer %s for weight %s \n",
+ number_of_clusters,
+ number_of_found_clusters,
+ layer.name,
+ weight.name,
+ )
+ return False
+ return True
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return the clustering stripped model."""
+ return self._strip_clustering_wrapper(model)
class RewriteRegistry(Registry[Rewrite]):
@@ -113,9 +226,9 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
- DynamicallyLoadedRewrite(
- "fully-connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model"
- )
+ GenericRewrite("fully-connected", fc_rewrite),
+ Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
+ ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite),
]
)
@@ -149,22 +262,35 @@ class RewritingOptimizer(Optimizer):
raise ConfigurationError(
"Input and output tensor names need to be set for rewrite."
)
-
orig_vs_repl_stats, total_stats = train(
source_model=tflite_model,
unmodified_model=tflite_model if use_unmodified_model else None,
output_model=str(tmp_output),
input_tfrec=str(tfrecord),
- replace_fn=rewrite,
+ rewrite=rewrite,
+ is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite),
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
train_params=self.optimizer_configuration.train_params,
)
if orig_vs_repl_stats:
- orig_vs_repl = ["Replaced sub-graph only"] + [
- f"{stat:.3f}" for stat in orig_vs_repl_stats
- ]
+ model_stats: list = []
+ cp_param = self.optimizer_configuration.train_params.checkpoint_at
+ checkpoints = (
+ [
+ "At checkpoint " + str(checkpoint) + " steps"
+ for checkpoint in cp_param
+ ]
+ if cp_param
+ else []
+ )
+ checkpoints.append("All Steps")
+ for checkpoint, orig_vs_repl_stat in zip(checkpoints, orig_vs_repl_stats):
+ model_stats.append(
+ ["Replaced sub-graph: " + checkpoint]
+ + [f"{stat:.3f}" for stat in orig_vs_repl_stat]
+ )
total = ["Total model"] + [f"{stat:.3f}" for stat in total_stats]
notes = (
"These metrics show the difference between original model\n"
@@ -178,19 +304,20 @@ class RewritingOptimizer(Optimizer):
table = Table(
columns=[
Column(
- "Original vs. optimized",
+ "Original vs. Optimized",
alias="metric",
fmt=Format(wrap_width=40),
),
Column("MAE", alias="value", fmt=Format(wrap_width=15)),
Column("NRMSE", alias="value", fmt=Format(wrap_width=15)),
],
- rows=[orig_vs_repl, total],
+ rows=[*model_stats, total],
name="Rewrite performance metrics",
alias="rewrite_performance_metrics",
notes=notes,
)
logger.info(table.to_plain_text(show_title=True))
+ self.model = TFLiteModel(tmp_output)
def get_model(self) -> TFLiteModel:
"""Return optimized model."""
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 60c39ae..88efa23 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
+# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
from __future__ import annotations
@@ -22,7 +23,6 @@ from typing import Literal
import numpy as np
import tensorflow as tf
-import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from numpy.random import Generator
@@ -73,12 +73,13 @@ class TrainingParameters:
checkpoint_at: list | None = None
-def train(
+def train( # pylint: disable=too-many-arguments
source_model: str,
unmodified_model: Any,
output_model: str,
input_tfrec: str,
- replace_fn: Callable,
+ rewrite: Callable,
+ is_qat: bool,
input_tensors: list,
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
@@ -118,7 +119,8 @@ def train(
train_dir=train_dir,
baseline_dir=unmodified_model_dir_path,
output_filename=Path(train_dir, "new.tflite"),
- replace_fn=replace_fn,
+ rewrite=rewrite,
+ is_qat=is_qat,
train_params=train_params,
)
@@ -159,10 +161,10 @@ def train(
if unmodified_model_dir:
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
- return (results if train_params.checkpoint_at else results[0]), [
+ return results, [
mae,
nrmse,
- ] # only return a list if multiple checkpoints are asked for
+ ]
def eval_in_dir(
@@ -345,7 +347,8 @@ def train_in_dir(
train_dir: str,
baseline_dir: Any,
output_filename: Path,
- replace_fn: Callable,
+ rewrite: Callable,
+ is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
@@ -380,15 +383,15 @@ def train_in_dir(
)
input_shape = teacher.shape_from_name[input_name][1:]
- output_shape = teacher.shape_from_name[output_name][1:]
- model = replace_fn(input_shape, output_shape)
+ output_shape = teacher.shape_from_name[output_name][1:]
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = keras.losses.MeanSquaredError()
- if model_is_quantized:
- model = tfmot.quantization.keras.quantize_model(model)
- model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
+
+ model = create_model(
+ rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized
+ )
logger.info(model.summary())
@@ -428,11 +431,127 @@ def train_in_dir(
elif train_params.learning_rate_schedule == "constant":
callbacks = []
- output_filenames = []
+ callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined]
+ output_filenames: list = []
checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
train_params.steps
]
+ model, output_filenames = model_fit(
+ model,
+ train_params,
+ checkpoints,
+ optimizer,
+ dataset,
+ callbacks,
+ output_filename,
+ rewrite,
+ replace,
+ input_name,
+ output_name,
+ model_is_quantized,
+ output_filenames,
+ input_shape,
+ output_shape,
+ loss_fn,
+ post_process=True,
+ )
+
+ # Placeholder for now, will be parametrized later (MLIA-1114)
+ # rewrite.check_optimization( # type: ignore[attr-defined]
+ # model, number_of_clusters=32
+ # )
+ if model_is_quantized and is_qat:
+ model = rewrite.preserved_quantize(model) # type: ignore[attr-defined]
+ checkpoints = (
+ train_params.checkpoint_at if train_params.checkpoint_at else []
+ ) + [train_params.steps]
+ output_filenames = []
+
+ if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined]
+ rewrite.training_callbacks() # type: ignore[attr-defined]
+ ).issubset(callbacks):
+ callbacks.pop(-1)
+
+ optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
+ model = model_compile(model, optimizer, loss_fn)
+
+ model, output_filenames = model_fit(
+ model,
+ train_params,
+ checkpoints,
+ optimizer,
+ dataset,
+ callbacks,
+ output_filename,
+ rewrite,
+ replace,
+ input_name,
+ output_name,
+ model_is_quantized,
+ output_filenames,
+ input_shape,
+ output_shape,
+ loss_fn,
+ )
+ # Placeholder for now, will be parametrized later (MLIA-1114)
+ # rewrite.check_optimization( # type: ignore[attr-defined]
+ # model, number_of_clusters=32
+ # )
+
+ teacher.close()
+ return output_filenames
+
+def model_compile(
+ model: keras.Model,
+ optimizer: keras.optimizers.Nadam,
+ loss_fn: keras.losses.Loss,
+) -> keras.Model:
+ """Compiles a tflite model."""
+ model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
+ return model
+
+
+def create_model( # pylint: disable=too-many-arguments
+ rewrite: Callable,
+ input_shape: int,
+ output_shape: int,
+ optimizer: Callable,
+ loss_fn: Callable,
+ model_is_quantized: bool,
+ model_to_load_from: keras.model | None = None,
+) -> keras.Model:
+ """Create a model, optionally from another."""
+ model = rewrite(input_shape, output_shape)
+ if model_is_quantized:
+ model = rewrite.quantize(model) # type: ignore[attr-defined]
+ model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn)
+ if model_to_load_from:
+ model.set_weights(model_to_load_from.get_weights())
+ return model
+
+
+def model_fit( # pylint: disable=too-many-arguments
+ model: keras.Model,
+ train_params: TrainingParameters,
+ checkpoints: list,
+ optimizer: tf.optimizers.Nadam,
+ dataset: tf.data.Dataset,
+ callbacks: list,
+ output_filename: Path,
+ rewrite: Callable,
+ replace: TFLiteModel,
+ input_name: str,
+ output_name: str,
+ model_is_quantized: bool,
+ output_filenames: list,
+ input_shape: int,
+ output_shape: int,
+ loss_fn: Callable,
+ post_process: bool = False,
+) -> keras.Model:
+ """Train a tflite model."""
+ steps_so_far = 0
while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
lr_start = optimizer.learning_rate.numpy()
@@ -452,15 +571,39 @@ def train_in_dir(
)
if steps_so_far < train_params.steps:
- filename, ext = Path(output_filename).parts[1:]
- checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
+ filename = Path(output_filename).stem
+ filename_dir = Path(output_filename).parent.as_posix()
+ ext = Path(output_filename).suffix
+ checkpoint_filename = (
+ filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext
+ )
+ # If post processing we are stripping the clustering/pruning layers below
+ # Thus copy the model before saving, so training can continue
+ if post_process:
+ model_to_save = create_model(
+ rewrite,
+ input_shape,
+ output_shape,
+ optimizer,
+ loss_fn,
+ model_is_quantized,
+ model_to_load_from=model,
+ )
+ else:
+ model_to_save = model
else:
checkpoint_filename = str(output_filename)
+ model_to_save = model
+
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):
+ if post_process:
+ model_to_save = rewrite.post_process( # type: ignore[attr-defined]
+ model_to_save
+ )
save_as_tflite(
- model,
+ model_to_save,
checkpoint_filename,
input_name,
replace.shape_from_name[input_name],
@@ -470,8 +613,7 @@ def train_in_dir(
)
output_filenames.append(checkpoint_filename)
- teacher.close()
- return output_filenames
+ return model_to_save, output_filenames
def save_as_tflite(
diff --git a/src/mlia/nn/rewrite/library/fc_clustering_layer.py b/src/mlia/nn/rewrite/library/fc_clustering_layer.py
new file mode 100644
index 0000000..7cc383e
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/fc_clustering_layer.py
@@ -0,0 +1,26 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Example rewrite with one fully connected clustered layer."""
+from typing import Any
+
+import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+
+def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Generate TensorFlow Lite model for clustering rewrite."""
+ rewrite_params = {
+ "number_of_clusters": 32,
+ "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
+ }
+ model = tfmot.clustering.keras.cluster_weights(
+ to_cluster=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Flatten(),
+ keras.layers.Dense(units=output_shape),
+ ]
+ ),
+ **rewrite_params
+ )
+ return model
diff --git a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py b/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py
new file mode 100644
index 0000000..531b34a
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py
@@ -0,0 +1,23 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Example rewrite with one fully connected 2:4 sparsity layer."""
+from typing import Any
+
+import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+
+def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model:
+ """Generate TensorFlow Lite model for rewrite."""
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ ]
+ ),
+ sparsity_m_by_n=(2, 4),
+ )
+
+ return model
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 81a614f..b61e713 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
model: keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
@@ -151,7 +151,7 @@ def get_optimizer(
def _get_optimizer(
model: keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
@@ -173,22 +173,17 @@ def _get_optimizer(
def _get_rewrite_params(
- training_parameters: list[dict | None] | None = None,
-) -> list:
+ training_parameters: dict | None = None,
+) -> TrainingParameters:
"""Get the rewrite TrainingParameters.
Return the default constructed TrainingParameters() per default, but can be
overwritten in the unit tests.
"""
- if training_parameters is None:
- return [TrainingParameters()]
+ if not training_parameters:
+ return TrainingParameters()
- if training_parameters[0] is None:
- train_params = TrainingParameters()
- else:
- train_params = TrainingParameters(**training_parameters[0])
-
- return [train_params]
+ return TrainingParameters(**training_parameters)
def _get_optimizer_configuration(
@@ -196,7 +191,7 @@ def _get_optimizer_configuration(
optimization_target: int | float | str,
layers_to_optimize: list[str] | None = None,
dataset: Path | None = None,
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
@@ -222,7 +217,7 @@ def _get_optimizer_configuration(
optimization_target=str(optimization_target),
layers_to_optimize=layers_to_optimize,
dataset=dataset,
- train_params=rewrite_params[0],
+ train_params=rewrite_params,
)
raise ConfigurationError(
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
index 8c5d184..1423189 100644
--- a/src/mlia/target/common/optimization.py
+++ b/src/mlia/target/common/optimization.py
@@ -86,7 +86,7 @@ class OptimizingDataCollector(ContextAwareDataCollector):
def optimize_model(
self,
opt_settings: list[OptimizationSettings],
- training_parameters: list[dict | None],
+ training_parameters: dict | None,
model: KerasModel | TFLiteModel,
) -> Any:
"""Run optimization."""
@@ -123,12 +123,12 @@ class OptimizingDataCollector(ContextAwareDataCollector):
context=context,
)
- def _get_training_settings(self, context: Context) -> list[dict]:
+ def _get_training_settings(self, context: Context) -> dict:
"""Get optimization settings."""
return self.get_parameter( # type: ignore
OptimizingDataCollector.name(),
"training_parameters",
- expected_type=list,
+ expected_type=dict,
expected=False,
context=context,
)
@@ -228,9 +228,8 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
raise TypeError("Optimization targets value has wrong format.")
rewrite_parameters = extra_args.get("optimization_profile")
- if not rewrite_parameters:
- training_parameters = None
- else:
+ training_parameters = None
+ if rewrite_parameters:
if not isinstance(rewrite_parameters, dict):
raise TypeError("Training Parameter values has wrong format.")
training_parameters = extra_args["optimization_profile"].get("training")
@@ -239,7 +238,7 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
{
"common_optimizations": {
"optimizations": [optimization_targets],
- "training_parameters": [training_parameters],
+ "training_parameters": training_parameters,
},
}
)
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 9cda27c..93a05bd 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -84,6 +84,19 @@ def test_performance_unknown_target(
],
[
"ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "fully-connected-sparsity24",
+ "sequential/flatten/Reshape",
+ "StatefulPartitionedCall:0",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
True,
False,
None,
@@ -126,7 +139,8 @@ def test_performance_unknown_target(
Exception,
match=re.escape(
"Invalid rewrite target: 'random'. "
- "Supported rewrites: ['fully-connected']"
+ "Supported rewrites: ['fully-connected',"
+ " 'fully-connected-clustering', 'fully-connected-sparsity24']"
),
),
],
@@ -168,6 +182,19 @@ def test_performance_unknown_target(
),
),
],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ None,
+ True,
+ "fully-connected-clustering",
+ "sequential/flatten/Reshape",
+ "StatefulPartitionedCall:0",
+ does_not_raise(),
+ ],
],
)
def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 05a5b55..341e0d2 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -57,7 +57,7 @@ def test_optimizing_data_collector(
config_parameters={
"common_optimizations": {
"optimizations": optimizations,
- "training_parameters": [training_parameters],
+ "training_parameters": training_parameters,
}
}
)
@@ -94,7 +94,7 @@ def test_optimizing_data_collector(
collector.set_context(context)
collector.collect_data()
assert optimize_model_mock.call_args.args[0] == opt_settings[0]
- assert optimize_model_mock.call_args.args[1] == [training_parameters]
+ assert optimize_model_mock.call_args.args[1] == training_parameters
assert fake_optimizer.invocation_count == 1
@@ -158,10 +158,12 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -
]
if not extra_args.get("optimization_profile"):
- assert advisor_parameters["common_optimizations"][
- "training_parameters"
- ] == [None]
+ assert (
+ advisor_parameters["common_optimizations"]["training_parameters"]
+ is None
+ )
else:
- assert advisor_parameters["common_optimizations"][
- "training_parameters"
- ] == list(extra_args["optimization_profile"].values())
+ assert (
+ advisor_parameters["common_optimizations"]["training_parameters"]
+ == extra_args["optimization_profile"]["training"]
+ )
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index b32fafd..e502842 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -10,45 +10,102 @@ from typing import cast
from unittest.mock import MagicMock
import pytest
+import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module
+ ClusterWeights,
+)
-from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
+from mlia.nn.rewrite.core.rewrite import ClusteringRewrite
+from mlia.nn.rewrite.core.rewrite import GenericRewrite
from mlia.nn.rewrite.core.rewrite import Rewrite
from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite
from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.rewrite.core.train import train_in_dir
+from mlia.nn.rewrite.library.fc_clustering_layer import (
+ get_keras_model_clus as fc_clustering_rewrite,
+)
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
+class TestRewrite(Rewrite):
+ """Test rewrite class."""
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ return tfmot.quantization.keras.quantize_model(model)
+
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Not needed."""
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ """Not needed here."""
+ return True
+
+
def mock_rewrite_function(*_: Any) -> Any:
"""Mock function to test autoloading of rewrite functions."""
def test_rewrite() -> None:
- """Test the Rewrite class."""
+ """Test a derived Rewrite class."""
def bad_rewrite_func() -> Any:
raise NotImplementedError()
- rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func))
+ rewrite = TestRewrite(
+ "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)
+ )
with pytest.raises(RuntimeError):
rewrite((1, 2), (1, 2))
@pytest.mark.parametrize(
+ "rewrite_name, callbacks_length, instance",
+ [
+ ("fully-connected", 0, GenericRewrite),
+ ("fully-connected-clustering", 0, ClusteringRewrite),
+ ("fully-connected-sparsity24", 1, Sparsity24Rewrite),
+ ],
+)
+def test_rewrite_selection(
+ rewrite_name: str, callbacks_length: int, instance: Rewrite
+) -> None:
+ """Test that the correct rewrite class is instantiated."""
+ rewrite = RewritingOptimizer.registry.items[rewrite_name]
+ assert rewrite.name == rewrite_name
+ assert isinstance(rewrite, instance) # type: ignore
+ assert len(rewrite.training_callbacks()) == callbacks_length
+
+
+@pytest.mark.parametrize(
"rewrite_name, expected_error",
[
("fully-connected", does_not_raise()),
+ ("fully-connected-sparsity24", does_not_raise()),
+ ("fully-connected-clustering", does_not_raise()),
("random", does_not_raise()),
],
)
def test_rewrite_configuration(
test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
) -> None:
- """Test get_rewrite function only supports rewrite type fully-connected."""
+ """Test get_rewrite function only supports rewrite type fully-connected,
+ fully-connected-clustering and fully-connected-sparsity24."""
with expected_error:
config_obj = RewriteConfiguration(
rewrite_name,
@@ -63,19 +120,69 @@ def test_rewrite_configuration(
assert isinstance(rewriter_obj, RewritingOptimizer)
-def test_rewriting_optimizer(
+def test_rewrite_fully_connected_clustering() -> None:
+ """Check that model has the set number of clusters"""
+
+ rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
+ model = rewrite(input_shape=(28, 28), output_shape=10)
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model, number_of_clusters=32)
+
+
+def test_rewrite_fully_connected_clustering_error_handling() -> None:
+ """Check that model has the set number of clusters
+ and that when quantized the number of clusters
+ remain."""
+
+ rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
+ model = rewrite(input_shape=(28, 28), output_shape=10)
+ with pytest.raises(
+ ValueError,
+ match=(
+ r"Expected check_preserved_quantize to have argument number_of_clusters"
+ ),
+ ):
+ rewrite.check_optimization(model, bad_arg_name=25)
+
+
+@pytest.mark.parametrize(
+ "rewrite_type, expected_layers, quant",
+ [
+ ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False],
+ ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False],
+ ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True],
+ ],
+)
+def test_rewriting_optimizer( # pylint: disable=too-many-locals
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
+ test_tflite_model: Path,
+ test_tfrecord: Path,
+ rewrite_type: str,
+ expected_layers: list[object],
+ quant: bool,
) -> None:
"""Test fc_layer rewrite process with rewrite type fully-connected."""
+
+ tfrecord = test_tfrecord if quant else test_tfrecord_fp32
+ tflite_model = test_tflite_model if quant else test_tflite_model_fp32
+
config_obj = RewriteConfiguration(
- "fully-connected",
+ rewrite_type,
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
- test_tfrecord_fp32,
+ tfrecord,
train_params=MockTrainingParameters(),
)
- test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
+ test_obj = RewritingOptimizer(tflite_model, config_obj)
+ rewrite_function = RewritingOptimizer.registry.items[
+ test_obj.optimizer_configuration.optimization_target
+ ]
+ # Input, output shape does not matter, just need the test the layers are as expected
+ rewrite_model = rewrite_function(input_shape=(28, 28, 1), output_shape=12)
+ for idx, layer in enumerate(rewrite_model.layers):
+ assert isinstance(layer, expected_layers[idx]) # type: ignore
+
test_obj.apply_optimization()
trained_model = test_obj.get_model()
@@ -87,11 +194,11 @@ def test_rewriting_optimizer(
def test_register_rewrite_function() -> None:
- """Test adding rewrite functions and verify the are reported via the registry."""
+ """Test adding rewrite functions and verify they are reported via the registry."""
registry = RewriteRegistry()
- rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1))
- rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2))
+ rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1))
+ rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2))
registry.register_rewrite(rewrite1)
registry.register_rewrite(rewrite2)
@@ -100,38 +207,11 @@ def test_register_rewrite_function() -> None:
def test_builtin_rewrite_names() -> None:
"""Test if all builtin rewrites are properly registered and returned."""
- assert RewritingOptimizer.builtin_rewrite_names() == ["fully-connected"]
-
-
-def test_rewrite_function_autoload() -> None:
- """Test rewrite function loading."""
- function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function"
- rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name)
- assert rewrite.name == "mock_rewrite"
-
- assert rewrite.function is not mock_rewrite_function
- assert rewrite.load_function(function_name) is mock_rewrite_function
- assert rewrite.function is mock_rewrite_function
-
-
-def test_rewrite_function_autoload_fail() -> None:
- """Test rewrite function loading failure."""
- function_name = "invalid_module.invalid_function"
- rewrite = DynamicallyLoadedRewrite(
- name="mock_rewrite",
- function_name="invalid_module.invalid_function",
- )
- assert rewrite.name == "mock_rewrite"
-
- with pytest.raises(Exception) as exc_info:
- rewrite.load_function(function_name)
-
- message = exc_info.value.args[0]
-
- assert message == (
- "Unable to load rewrite function 'invalid_module.invalid_function'"
- " for 'mock_rewrite'."
- )
+ assert RewritingOptimizer.builtin_rewrite_names() == [
+ "fully-connected",
+ "fully-connected-clustering",
+ "fully-connected-sparsity24",
+ ]
def test_rewrite_configuration_train_params(
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 6d24133..94c99ff 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -20,6 +20,7 @@ from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
+from tests.test_nn_rewrite_core_rewrite import TestRewrite
from tests.utils.rewrite import MockTrainingParameters
@@ -53,18 +54,23 @@ def check_train(
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
output_file = Path(tmp_dir, "out.tflite")
+ mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv)
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
output_model=str(output_file),
input_tfrec=str(tfrecord),
- replace_fn=replace_fully_connected_with_conv,
+ rewrite=mock_rewrite,
+ is_qat=False,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
train_params=train_params,
)
- assert len(result) == 2
- assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}"
+
+ assert len(result[0][0]) == 2
+ assert all(
+ res >= 0.0 for res in result[0][0]
+ ), f"Results out of bound: {result}"
assert output_file.is_file()
if quantized:
@@ -229,3 +235,17 @@ def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None:
with expected_error:
fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore
assert len(fn_twins) == 2
+
+
+def test_train_checkpoint(
+ test_tflite_model: Path,
+ test_tfrecord: Path,
+) -> None:
+ """Test the train() function with valid checkpoint parameters."""
+ check_train(
+ tflite_model=test_tflite_model,
+ tfrecord=test_tfrecord,
+ train_params=MockTrainingParameters(steps=64, checkpoint_at=[24, 32]),
+ use_unmodified_model=False,
+ quantized=True,
+ )
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py
index aac07b4..4095076 100644
--- a/tests/test_nn_select.py
+++ b/tests/test_nn_select.py
@@ -183,11 +183,11 @@ def test_get_optimizer(
@pytest.mark.parametrize(
"rewrite_parameters",
- [[None], [{"batch_size": 64, "learning_rate": 0.003}]],
+ [None, {"batch_size": 64, "learning_rate": 0.003}],
)
@pytest.mark.skip_set_training_steps
def test_get_optimizer_training_parameters(
- rewrite_parameters: list[dict], test_tflite_model: Path
+ rewrite_parameters: dict | None, test_tflite_model: Path
) -> None:
"""Test function get_optimzer with various combinations of parameters."""
config = OptimizationSettings(
@@ -198,20 +198,18 @@ def test_get_optimizer_training_parameters(
)
optimizer = cast(
RewritingOptimizer,
- get_optimizer(test_tflite_model, config, list(rewrite_parameters)),
+ get_optimizer(test_tflite_model, config, rewrite_parameters),
)
- assert len(rewrite_parameters) == 1
-
assert isinstance(
optimizer.optimizer_configuration.train_params, TrainingParameters
)
- if not rewrite_parameters[0]:
+ if not rewrite_parameters:
assert asdict(TrainingParameters()) == asdict(
optimizer.optimizer_configuration.train_params
)
else:
- assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict(
+ assert asdict(TrainingParameters()) | rewrite_parameters == asdict(
optimizer.optimizer_configuration.train_params
)
diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py
index 59d54b5..7bb57c3 100644
--- a/tests/test_target_cortex_a_advisor.py
+++ b/tests/test_target_cortex_a_advisor.py
@@ -47,7 +47,7 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None:
},
]
],
- "training_parameters": [None],
+ "training_parameters": None,
},
}
diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py
index cc47321..020acc5 100644
--- a/tests/test_target_tosa_advisor.py
+++ b/tests/test_target_tosa_advisor.py
@@ -47,7 +47,7 @@ def test_configure_and_get_tosa_advisor(
},
]
],
- "training_parameters": [None],
+ "training_parameters": None,
},
"tosa_inference_advisor": {
"model": str(test_tflite_model),