aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md20
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py9
-rw-r--r--src/mlia/nn/rewrite/core/train.py52
-rw-r--r--src/mlia/nn/rewrite/library/clustering.py7
-rw-r--r--src/mlia/nn/rewrite/library/helper_functions.py26
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py12
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml1
-rw-r--r--src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml1
-rw-r--r--tests/conftest.py32
-rw-r--r--tests/test_nn_rewrite_core_train.py39
-rw-r--r--tests/test_nn_rewrite_library_helper_functions.py44
11 files changed, 231 insertions, 12 deletions
diff --git a/README.md b/README.md
index a0e07f4..c5889ee 100644
--- a/README.md
+++ b/README.md
@@ -229,19 +229,19 @@ There are a number of predefined profiles for rewrites shown below:
| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Num Clusters | Cluster Centroids Init |
| :-------------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :----------: | :--------------------------------: |
-| optimization-fully-connected-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 32 | "CentroidInitialization.LINEAR" |
+| optimization-fully-connected-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 16 | "CentroidInitialization.LINEAR" |
| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Sparsity M | Sparsity N |
| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: |
| optimization-fully-connected-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 2 | 4 |
-| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Num Clusters | Cluster Centroids Init |
-| :-------------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :----------: | :--------------------------------: |
-| optimization-conv2d-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 32 | "CentroidInitialization.LINEAR" |
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Num Clusters | Cluster Centroids Init | Activation |
+| :-------------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :----------: | :--------------------------------: | :--------: |
+| optimization-conv2d-clustering | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 16 | "CentroidInitialization.LINEAR" | "relu" |
-| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Sparsity M | Sparsity N |
-| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: |
-| optimization-conv2d-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 2 | 4 |
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Sparsity M | Sparsity N | Activation |
+| :-----------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :--------: | :--------: | :--------: |
+| optimization-conv2d-pruning | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 2 | 4 | "relu" |
These are summarized below:
@@ -251,6 +251,12 @@ These are summarized below:
* optimization-conv2d-clustering - Provides training parameters for rewrites and cluster specific parameters for the conv2d-clustering rewrite
* optimization-conv2d-pruning - Provides training parameters for rewrites and pruning specific parameters for the conv2d-sparsity rewrite
+Note for convolutional rewrites (e.g. optimization-conv2d-pruning). The activation function for the rewrite can be selected in the optimization profile from the following list:
+
+* "relu" - Standard ReLU activation function
+* "relu6" - ReLU6 activation function i.e. ReLU activation function capped at 6
+* "none" - No activation function
+
The user can also specify custom augmentations as part of the training parameters. An example of this can be found in the following optimization profile:
| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Augmentations - gaussian_strength | Augmentations - mixup_strength |
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 78fa533..a802c51 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -55,7 +55,7 @@ class Rewrite(ABC):
try:
return self.function(input_shape, output_shape, **kwargs)
except TypeError as ex:
- expected_args = getfullargspec(self.function).args
+ expected_args = self.return_rewrite_func_args()
if "input_shape" in expected_args:
expected_args.remove("input_shape")
if "output_shape" in expected_args:
@@ -72,6 +72,10 @@ class Rewrite(ABC):
"""Return a quantized model if required."""
return model
+ def return_rewrite_func_args(self) -> list[str]:
+ """Return the expected args of the rewrite function."""
+ return getfullargspec(self.function).args
+
@abstractmethod
def training_callbacks(self) -> list:
"""Return rewrite callbacks."""
@@ -304,6 +308,9 @@ class RewritingOptimizer(Optimizer):
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
train_params=self.optimizer_configuration.train_params,
rewrite_specific_params=self.optimizer_configuration.rewrite_specific_params, # pylint: disable=line-too-long
+ detect_activation_function=(
+ "activation" in rewrite.return_rewrite_func_args()
+ ),
)
if orig_vs_repl_stats:
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index e99c7e9..570968a 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -34,13 +34,13 @@ from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.utils.logging import log_action
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
@@ -84,6 +84,7 @@ def train( # pylint: disable=too-many-arguments
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
rewrite_specific_params: dict | None = None,
+ detect_activation_function: bool = False,
) -> Any:
"""Extract and train a model, and return the results."""
if unmodified_model:
@@ -124,6 +125,7 @@ def train( # pylint: disable=too-many-arguments
is_qat=is_qat,
train_params=train_params,
rewrite_specific_params=rewrite_specific_params,
+ detect_activation_function=detect_activation_function,
)
for i, filename in enumerate(tflite_filenames):
@@ -351,6 +353,41 @@ def set_up_data_pipeline(
return dataset, steps_per_epoch
+def detect_activation_from_rewrite_function(model_path: str) -> str:
+ """Given a rewrite model, choose the most common activation function."""
+ interpreter = tf.lite.Interpreter(model_path=model_path)
+ interpreter.allocate_tensors()
+ act_func_match_list = []
+ for tensor_details in interpreter.get_tensor_details():
+ for act_func in ACTIVATION_FUNCTION_LIST:
+ tensor_name = tensor_details["name"].lower()
+ if act_func in tensor_name:
+ act_func_idx = tensor_name.index(act_func)
+ if (
+ len(tensor_name) == act_func_idx + len(act_func)
+ or tensor_name[act_func_idx + len(act_func)] == ";"
+ ):
+ act_func_match_list.append(
+ tensor_name[
+ act_func_idx : act_func_idx + len(act_func) # noqa: E203
+ ]
+ )
+ act_func_match = "relu"
+ if len(act_func_match_list) == 0:
+ logger.info(
+ "No activation function specified, setting activation function to ReLU"
+ )
+ else:
+ act_func_match = max(set(act_func_match_list), key=act_func_match.count)
+ logger.info(
+ "No activation function specified, "
+ "setting activation function to most "
+ "common activation detected in rewrite graph: %s",
+ act_func_match,
+ )
+ return act_func_match
+
+
def train_in_dir(
train_dir: str,
baseline_dir: Any,
@@ -359,6 +396,7 @@ def train_in_dir(
is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
rewrite_specific_params: dict | None = None,
+ detect_activation_function: bool = False,
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
and output.tfrec in train_dir.
@@ -375,6 +413,18 @@ def train_in_dir(
)
replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir))
+ if detect_activation_function and (
+ rewrite_specific_params is None
+ or "activation" not in list(rewrite_specific_params.keys())
+ ):
+ detected_activation_function = detect_activation_from_rewrite_function(
+ ExtractPaths.tflite.replace(train_dir).as_posix()
+ )
+ if rewrite_specific_params:
+ rewrite_specific_params["activation"] = detected_activation_function
+ else:
+ rewrite_specific_params = {"activation": detected_activation_function}
+
input_name, output_name = _get_io_tensors(teacher)
model_is_quantized = replace.is_tensor_quantized(name=input_name)
diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py
index a81d2d4..b159763 100644
--- a/src/mlia/nn/rewrite/library/clustering.py
+++ b/src/mlia/nn/rewrite/library/clustering.py
@@ -7,6 +7,7 @@ import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+from mlia.nn.rewrite.library.helper_functions import get_activation_function
def fc_clustering_rewrite(
@@ -42,6 +43,7 @@ def conv2d_clustering_rewrite(
cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long
"CentroidInitialization.LINEAR"
),
+ activation: str = "relu",
) -> keras.Model:
"""Conv2d TensorFlow Lite model ready for clustering."""
rewrite_params = {
@@ -51,13 +53,16 @@ def conv2d_clustering_rewrite(
conv2d_parameters = compute_conv2d_parameters(
input_shape=input_shape, output_shape=output_shape
)
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
model = tfmot.clustering.keras.cluster_weights(
to_cluster=keras.Sequential(
[
keras.layers.InputLayer(input_shape=input_shape),
keras.layers.Conv2D(**conv2d_parameters),
keras.layers.BatchNormalization(),
- keras.layers.ReLU(),
+ activation_function(**activation_function_extra_args),
]
),
**rewrite_params
diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py
index 4f08170..58d84b1 100644
--- a/src/mlia/nn/rewrite/library/helper_functions.py
+++ b/src/mlia/nn/rewrite/library/helper_functions.py
@@ -5,6 +5,32 @@ import math
from typing import Any
import numpy as np
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+
+ACTIVATION_FUNCTION_PRESETS = {
+ "relu": {"layer_func": keras.layers.ReLU, "extra_args": {}},
+ "relu6": {"layer_func": keras.layers.ReLU, "extra_args": {"max_value": 6}},
+ "none": {"layer_func": keras.layers.Identity, "extra_args": {}},
+}
+ACTIVATION_FUNCTION_LIST = [
+ act_func for act_func, _ in ACTIVATION_FUNCTION_PRESETS.items()
+]
+
+
+def get_activation_function(
+ activation: str = "relu",
+) -> tuple[type[keras.layers.Layer], dict]:
+ """Get the activation function from a key."""
+ if activation not in ACTIVATION_FUNCTION_LIST:
+ raise KeyError(
+ "Expected activation function to be "
+ f"in {ACTIVATION_FUNCTION_LIST}, found {activation}"
+ )
+ activation_function = ACTIVATION_FUNCTION_PRESETS[activation]["layer_func"]
+ activation_function_extra_args = ACTIVATION_FUNCTION_PRESETS[activation][
+ "extra_args"
+ ]
+ return activation_function, activation_function_extra_args
def compute_conv2d_parameters(
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
index 2342e3d..95f99a7 100644
--- a/src/mlia/nn/rewrite/library/sparsity.py
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -7,6 +7,7 @@ import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+from mlia.nn.rewrite.library.helper_functions import get_activation_function
def fc_sparsity_rewrite(
@@ -31,19 +32,26 @@ def fc_sparsity_rewrite(
def conv2d_sparsity_rewrite(
- input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4
+ input_shape: Any,
+ output_shape: Any,
+ sparsity_m: int = 2,
+ sparsity_n: int = 4,
+ activation: str = "relu",
) -> keras.Model:
"""Conv2d TensorFlow Lite model ready for sparse pruning."""
conv2d_parameters = compute_conv2d_parameters(
input_shape=input_shape, output_shape=output_shape
)
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
model = tfmot.sparsity.keras.prune_low_magnitude(
to_prune=keras.Sequential(
[
keras.layers.InputLayer(input_shape=input_shape),
keras.layers.Conv2D(**conv2d_parameters),
keras.layers.BatchNormalization(),
- keras.layers.ReLU(),
+ activation_function(**activation_function_extra_args),
]
),
sparsity_m_by_n=(
diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml
index 4e09e0f..fe50c31 100644
--- a/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml
+++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-clustering.toml
@@ -15,3 +15,4 @@ augmentations.mixup_strength = 0.0
[rewrite.conv2d-clustering]
num_clusters = 16
cluster_centroids_init = "CentroidInitialization.LINEAR"
+activation = "relu"
diff --git a/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml
index 87822ed..d0e05a7 100644
--- a/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml
+++ b/src/mlia/resources/optimization_profiles/optimization-conv2d-pruning.toml
@@ -15,3 +15,4 @@ augmentations.mixup_strength = 0.0
[rewrite.conv2d-sparsity]
sparsity_m = 2
sparsity_n = 4
+activation = "relu"
diff --git a/tests/conftest.py b/tests/conftest.py
index a35ad4d..a64f320 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -126,9 +126,28 @@ def get_test_keras_model() -> keras.Model:
return model
+def get_test_keras_model_no_activation() -> keras.Model:
+ """Return test Keras model."""
+ model = keras.Sequential(
+ [
+ keras.Input(shape=(28, 28, 1), batch_size=1, name="input"),
+ keras.layers.Reshape((28, 28, 1)),
+ keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv1"),
+ keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv2"),
+ keras.layers.MaxPool2D(2, 2),
+ keras.layers.Flatten(),
+ keras.layers.Dense(10, name="output"),
+ ]
+ )
+
+ model.compile(optimizer="sgd", loss="mean_squared_error")
+ return model
+
+
TEST_MODEL_KERAS_FILE = "test_model.h5"
TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite"
TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite"
+TEST_MODEL_TFLITE_NO_ACT_FILE = "test_model_no_act.tflite"
TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite"
TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model"
TEST_MODEL_INVALID_FILE = "invalid.tflite"
@@ -153,6 +172,13 @@ def fixture_test_models_path(
keras_model, quantized=False, output_path=tmp_path / TEST_MODEL_TFLITE_FP32_FILE
)
+ # Un-quantized TensorFlow Lite model with ReLU activation (fp32)
+ convert_to_tflite(
+ get_test_keras_model_no_activation(),
+ quantized=False,
+ output_path=tmp_path / TEST_MODEL_TFLITE_NO_ACT_FILE,
+ )
+
# Quantized TensorFlow Lite model (int8)
tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE
convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path)
@@ -195,6 +221,12 @@ def fixture_test_tflite_vela_model(test_models_path: Path) -> Path:
return test_models_path / TEST_MODEL_TFLITE_VELA_FILE
+@pytest.fixture(scope="session", name="test_tflite_no_act_model")
+def fixture_test_tflite_no_act_model(test_models_path: Path) -> Path:
+ """Return test TensorFlow Lite model with relu activation."""
+ return test_models_path / TEST_MODEL_TFLITE_NO_ACT_FILE
+
+
@pytest.fixture(scope="session", name="test_tf_model")
def fixture_test_tf_model(test_models_path: Path) -> Path:
"""Return test TensorFlow Lite model."""
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index cabe55f..03b230f 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -16,6 +16,7 @@ from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.nn.rewrite.core.train import augment_fn_twins
from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
+from mlia.nn.rewrite.core.train import detect_activation_from_rewrite_function
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
@@ -250,3 +251,41 @@ def test_train_checkpoint(
use_unmodified_model=False,
quantized=True,
)
+
+
+def test_detect_activation_from_rewrite_function_no_activation(
+ caplog: pytest.LogCaptureFixture, test_tflite_no_act_model: Path
+) -> None:
+ """
+ Test function detect_activation_from_rewrite_function()
+ with a model with no activation functions.
+ """
+ caplog.set_level(level=20)
+ activation = detect_activation_from_rewrite_function(
+ test_tflite_no_act_model.as_posix()
+ )
+ log_records = caplog.get_records(when="call")
+ logging_messages = [x.message for x in log_records if x.levelno == 20]
+ assert activation == "relu"
+ assert (
+ "No activation function specified, setting activation function to ReLU"
+ in logging_messages
+ )
+
+
+def test_detect_activation_from_rewrite_function_relu_activation(
+ caplog: pytest.LogCaptureFixture, test_tflite_model: Path
+) -> None:
+ """
+ Test function detect_activation_from_rewrite_function()
+ with a model with ReLU activation functions.
+ """
+ caplog.set_level(level=20)
+ activation = detect_activation_from_rewrite_function(test_tflite_model.as_posix())
+ log_records = caplog.get_records(when="call")
+ logging_messages = [x.message for x in log_records if x.levelno == 20]
+ assert activation == "relu"
+ assert (
+ "No activation function specified, setting activation function "
+ "to most common activation detected in rewrite graph: relu" in logging_messages
+ )
diff --git a/tests/test_nn_rewrite_library_helper_functions.py b/tests/test_nn_rewrite_library_helper_functions.py
index c3117f0..9e880aa 100644
--- a/tests/test_nn_rewrite_library_helper_functions.py
+++ b/tests/test_nn_rewrite_library_helper_functions.py
@@ -3,13 +3,16 @@
"""Tests for module mlia.nn.rewrite.library.helper_functions."""
from __future__ import annotations
+from contextlib import ExitStack as does_not_raise
from typing import Any
import numpy as np
import pytest
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST
from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
+from mlia.nn.rewrite.library.helper_functions import get_activation_function
def compute_conv_output(
@@ -49,3 +52,44 @@ def test_compute_conv2d_parameters(
np.random.rand(1, *input_shape), input_shape, conv_parameters
)
assert np.equal(computed_output_shape, output_shape).all()
+
+
+@pytest.mark.parametrize(
+ "activation, expected_function_type, expected_extra_args, expected_error",
+ [
+ ("relu", keras.layers.ReLU, {}, does_not_raise()),
+ ("relu6", keras.layers.ReLU, {"max_value": 6}, does_not_raise()),
+ ("none", keras.layers.Identity, {}, does_not_raise()),
+ (
+ "wrong_key",
+ keras.layers.Identity,
+ {},
+ pytest.raises(
+ KeyError,
+ match=(
+ "Expected activation function to be "
+ rf"in \{ACTIVATION_FUNCTION_LIST}\, found wrong_key"
+ ),
+ ),
+ ),
+ ],
+)
+def test_get_activation_functions(
+ activation: str,
+ expected_function_type: type[keras.layers.Layer],
+ expected_extra_args: dict,
+ expected_error: Any,
+) -> None:
+ """
+ Check the get_activation_function returns
+ the expected layer and extra arguments.
+ """
+ with expected_error:
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ assert isinstance(
+ activation_function(**activation_function_extra_args),
+ expected_function_type,
+ )
+ assert expected_extra_args == activation_function_extra_args