diff options
-rw-r--r-- | README.md | 20 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 9 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 52 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/clustering.py | 7 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/helper_functions.py | 26 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/sparsity.py | 12 | ||||
-rw-r--r-- | src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml | 1 | ||||
-rw-r--r-- | src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml | 1 | ||||
-rw-r--r-- | tests/conftest.py | 32 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 39 | ||||
-rw-r--r-- | tests/test_nn_rewrite_library_helper_functions.py | 44 |
11 files changed, 231 insertions, 12 deletions
@@ -229,19 +229,19 @@ There are a number of predefined profiles for rewrites shown below: | Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Num Clusters | Cluster Centroids Init | | :-------------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :----------: | :--------------------------------: | -| optimization-fully-connected-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 32 | "CentroidInitialization.LINEAR" | +| optimization-fully-connected-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 16 | "CentroidInitialization.LINEAR" | | Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Sparsity M | Sparsity N | | :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: | | optimization-fully-connected-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 2 | 4 | -| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Num Clusters | Cluster Centroids Init | -| :-------------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :----------: | :--------------------------------: | -| optimization-conv2d-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 32 | "CentroidInitialization.LINEAR" | +| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Num Clusters | Cluster Centroids Init | Activation | +| :-------------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :----------: | :--------------------------------: | :--------: | +| optimization-conv2d-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 16 | "CentroidInitialization.LINEAR" | "relu" | -| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Sparsity M | Sparsity N | -| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: | -| optimization-conv2d-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 2 | 4 | +| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Sparsity M | Sparsity N | Activation | +| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: | :--------: | +| optimization-conv2d-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 2 | 4 | "relu" | These are summarized below: @@ -251,6 +251,12 @@ These are summarized below: * optimization-conv2d-clustering - Provides training parameters for rewrites and cluster specific parameters for the conv2d-clustering rewrite * optimization-conv2d-pruning - Provides training parameters for rewrites and pruning specific parameters for the conv2d-sparsity rewrite +Note for convolutional rewrites (e.g. optimization-conv2d-pruning). The activation function for the rewrite can be selected in the optimization profile from the following list: + +* "relu" - Standard ReLU activation function +* "relu6" - ReLU6 activation function i.e. ReLU activation function capped at 6 +* "none" - No activation function + The user can also specify custom augmentations as part of the training parameters. An example of this can be found in the following optimization profile: | Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Augmentations - gaussian_strength | Augmentations - mixup_strength | diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 78fa533..a802c51 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -55,7 +55,7 @@ class Rewrite(ABC): try: return self.function(input_shape, output_shape, **kwargs) except TypeError as ex: - expected_args = getfullargspec(self.function).args + expected_args = self.return_rewrite_func_args() if "input_shape" in expected_args: expected_args.remove("input_shape") if "output_shape" in expected_args: @@ -72,6 +72,10 @@ class Rewrite(ABC): """Return a quantized model if required.""" return model + def return_rewrite_func_args(self) -> list[str]: + """Return the expected args of the rewrite function.""" + return getfullargspec(self.function).args + @abstractmethod def training_callbacks(self) -> list: """Return rewrite callbacks.""" @@ -304,6 +308,9 @@ class RewritingOptimizer(Optimizer): output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, rewrite_specific_params=self.optimizer_configuration.rewrite_specific_params, # pylint: disable=line-too-long + detect_activation_function=( + "activation" in rewrite.return_rewrite_func_args() + ), ) if orig_vs_repl_stats: diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index e99c7e9..570968a 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -34,13 +34,13 @@ from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.utils.logging import log_action - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logger = logging.getLogger(__name__) @@ -84,6 +84,7 @@ def train( # pylint: disable=too-many-arguments output_tensors: list, train_params: TrainingParameters = TrainingParameters(), rewrite_specific_params: dict | None = None, + detect_activation_function: bool = False, ) -> Any: """Extract and train a model, and return the results.""" if unmodified_model: @@ -124,6 +125,7 @@ def train( # pylint: disable=too-many-arguments is_qat=is_qat, train_params=train_params, rewrite_specific_params=rewrite_specific_params, + detect_activation_function=detect_activation_function, ) for i, filename in enumerate(tflite_filenames): @@ -351,6 +353,41 @@ def set_up_data_pipeline( return dataset, steps_per_epoch +def detect_activation_from_rewrite_function(model_path: str) -> str: + """Given a rewrite model, choose the most common activation function.""" + interpreter = tf.lite.Interpreter(model_path=model_path) + interpreter.allocate_tensors() + act_func_match_list = [] + for tensor_details in interpreter.get_tensor_details(): + for act_func in ACTIVATION_FUNCTION_LIST: + tensor_name = tensor_details["name"].lower() + if act_func in tensor_name: + act_func_idx = tensor_name.index(act_func) + if ( + len(tensor_name) == act_func_idx + len(act_func) + or tensor_name[act_func_idx + len(act_func)] == ";" + ): + act_func_match_list.append( + tensor_name[ + act_func_idx : act_func_idx + len(act_func) # noqa: E203 + ] + ) + act_func_match = "relu" + if len(act_func_match_list) == 0: + logger.info( + "No activation function specified, setting activation function to ReLU" + ) + else: + act_func_match = max(set(act_func_match_list), key=act_func_match.count) + logger.info( + "No activation function specified, " + "setting activation function to most " + "common activation detected in rewrite graph: %s", + act_func_match, + ) + return act_func_match + + def train_in_dir( train_dir: str, baseline_dir: Any, @@ -359,6 +396,7 @@ def train_in_dir( is_qat: bool, train_params: TrainingParameters = TrainingParameters(), rewrite_specific_params: dict | None = None, + detect_activation_function: bool = False, ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ and output.tfrec in train_dir. @@ -375,6 +413,18 @@ def train_in_dir( ) replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir)) + if detect_activation_function and ( + rewrite_specific_params is None + or "activation" not in list(rewrite_specific_params.keys()) + ): + detected_activation_function = detect_activation_from_rewrite_function( + ExtractPaths.tflite.replace(train_dir).as_posix() + ) + if rewrite_specific_params: + rewrite_specific_params["activation"] = detected_activation_function + else: + rewrite_specific_params = {"activation": detected_activation_function} + input_name, output_name = _get_io_tensors(teacher) model_is_quantized = replace.is_tensor_quantized(name=input_name) diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py index a81d2d4..b159763 100644 --- a/src/mlia/nn/rewrite/library/clustering.py +++ b/src/mlia/nn/rewrite/library/clustering.py @@ -7,6 +7,7 @@ import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function def fc_clustering_rewrite( @@ -42,6 +43,7 @@ def conv2d_clustering_rewrite( cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long "CentroidInitialization.LINEAR" ), + activation: str = "relu", ) -> keras.Model: """Conv2d TensorFlow Lite model ready for clustering.""" rewrite_params = { @@ -51,13 +53,16 @@ def conv2d_clustering_rewrite( conv2d_parameters = compute_conv2d_parameters( input_shape=input_shape, output_shape=output_shape ) + activation_function, activation_function_extra_args = get_activation_function( + activation + ) model = tfmot.clustering.keras.cluster_weights( to_cluster=keras.Sequential( [ keras.layers.InputLayer(input_shape=input_shape), keras.layers.Conv2D(**conv2d_parameters), keras.layers.BatchNormalization(), - keras.layers.ReLU(), + activation_function(**activation_function_extra_args), ] ), **rewrite_params diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py index 4f08170..58d84b1 100644 --- a/src/mlia/nn/rewrite/library/helper_functions.py +++ b/src/mlia/nn/rewrite/library/helper_functions.py @@ -5,6 +5,32 @@ import math from typing import Any import numpy as np +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +ACTIVATION_FUNCTION_PRESETS = { + "relu": {"layer_func": keras.layers.ReLU, "extra_args": {}}, + "relu6": {"layer_func": keras.layers.ReLU, "extra_args": {"max_value": 6}}, + "none": {"layer_func": keras.layers.Identity, "extra_args": {}}, +} +ACTIVATION_FUNCTION_LIST = [ + act_func for act_func, _ in ACTIVATION_FUNCTION_PRESETS.items() +] + + +def get_activation_function( + activation: str = "relu", +) -> tuple[type[keras.layers.Layer], dict]: + """Get the activation function from a key.""" + if activation not in ACTIVATION_FUNCTION_LIST: + raise KeyError( + "Expected activation function to be " + f"in {ACTIVATION_FUNCTION_LIST}, found {activation}" + ) + activation_function = ACTIVATION_FUNCTION_PRESETS[activation]["layer_func"] + activation_function_extra_args = ACTIVATION_FUNCTION_PRESETS[activation][ + "extra_args" + ] + return activation_function, activation_function_extra_args def compute_conv2d_parameters( diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py index 2342e3d..95f99a7 100644 --- a/src/mlia/nn/rewrite/library/sparsity.py +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -7,6 +7,7 @@ import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function def fc_sparsity_rewrite( @@ -31,19 +32,26 @@ def fc_sparsity_rewrite( def conv2d_sparsity_rewrite( - input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4 + input_shape: Any, + output_shape: Any, + sparsity_m: int = 2, + sparsity_n: int = 4, + activation: str = "relu", ) -> keras.Model: """Conv2d TensorFlow Lite model ready for sparse pruning.""" conv2d_parameters = compute_conv2d_parameters( input_shape=input_shape, output_shape=output_shape ) + activation_function, activation_function_extra_args = get_activation_function( + activation + ) model = tfmot.sparsity.keras.prune_low_magnitude( to_prune=keras.Sequential( [ keras.layers.InputLayer(input_shape=input_shape), keras.layers.Conv2D(**conv2d_parameters), keras.layers.BatchNormalization(), - keras.layers.ReLU(), + activation_function(**activation_function_extra_args), ] ), sparsity_m_by_n=( diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml index 4e09e0f..fe50c31 100644 --- a/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml +++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml @@ -15,3 +15,4 @@ augmentations.mixup_strength = 0.0 [rewrite.conv2d-clustering] num_clusters = 16 cluster_centroids_init = "CentroidInitialization.LINEAR" +activation = "relu" diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml index 87822ed..d0e05a7 100644 --- a/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml +++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml @@ -15,3 +15,4 @@ augmentations.mixup_strength = 0.0 [rewrite.conv2d-sparsity] sparsity_m = 2 sparsity_n = 4 +activation = "relu" diff --git a/tests/conftest.py b/tests/conftest.py index a35ad4d..a64f320 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -126,9 +126,28 @@ def get_test_keras_model() -> keras.Model: return model +def get_test_keras_model_no_activation() -> keras.Model: + """Return test Keras model.""" + model = keras.Sequential( + [ + keras.Input(shape=(28, 28, 1), batch_size=1, name="input"), + keras.layers.Reshape((28, 28, 1)), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv1"), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv2"), + keras.layers.MaxPool2D(2, 2), + keras.layers.Flatten(), + keras.layers.Dense(10, name="output"), + ] + ) + + model.compile(optimizer="sgd", loss="mean_squared_error") + return model + + TEST_MODEL_KERAS_FILE = "test_model.h5" TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite" TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite" +TEST_MODEL_TFLITE_NO_ACT_FILE = "test_model_no_act.tflite" TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite" TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model" TEST_MODEL_INVALID_FILE = "invalid.tflite" @@ -153,6 +172,13 @@ def fixture_test_models_path( keras_model, quantized=False, output_path=tmp_path / TEST_MODEL_TFLITE_FP32_FILE ) + # Un-quantized TensorFlow Lite model with ReLU activation (fp32) + convert_to_tflite( + get_test_keras_model_no_activation(), + quantized=False, + output_path=tmp_path / TEST_MODEL_TFLITE_NO_ACT_FILE, + ) + # Quantized TensorFlow Lite model (int8) tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path) @@ -195,6 +221,12 @@ def fixture_test_tflite_vela_model(test_models_path: Path) -> Path: return test_models_path / TEST_MODEL_TFLITE_VELA_FILE +@pytest.fixture(scope="session", name="test_tflite_no_act_model") +def fixture_test_tflite_no_act_model(test_models_path: Path) -> Path: + """Return test TensorFlow Lite model with relu activation.""" + return test_models_path / TEST_MODEL_TFLITE_NO_ACT_FILE + + @pytest.fixture(scope="session", name="test_tf_model") def fixture_test_tf_model(test_models_path: Path) -> Path: """Return test TensorFlow Lite model.""" diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index cabe55f..03b230f 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -16,6 +16,7 @@ from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.rewrite.core.train import augment_fn_twins from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS +from mlia.nn.rewrite.core.train import detect_activation_from_rewrite_function from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train @@ -250,3 +251,41 @@ def test_train_checkpoint( use_unmodified_model=False, quantized=True, ) + + +def test_detect_activation_from_rewrite_function_no_activation( + caplog: pytest.LogCaptureFixture, test_tflite_no_act_model: Path +) -> None: + """ + Test function detect_activation_from_rewrite_function() + with a model with no activation functions. + """ + caplog.set_level(level=20) + activation = detect_activation_from_rewrite_function( + test_tflite_no_act_model.as_posix() + ) + log_records = caplog.get_records(when="call") + logging_messages = [x.message for x in log_records if x.levelno == 20] + assert activation == "relu" + assert ( + "No activation function specified, setting activation function to ReLU" + in logging_messages + ) + + +def test_detect_activation_from_rewrite_function_relu_activation( + caplog: pytest.LogCaptureFixture, test_tflite_model: Path +) -> None: + """ + Test function detect_activation_from_rewrite_function() + with a model with ReLU activation functions. + """ + caplog.set_level(level=20) + activation = detect_activation_from_rewrite_function(test_tflite_model.as_posix()) + log_records = caplog.get_records(when="call") + logging_messages = [x.message for x in log_records if x.levelno == 20] + assert activation == "relu" + assert ( + "No activation function specified, setting activation function " + "to most common activation detected in rewrite graph: relu" in logging_messages + ) diff --git a/tests/test_nn_rewrite_library_helper_functions.py b/tests/test_nn_rewrite_library_helper_functions.py index c3117f0..9e880aa 100644 --- a/tests/test_nn_rewrite_library_helper_functions.py +++ b/tests/test_nn_rewrite_library_helper_functions.py @@ -3,13 +3,16 @@ """Tests for module mlia.nn.rewrite.library.helper_functions.""" from __future__ import annotations +from contextlib import ExitStack as does_not_raise from typing import Any import numpy as np import pytest from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 +from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function def compute_conv_output( @@ -49,3 +52,44 @@ def test_compute_conv2d_parameters( np.random.rand(1, *input_shape), input_shape, conv_parameters ) assert np.equal(computed_output_shape, output_shape).all() + + +@pytest.mark.parametrize( + "activation, expected_function_type, expected_extra_args, expected_error", + [ + ("relu", keras.layers.ReLU, {}, does_not_raise()), + ("relu6", keras.layers.ReLU, {"max_value": 6}, does_not_raise()), + ("none", keras.layers.Identity, {}, does_not_raise()), + ( + "wrong_key", + keras.layers.Identity, + {}, + pytest.raises( + KeyError, + match=( + "Expected activation function to be " + rf"in \{ACTIVATION_FUNCTION_LIST}\, found wrong_key" + ), + ), + ), + ], +) +def test_get_activation_functions( + activation: str, + expected_function_type: type[keras.layers.Layer], + expected_extra_args: dict, + expected_error: Any, +) -> None: + """ + Check the get_activation_function returns + the expected layer and extra arguments. + """ + with expected_error: + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + assert isinstance( + activation_function(**activation_function_extra_args), + expected_function_type, + ) + assert expected_extra_args == activation_function_extra_args |