diff options
-rw-r--r-- | .pre-commit-config.yaml | 8 | ||||
-rw-r--r-- | pre_commit_hooks/check_copyright_header.py | 34 | ||||
-rw-r--r-- | pyproject.toml | 4 | ||||
-rw-r--r-- | setup.cfg | 8 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 201 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 178 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/fc_clustering_layer.py | 26 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/fc_sparsity24_layer.py | 23 | ||||
-rw-r--r-- | src/mlia/nn/select.py | 23 | ||||
-rw-r--r-- | src/mlia/target/common/optimization.py | 13 | ||||
-rw-r--r-- | tests/test_cli_commands.py | 29 | ||||
-rw-r--r-- | tests/test_common_optimization.py | 18 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 166 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 26 | ||||
-rw-r--r-- | tests/test_nn_select.py | 12 | ||||
-rw-r--r-- | tests/test_target_cortex_a_advisor.py | 2 | ||||
-rw-r--r-- | tests/test_target_tosa_advisor.py | 2 |
17 files changed, 629 insertions, 144 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b601b03..3788326 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -124,3 +124,11 @@ repos: hooks: - id: commitizen-branch args: [--rev-range, HEAD~1..HEAD] + +- repo: local + hooks: + - id: check-copyright-header + name: Check Copyright header years + entry: python pre_commit_hooks/check_copyright_header.py + language: python + verbose: true diff --git a/pre_commit_hooks/check_copyright_header.py b/pre_commit_hooks/check_copyright_header.py new file mode 100644 index 0000000..ded7675 --- /dev/null +++ b/pre_commit_hooks/check_copyright_header.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Pre-commit hook that checks the current year is in the Copyright header of a file. + +If the header is out of date it will print a warning. +""" +import datetime +import subprocess # nosec + + +class CopyrightHeaderChecker: + """Class that wraps the checker for the Copyright header.""" + + def check_files_have_updated_header(self, filenames: list) -> None: + """Check whether input files have the current year in the copyright string.""" + current_year = str(datetime.datetime.now().year) + for filename in filenames: + with open(filename, encoding="utf-8") as file: + first_line = file.readline() + second_line = file.readline() + if filename.endswith(".md") and current_year not in second_line: + print(f"WARNING: The Copyright header of {filename} is out of date!") + + if not filename.endswith(".md") and current_year not in first_line: + print(f"WARNING: The Copyright header of {filename} is out of date!") + + +if __name__ == "__main__": + staged_files = ( + subprocess.check_output(["git", "diff", "--cached", "--name-only"]) # nosec + .decode() + .splitlines() + ) + CopyrightHeaderChecker().check_files_have_updated_header(filenames=staged_files) diff --git a/pyproject.toml b/pyproject.toml index 0c4cc8c..cf2db54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,6 @@ update_changelog_on_bump = true schema_pattern = "(?s)(build|ci|docs|feat|fix|perf|refactor|style|test)(\\(\\S+\\))?!?:( [A-Z][^\\n\\r]+)((\\n\\n.*)|(\\s*))?$" schema = "<type>(<scope>): <Subject-capitalized>\n<BLANK LINE>\n<body>\n<BLANK LINE>\n(BREAKING CHANGE: )<footer>" # Commit parser is used to render the commits for RELEASES.md -commit_parser = "^((?P<change_type>feat|fix|refactor|perf|BREAKING CHANGE)(?:\\((?P<scope>[^()\\r\\n]*)\\)|\\()?(?P<breaking>!)?|\\w+!):\\s(?P<message>.*)?" +commit_parser = "^((?P<change_type>build|ci|docs|feat|fix|perf|refactor|style|test|BREAKING CHANGE)(?:\\((?P<scope>[^()\\r\\n]*)\\)|\\()?(?P<breaking>!)?|\\w+!):\\s(?P<message>.*)?" # Change type map to render the title for that category as per {tag:title} -change_type_map = {'feat' = 'Feature changes', 'fix' = 'Bug fix', 'perf' = 'Performance improvements'} +change_type_map = {'feat' = 'Feature changes', 'fix' = 'Bug fix', 'perf' = 'Performance improvements', 'build' = 'Development changes'} @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-FileCopyrightText: Copyright (c) 2020 Troy Comi # SPDX-License-Identifier: Apache-2.0 AND MIT @@ -28,8 +28,12 @@ python_requires = >=3.9.0 package_dir = = src packages = find_namespace: +# Pinning tensorflow & h5py to work around build issue on aarch64: +# https://github.com/h5py/h5py/issues/2408 +# Idea is to unpin these when it's resolved. install_requires = - tensorflow~=2.15.1 + tensorflow==2.15.1 + h5py==3.10.0 tensorflow-model-optimization~=0.7.5 ethos-u-vela~=3.11.0 flaky~=3.7.0 diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index c7d13ba..e2c097c 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -3,15 +3,17 @@ """Contains class RewritingOptimizer to replace a subgraph/layer of a model.""" from __future__ import annotations -import importlib import logging import tempfile +from abc import ABC +from abc import abstractmethod from dataclasses import dataclass from pathlib import Path from typing import Any from typing import Callable -from typing import cast +import numpy as np +import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.core.errors import ConfigurationError @@ -22,6 +24,13 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters +from mlia.nn.rewrite.library.fc_clustering_layer import ( + get_keras_model_clus as fc_clustering_rewrite, +) +from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite +from mlia.nn.rewrite.library.fc_sparsity24_layer import ( + get_keras_model as fc_rewrite_sparsity24, +) from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry @@ -30,8 +39,8 @@ logger = logging.getLogger(__name__) RewriteCallable = Callable[[Any, Any], keras.Model] -class Rewrite: - """Graph rewrite logic to be used by RewritingOptimizer.""" +class Rewrite(ABC): + """Abstract class for rewrite logic to be used by RewritingOptimizer.""" def __init__(self, name: str, rewrite_fn: RewriteCallable): """Initialize a Rewrite instance with a given name and an optional function.""" @@ -45,34 +54,138 @@ class Rewrite: except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return model -@dataclass -class DynamicallyLoadedRewrite(Rewrite): - """A rewrite which can load logic from a function loaded dynamically.""" + @abstractmethod + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" - def __init__(self, name: str, function_name: str): - """Initialize.""" + @abstractmethod + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" - def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model: - """Load the function from a file dynamically.""" - self.load_function(function_name) - return self.function(input_shape, output_shape) + @abstractmethod + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Check if the optimization has produced the correct result.""" - super().__init__(name, load_and_run) - def load_function(self, function_name: str) -> RewriteCallable: - """Return the rewrite function. Import using the auto_load attr if necessary.""" - try: - name_parts = function_name.split(".") - module_name = ".".join(name_parts[:-1]) - fn_name = name_parts[-1] - module = importlib.import_module(module_name) - self.function = cast(RewriteCallable, getattr(module, fn_name)) - return self.function - except Exception as ex: - raise RuntimeError( - f"Unable to load rewrite function '{function_name}' for '{self.name}'." - ) from ex +class GenericRewrite(Rewrite): + """Graph rewrite logic for fully-connected rewrite.""" + + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + return model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Not needed here.""" + return True + + +class QuantizeAwareTrainingRewrite(Rewrite, ABC): + """Abstract class for rewrites that perform QAT.""" + + @abstractmethod + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply optimization-aware quantization to a given model.""" + return model + + +class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): + """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" + + pruning_callback = tfmot.sparsity.keras.UpdatePruningStep + + strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) + + def quantize(self, model: keras.Model) -> keras.Model: + """Skip quantization when using pruning rewrite.""" + return model + + def training_callbacks(self) -> list: + """Return pruning-specific rewrite callback.""" + return [self.pruning_callback()] + + def post_process(self, model: keras.Model) -> keras.Model: + """Pruning-specific post-processing rewrite options.""" + return self.strip_pruning_wrapper(model) + + def preserved_quantize( + self, + model: keras.Model, + ) -> keras.Model: + """Apply pruning-preserved quantization training to a given model.""" + model = tfmot.quantization.keras.quantize_annotate_model(model) + model = tfmot.quantization.keras.quantize_apply( + model, + tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), + ) + + return model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Not needed here.""" + return True + + +class ClusteringRewrite(QuantizeAwareTrainingRewrite): + """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + + _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply clustering-preserved quantization to a given model.""" + quant_aware_model = tfmot.quantization.keras.quantize_annotate_model(model) + cqat_model = tfmot.quantization.keras.quantize_apply( + quant_aware_model, + tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(), + ) + return cqat_model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Check if clustering has produced the correct result.""" + number_of_clusters = kwargs.get("number_of_clusters") + if not number_of_clusters: + raise ValueError( + """ + Expected check_preserved_quantize to have argument number_of_clusters. + """ + ) + + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + number_of_found_clusters = len(np.unique(weight)) + if number_of_found_clusters != number_of_clusters: + logger.warning( + "\nWARNING: Expected %d cluster(s), found %d " + "cluster(s) in layer %s for weight %s \n", + number_of_clusters, + number_of_found_clusters, + layer.name, + weight.name, + ) + return False + return True + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return the clustering stripped model.""" + return self._strip_clustering_wrapper(model) class RewriteRegistry(Registry[Rewrite]): @@ -113,9 +226,9 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ - DynamicallyLoadedRewrite( - "fully-connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model" - ) + GenericRewrite("fully-connected", fc_rewrite), + Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), + ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), ] ) @@ -149,22 +262,35 @@ class RewritingOptimizer(Optimizer): raise ConfigurationError( "Input and output tensor names need to be set for rewrite." ) - orig_vs_repl_stats, total_stats = train( source_model=tflite_model, unmodified_model=tflite_model if use_unmodified_model else None, output_model=str(tmp_output), input_tfrec=str(tfrecord), - replace_fn=rewrite, + rewrite=rewrite, + is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite), input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, ) if orig_vs_repl_stats: - orig_vs_repl = ["Replaced sub-graph only"] + [ - f"{stat:.3f}" for stat in orig_vs_repl_stats - ] + model_stats: list = [] + cp_param = self.optimizer_configuration.train_params.checkpoint_at + checkpoints = ( + [ + "At checkpoint " + str(checkpoint) + " steps" + for checkpoint in cp_param + ] + if cp_param + else [] + ) + checkpoints.append("All Steps") + for checkpoint, orig_vs_repl_stat in zip(checkpoints, orig_vs_repl_stats): + model_stats.append( + ["Replaced sub-graph: " + checkpoint] + + [f"{stat:.3f}" for stat in orig_vs_repl_stat] + ) total = ["Total model"] + [f"{stat:.3f}" for stat in total_stats] notes = ( "These metrics show the difference between original model\n" @@ -178,19 +304,20 @@ class RewritingOptimizer(Optimizer): table = Table( columns=[ Column( - "Original vs. optimized", + "Original vs. Optimized", alias="metric", fmt=Format(wrap_width=40), ), Column("MAE", alias="value", fmt=Format(wrap_width=15)), Column("NRMSE", alias="value", fmt=Format(wrap_width=15)), ], - rows=[orig_vs_repl, total], + rows=[*model_stats, total], name="Rewrite performance metrics", alias="rewrite_performance_metrics", notes=notes, ) logger.info(table.to_plain_text(show_title=True)) + self.model = TFLiteModel(tmp_output) def get_model(self) -> TFLiteModel: """Return optimized model.""" diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 60c39ae..88efa23 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Sequential trainer.""" +# pylint: disable=too-many-arguments # pylint: disable=too-many-locals # pylint: disable=too-many-statements from __future__ import annotations @@ -22,7 +23,6 @@ from typing import Literal import numpy as np import tensorflow as tf -import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from numpy.random import Generator @@ -73,12 +73,13 @@ class TrainingParameters: checkpoint_at: list | None = None -def train( +def train( # pylint: disable=too-many-arguments source_model: str, unmodified_model: Any, output_model: str, input_tfrec: str, - replace_fn: Callable, + rewrite: Callable, + is_qat: bool, input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), @@ -118,7 +119,8 @@ def train( train_dir=train_dir, baseline_dir=unmodified_model_dir_path, output_filename=Path(train_dir, "new.tflite"), - replace_fn=replace_fn, + rewrite=rewrite, + is_qat=is_qat, train_params=train_params, ) @@ -159,10 +161,10 @@ def train( if unmodified_model_dir: cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() - return (results if train_params.checkpoint_at else results[0]), [ + return results, [ mae, nrmse, - ] # only return a list if multiple checkpoints are asked for + ] def eval_in_dir( @@ -345,7 +347,8 @@ def train_in_dir( train_dir: str, baseline_dir: Any, output_filename: Path, - replace_fn: Callable, + rewrite: Callable, + is_qat: bool, train_params: TrainingParameters = TrainingParameters(), ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ @@ -380,15 +383,15 @@ def train_in_dir( ) input_shape = teacher.shape_from_name[input_name][1:] - output_shape = teacher.shape_from_name[output_name][1:] - model = replace_fn(input_shape, output_shape) + output_shape = teacher.shape_from_name[output_name][1:] optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = keras.losses.MeanSquaredError() - if model_is_quantized: - model = tfmot.quantization.keras.quantize_model(model) - model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) + + model = create_model( + rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + ) logger.info(model.summary()) @@ -428,11 +431,127 @@ def train_in_dir( elif train_params.learning_rate_schedule == "constant": callbacks = [] - output_filenames = [] + callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] + output_filenames: list = [] checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ train_params.steps ] + model, output_filenames = model_fit( + model, + train_params, + checkpoints, + optimizer, + dataset, + callbacks, + output_filename, + rewrite, + replace, + input_name, + output_name, + model_is_quantized, + output_filenames, + input_shape, + output_shape, + loss_fn, + post_process=True, + ) + + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) + if model_is_quantized and is_qat: + model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] + checkpoints = ( + train_params.checkpoint_at if train_params.checkpoint_at else [] + ) + [train_params.steps] + output_filenames = [] + + if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined] + rewrite.training_callbacks() # type: ignore[attr-defined] + ).issubset(callbacks): + callbacks.pop(-1) + + optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) + model = model_compile(model, optimizer, loss_fn) + + model, output_filenames = model_fit( + model, + train_params, + checkpoints, + optimizer, + dataset, + callbacks, + output_filename, + rewrite, + replace, + input_name, + output_name, + model_is_quantized, + output_filenames, + input_shape, + output_shape, + loss_fn, + ) + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) + + teacher.close() + return output_filenames + +def model_compile( + model: keras.Model, + optimizer: keras.optimizers.Nadam, + loss_fn: keras.losses.Loss, +) -> keras.Model: + """Compiles a tflite model.""" + model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) + return model + + +def create_model( # pylint: disable=too-many-arguments + rewrite: Callable, + input_shape: int, + output_shape: int, + optimizer: Callable, + loss_fn: Callable, + model_is_quantized: bool, + model_to_load_from: keras.model | None = None, +) -> keras.Model: + """Create a model, optionally from another.""" + model = rewrite(input_shape, output_shape) + if model_is_quantized: + model = rewrite.quantize(model) # type: ignore[attr-defined] + model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) + if model_to_load_from: + model.set_weights(model_to_load_from.get_weights()) + return model + + +def model_fit( # pylint: disable=too-many-arguments + model: keras.Model, + train_params: TrainingParameters, + checkpoints: list, + optimizer: tf.optimizers.Nadam, + dataset: tf.data.Dataset, + callbacks: list, + output_filename: Path, + rewrite: Callable, + replace: TFLiteModel, + input_name: str, + output_name: str, + model_is_quantized: bool, + output_filenames: list, + input_shape: int, + output_shape: int, + loss_fn: Callable, + post_process: bool = False, +) -> keras.Model: + """Train a tflite model.""" + steps_so_far = 0 while steps_so_far < train_params.steps: steps_to_train = checkpoints.pop(0) - steps_so_far lr_start = optimizer.learning_rate.numpy() @@ -452,15 +571,39 @@ def train_in_dir( ) if steps_so_far < train_params.steps: - filename, ext = Path(output_filename).parts[1:] - checkpoint_filename = filename + (f"_@{steps_so_far}") + ext + filename = Path(output_filename).stem + filename_dir = Path(output_filename).parent.as_posix() + ext = Path(output_filename).suffix + checkpoint_filename = ( + filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext + ) + # If post processing we are stripping the clustering/pruning layers below + # Thus copy the model before saving, so training can continue + if post_process: + model_to_save = create_model( + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + model_to_load_from=model, + ) + else: + model_to_save = model else: checkpoint_filename = str(output_filename) + model_to_save = model + with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): + if post_process: + model_to_save = rewrite.post_process( # type: ignore[attr-defined] + model_to_save + ) save_as_tflite( - model, + model_to_save, checkpoint_filename, input_name, replace.shape_from_name[input_name], @@ -470,8 +613,7 @@ def train_in_dir( ) output_filenames.append(checkpoint_filename) - teacher.close() - return output_filenames + return model_to_save, output_filenames def save_as_tflite( diff --git a/src/mlia/nn/rewrite/library/fc_clustering_layer.py b/src/mlia/nn/rewrite/library/fc_clustering_layer.py new file mode 100644 index 0000000..7cc383e --- /dev/null +++ b/src/mlia/nn/rewrite/library/fc_clustering_layer.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Example rewrite with one fully connected clustered layer.""" +from typing import Any + +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + + +def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model: + """Generate TensorFlow Lite model for clustering rewrite.""" + rewrite_params = { + "number_of_clusters": 32, + "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR, + } + model = tfmot.clustering.keras.cluster_weights( + to_cluster=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Flatten(), + keras.layers.Dense(units=output_shape), + ] + ), + **rewrite_params + ) + return model diff --git a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py b/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py new file mode 100644 index 0000000..531b34a --- /dev/null +++ b/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Example rewrite with one fully connected 2:4 sparsity layer.""" +from typing import Any + +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + + +def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model: + """Generate TensorFlow Lite model for rewrite.""" + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Reshape([-1]), + keras.layers.Dense(output_shape), + ] + ), + sparsity_m_by_n=(2, 4), + ) + + return model diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 81a614f..b61e713 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -151,7 +151,7 @@ def get_optimizer( def _get_optimizer( model: keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] @@ -173,22 +173,17 @@ def _get_optimizer( def _get_rewrite_params( - training_parameters: list[dict | None] | None = None, -) -> list: + training_parameters: dict | None = None, +) -> TrainingParameters: """Get the rewrite TrainingParameters. Return the default constructed TrainingParameters() per default, but can be overwritten in the unit tests. """ - if training_parameters is None: - return [TrainingParameters()] + if not training_parameters: + return TrainingParameters() - if training_parameters[0] is None: - train_params = TrainingParameters() - else: - train_params = TrainingParameters(**training_parameters[0]) - - return [train_params] + return TrainingParameters(**training_parameters) def _get_optimizer_configuration( @@ -196,7 +191,7 @@ def _get_optimizer_configuration( optimization_target: int | float | str, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -222,7 +217,7 @@ def _get_optimizer_configuration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=rewrite_params[0], + train_params=rewrite_params, ) raise ConfigurationError( diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py index 8c5d184..1423189 100644 --- a/src/mlia/target/common/optimization.py +++ b/src/mlia/target/common/optimization.py @@ -86,7 +86,7 @@ class OptimizingDataCollector(ContextAwareDataCollector): def optimize_model( self, opt_settings: list[OptimizationSettings], - training_parameters: list[dict | None], + training_parameters: dict | None, model: KerasModel | TFLiteModel, ) -> Any: """Run optimization.""" @@ -123,12 +123,12 @@ class OptimizingDataCollector(ContextAwareDataCollector): context=context, ) - def _get_training_settings(self, context: Context) -> list[dict]: + def _get_training_settings(self, context: Context) -> dict: """Get optimization settings.""" return self.get_parameter( # type: ignore OptimizingDataCollector.name(), "training_parameters", - expected_type=list, + expected_type=dict, expected=False, context=context, ) @@ -228,9 +228,8 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) - raise TypeError("Optimization targets value has wrong format.") rewrite_parameters = extra_args.get("optimization_profile") - if not rewrite_parameters: - training_parameters = None - else: + training_parameters = None + if rewrite_parameters: if not isinstance(rewrite_parameters, dict): raise TypeError("Training Parameter values has wrong format.") training_parameters = extra_args["optimization_profile"].get("training") @@ -239,7 +238,7 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) - { "common_optimizations": { "optimizations": [optimization_targets], - "training_parameters": [training_parameters], + "training_parameters": training_parameters, }, } ) diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 9cda27c..93a05bd 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -84,6 +84,19 @@ def test_performance_unknown_target( ], [ "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "fully-connected-sparsity24", + "sequential/flatten/Reshape", + "StatefulPartitionedCall:0", + does_not_raise(), + ], + [ + "ethos-u55-256", True, False, None, @@ -126,7 +139,8 @@ def test_performance_unknown_target( Exception, match=re.escape( "Invalid rewrite target: 'random'. " - "Supported rewrites: ['fully-connected']" + "Supported rewrites: ['fully-connected'," + " 'fully-connected-clustering', 'fully-connected-sparsity24']" ), ), ], @@ -168,6 +182,19 @@ def test_performance_unknown_target( ), ), ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "fully-connected-clustering", + "sequential/flatten/Reshape", + "StatefulPartitionedCall:0", + does_not_raise(), + ], ], ) def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 05a5b55..341e0d2 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -57,7 +57,7 @@ def test_optimizing_data_collector( config_parameters={ "common_optimizations": { "optimizations": optimizations, - "training_parameters": [training_parameters], + "training_parameters": training_parameters, } } ) @@ -94,7 +94,7 @@ def test_optimizing_data_collector( collector.set_context(context) collector.collect_data() assert optimize_model_mock.call_args.args[0] == opt_settings[0] - assert optimize_model_mock.call_args.args[1] == [training_parameters] + assert optimize_model_mock.call_args.args[1] == training_parameters assert fake_optimizer.invocation_count == 1 @@ -158,10 +158,12 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) - ] if not extra_args.get("optimization_profile"): - assert advisor_parameters["common_optimizations"][ - "training_parameters" - ] == [None] + assert ( + advisor_parameters["common_optimizations"]["training_parameters"] + is None + ) else: - assert advisor_parameters["common_optimizations"][ - "training_parameters" - ] == list(extra_args["optimization_profile"].values()) + assert ( + advisor_parameters["common_optimizations"]["training_parameters"] + == extra_args["optimization_profile"]["training"] + ) diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index b32fafd..e502842 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -10,45 +10,102 @@ from typing import cast from unittest.mock import MagicMock import pytest +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 +from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module + ClusterWeights, +) -from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite +from mlia.nn.rewrite.core.rewrite import ClusteringRewrite +from mlia.nn.rewrite.core.rewrite import GenericRewrite from mlia.nn.rewrite.core.rewrite import Rewrite from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.rewrite.core.train import train_in_dir +from mlia.nn.rewrite.library.fc_clustering_layer import ( + get_keras_model_clus as fc_clustering_rewrite, +) from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters +class TestRewrite(Rewrite): + """Test rewrite class.""" + + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Not needed.""" + return model + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + return model + + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Not needed here.""" + return True + + def mock_rewrite_function(*_: Any) -> Any: """Mock function to test autoloading of rewrite functions.""" def test_rewrite() -> None: - """Test the Rewrite class.""" + """Test a derived Rewrite class.""" def bad_rewrite_func() -> Any: raise NotImplementedError() - rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)) + rewrite = TestRewrite( + "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func) + ) with pytest.raises(RuntimeError): rewrite((1, 2), (1, 2)) @pytest.mark.parametrize( + "rewrite_name, callbacks_length, instance", + [ + ("fully-connected", 0, GenericRewrite), + ("fully-connected-clustering", 0, ClusteringRewrite), + ("fully-connected-sparsity24", 1, Sparsity24Rewrite), + ], +) +def test_rewrite_selection( + rewrite_name: str, callbacks_length: int, instance: Rewrite +) -> None: + """Test that the correct rewrite class is instantiated.""" + rewrite = RewritingOptimizer.registry.items[rewrite_name] + assert rewrite.name == rewrite_name + assert isinstance(rewrite, instance) # type: ignore + assert len(rewrite.training_callbacks()) == callbacks_length + + +@pytest.mark.parametrize( "rewrite_name, expected_error", [ ("fully-connected", does_not_raise()), + ("fully-connected-sparsity24", does_not_raise()), + ("fully-connected-clustering", does_not_raise()), ("random", does_not_raise()), ], ) def test_rewrite_configuration( test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any ) -> None: - """Test get_rewrite function only supports rewrite type fully-connected.""" + """Test get_rewrite function only supports rewrite type fully-connected, + fully-connected-clustering and fully-connected-sparsity24.""" with expected_error: config_obj = RewriteConfiguration( rewrite_name, @@ -63,19 +120,69 @@ def test_rewrite_configuration( assert isinstance(rewriter_obj, RewritingOptimizer) -def test_rewriting_optimizer( +def test_rewrite_fully_connected_clustering() -> None: + """Check that model has the set number of clusters""" + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model, number_of_clusters=32) + + +def test_rewrite_fully_connected_clustering_error_handling() -> None: + """Check that model has the set number of clusters + and that when quantized the number of clusters + remain.""" + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + with pytest.raises( + ValueError, + match=( + r"Expected check_preserved_quantize to have argument number_of_clusters" + ), + ): + rewrite.check_optimization(model, bad_arg_name=25) + + +@pytest.mark.parametrize( + "rewrite_type, expected_layers, quant", + [ + ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True], + ], +) +def test_rewriting_optimizer( # pylint: disable=too-many-locals test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, + test_tflite_model: Path, + test_tfrecord: Path, + rewrite_type: str, + expected_layers: list[object], + quant: bool, ) -> None: """Test fc_layer rewrite process with rewrite type fully-connected.""" + + tfrecord = test_tfrecord if quant else test_tfrecord_fp32 + tflite_model = test_tflite_model if quant else test_tflite_model_fp32 + config_obj = RewriteConfiguration( - "fully-connected", + rewrite_type, ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], - test_tfrecord_fp32, + tfrecord, train_params=MockTrainingParameters(), ) - test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) + test_obj = RewritingOptimizer(tflite_model, config_obj) + rewrite_function = RewritingOptimizer.registry.items[ + test_obj.optimizer_configuration.optimization_target + ] + # Input, output shape does not matter, just need the test the layers are as expected + rewrite_model = rewrite_function(input_shape=(28, 28, 1), output_shape=12) + for idx, layer in enumerate(rewrite_model.layers): + assert isinstance(layer, expected_layers[idx]) # type: ignore + test_obj.apply_optimization() trained_model = test_obj.get_model() @@ -87,11 +194,11 @@ def test_rewriting_optimizer( def test_register_rewrite_function() -> None: - """Test adding rewrite functions and verify the are reported via the registry.""" + """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() - rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1)) - rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2)) + rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1)) + rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2)) registry.register_rewrite(rewrite1) registry.register_rewrite(rewrite2) @@ -100,38 +207,11 @@ def test_register_rewrite_function() -> None: def test_builtin_rewrite_names() -> None: """Test if all builtin rewrites are properly registered and returned.""" - assert RewritingOptimizer.builtin_rewrite_names() == ["fully-connected"] - - -def test_rewrite_function_autoload() -> None: - """Test rewrite function loading.""" - function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function" - rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name) - assert rewrite.name == "mock_rewrite" - - assert rewrite.function is not mock_rewrite_function - assert rewrite.load_function(function_name) is mock_rewrite_function - assert rewrite.function is mock_rewrite_function - - -def test_rewrite_function_autoload_fail() -> None: - """Test rewrite function loading failure.""" - function_name = "invalid_module.invalid_function" - rewrite = DynamicallyLoadedRewrite( - name="mock_rewrite", - function_name="invalid_module.invalid_function", - ) - assert rewrite.name == "mock_rewrite" - - with pytest.raises(Exception) as exc_info: - rewrite.load_function(function_name) - - message = exc_info.value.args[0] - - assert message == ( - "Unable to load rewrite function 'invalid_module.invalid_function'" - " for 'mock_rewrite'." - ) + assert RewritingOptimizer.builtin_rewrite_names() == [ + "fully-connected", + "fully-connected-clustering", + "fully-connected-sparsity24", + ] def test_rewrite_configuration_train_params( diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 6d24133..94c99ff 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -20,6 +20,7 @@ from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters +from tests.test_nn_rewrite_core_rewrite import TestRewrite from tests.utils.rewrite import MockTrainingParameters @@ -53,18 +54,23 @@ def check_train( """Test the train() function.""" with TemporaryDirectory() as tmp_dir: output_file = Path(tmp_dir, "out.tflite") + mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv) result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, output_model=str(output_file), input_tfrec=str(tfrecord), - replace_fn=replace_fully_connected_with_conv, + rewrite=mock_rewrite, + is_qat=False, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], train_params=train_params, ) - assert len(result) == 2 - assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}" + + assert len(result[0][0]) == 2 + assert all( + res >= 0.0 for res in result[0][0] + ), f"Results out of bound: {result}" assert output_file.is_file() if quantized: @@ -229,3 +235,17 @@ def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None: with expected_error: fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore assert len(fn_twins) == 2 + + +def test_train_checkpoint( + test_tflite_model: Path, + test_tfrecord: Path, +) -> None: + """Test the train() function with valid checkpoint parameters.""" + check_train( + tflite_model=test_tflite_model, + tfrecord=test_tfrecord, + train_params=MockTrainingParameters(steps=64, checkpoint_at=[24, 32]), + use_unmodified_model=False, + quantized=True, + ) diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py index aac07b4..4095076 100644 --- a/tests/test_nn_select.py +++ b/tests/test_nn_select.py @@ -183,11 +183,11 @@ def test_get_optimizer( @pytest.mark.parametrize( "rewrite_parameters", - [[None], [{"batch_size": 64, "learning_rate": 0.003}]], + [None, {"batch_size": 64, "learning_rate": 0.003}], ) @pytest.mark.skip_set_training_steps def test_get_optimizer_training_parameters( - rewrite_parameters: list[dict], test_tflite_model: Path + rewrite_parameters: dict | None, test_tflite_model: Path ) -> None: """Test function get_optimzer with various combinations of parameters.""" config = OptimizationSettings( @@ -198,20 +198,18 @@ def test_get_optimizer_training_parameters( ) optimizer = cast( RewritingOptimizer, - get_optimizer(test_tflite_model, config, list(rewrite_parameters)), + get_optimizer(test_tflite_model, config, rewrite_parameters), ) - assert len(rewrite_parameters) == 1 - assert isinstance( optimizer.optimizer_configuration.train_params, TrainingParameters ) - if not rewrite_parameters[0]: + if not rewrite_parameters: assert asdict(TrainingParameters()) == asdict( optimizer.optimizer_configuration.train_params ) else: - assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict( + assert asdict(TrainingParameters()) | rewrite_parameters == asdict( optimizer.optimizer_configuration.train_params ) diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py index 59d54b5..7bb57c3 100644 --- a/tests/test_target_cortex_a_advisor.py +++ b/tests/test_target_cortex_a_advisor.py @@ -47,7 +47,7 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None: }, ] ], - "training_parameters": [None], + "training_parameters": None, }, } diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py index cc47321..020acc5 100644 --- a/tests/test_target_tosa_advisor.py +++ b/tests/test_target_tosa_advisor.py @@ -47,7 +47,7 @@ def test_configure_and_get_tosa_advisor( }, ] ], - "training_parameters": [None], + "training_parameters": None, }, "tosa_inference_advisor": { "model": str(test_tflite_model), |