aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py251
-rw-r--r--src/mlia/nn/rewrite/core/train.py120
2 files changed, 301 insertions, 70 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index e2c097c..a6bd306 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -8,13 +8,20 @@ import tempfile
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
+from inspect import getfullargspec
from pathlib import Path
+from statistics import fmean
from typing import Any
from typing import Callable
+from typing import Generator
import numpy as np
+import tensorflow as tf
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module
+ is_pruned_m_by_n,
+)
from mlia.core.errors import ConfigurationError
from mlia.core.reporting import Column
@@ -24,33 +31,57 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
-from mlia.nn.rewrite.library.fc_clustering_layer import (
- get_keras_model_clus as fc_clustering_rewrite,
-)
-from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
-from mlia.nn.rewrite.library.fc_sparsity24_layer import (
- get_keras_model as fc_rewrite_sparsity24,
-)
+from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite
+from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite
+from mlia.nn.rewrite.library.layers import conv2d_rewrite
+from mlia.nn.rewrite.library.layers import fc_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_unstructured_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_unstructured_rewrite
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
-
logger = logging.getLogger(__name__)
-RewriteCallable = Callable[[Any, Any], keras.Model]
+RewriteCallable = Callable[..., keras.Model]
class Rewrite(ABC):
"""Abstract class for rewrite logic to be used by RewritingOptimizer."""
- def __init__(self, name: str, rewrite_fn: RewriteCallable):
+ def __init__(
+ self,
+ name: str,
+ rewrite_fn: RewriteCallable,
+ rewrite_fn_extra_args: dict[str, Any] | None = None,
+ ):
"""Initialize a Rewrite instance with a given name and an optional function."""
self.name = name
self.function = rewrite_fn
+ self.function_extra_args = rewrite_fn_extra_args
- def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model:
+ def __call__(
+ self, input_shape: Any, output_shape: Any, **kwargs: Any
+ ) -> keras.Model:
"""Perform the rewrite operation using the configured function."""
try:
- return self.function(input_shape, output_shape)
+ if self.function_extra_args:
+ return self.function(
+ input_shape, output_shape, **kwargs, **self.function_extra_args
+ )
+
+ return self.function(input_shape, output_shape, **kwargs)
+ except TypeError as ex:
+ expected_args = self.return_rewrite_func_args()
+ if "input_shape" in expected_args:
+ expected_args.remove("input_shape")
+ if "output_shape" in expected_args:
+ expected_args.remove("output_shape")
+ raise KeyError(
+ f"Found unexpected parameters for rewrite. Expected (sub)set "
+ f"of {expected_args} found unexpected parameter(s) "
+ f"{list(set(list(kwargs.keys())) - set(expected_args))}"
+ ) from ex
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
@@ -58,21 +89,25 @@ class Rewrite(ABC):
"""Return a quantized model if required."""
return model
+ def return_rewrite_func_args(self) -> list[str]:
+ """Return the expected args of the rewrite function."""
+ return getfullargspec(self.function).args
+
@abstractmethod
def training_callbacks(self) -> list:
- """Return default rewrite callbacks."""
+ """Return rewrite callbacks."""
@abstractmethod
def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ """Return post-processing rewrite option."""
@abstractmethod
- def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ def check_optimization(self, model: keras.Model) -> bool:
"""Check if the optimization has produced the correct result."""
class GenericRewrite(Rewrite):
- """Graph rewrite logic for fully-connected rewrite."""
+ """Rewrite class for generic rewrites e.g. fully-connected."""
def quantize(self, model: keras.Model) -> keras.Model:
"""Return a quantized model if required."""
@@ -83,10 +118,10 @@ class GenericRewrite(Rewrite):
return []
def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ """Return default post-processing rewrite option."""
return model
- def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ def check_optimization(self, model: keras.Model, **_: Any) -> bool:
"""Not needed here."""
return True
@@ -99,16 +134,31 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC):
"""Apply optimization-aware quantization to a given model."""
return model
+ def check_optimization_generator(
+ self, model: keras.Model
+ ) -> Generator[tuple[tf.Tensor, keras.layers.Layer], None, None]:
+ """Loop for check_optimization function."""
+ for layer in model.layers:
+ for weight in layer.weights:
+ if "kernel" in weight.name:
+ if (
+ "kernel_min" in weight.name
+ or "kernel_max" in weight.name
+ or "depthwise" in weight.name
+ ):
+ continue
+ yield weight, layer
+
-class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
- """Graph rewrite logic for fully-connected-sparsity24 rewrite."""
+class SparsityRewrite(QuantizeAwareTrainingRewrite):
+ """Base rewrite class for sparsity rewrites."""
pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
def quantize(self, model: keras.Model) -> keras.Model:
- """Skip quantization when using pruning rewrite."""
+ """Skip quantization when using sparsity rewrite."""
return model
def training_callbacks(self) -> list:
@@ -116,7 +166,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
return [self.pruning_callback()]
def post_process(self, model: keras.Model) -> keras.Model:
- """Pruning-specific post-processing rewrite options."""
+ """Pruning-specific post-processing rewrite option."""
return self.strip_pruning_wrapper(model)
def preserved_quantize(
@@ -129,16 +179,78 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
model,
tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
)
-
return model
- def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ def check_optimization(self, model: keras.Model) -> bool:
+ """Not needed here."""
+ return True
+
+
+class UnstructuredSparsityRewrite(SparsityRewrite):
+ """
+ Rewrite class for unstructured sparsity rewrite.
+
+ e.g. fully-connected-unstructured-sparsity.
+ """
+
+ def check_optimization(
+ self, model: keras.Model, final_sparsity: float = 0.5, **_: Any
+ ) -> bool:
"""Not needed here."""
+ found_sparsity_list = []
+ num_dec_places = str(final_sparsity)[::-1].find(".")
+ for weight, _ in self.check_optimization_generator(model=model):
+ weight_np = weight.numpy()
+ found_sparsity_list.append(
+ round(np.count_nonzero(weight_np) / weight_np.size, num_dec_places)
+ )
+ if len(found_sparsity_list) == 0:
+ logger.warning(
+ "\nWARNING: Could not find any layers "
+ "in rewrite that could be sparsely pruned"
+ )
+ return False
+ found_sparsity = fmean(found_sparsity_list)
+ if found_sparsity != final_sparsity:
+ logger.warning(
+ "\nWARNING: Found total sparsity of "
+ "rewrite model: %.2f "
+ "expected total sparsity to be: "
+ "%.2f\n",
+ found_sparsity,
+ final_sparsity,
+ )
+ return False
+ return True
+
+
+class StructuredSparsityRewrite(SparsityRewrite):
+ """Rewrite class for structured sparsity rewrite e.g. fully-connected-sparsity."""
+
+ def check_optimization(
+ self,
+ model: keras.Model,
+ sparsity_m: int = 2,
+ sparsity_n: int = 4,
+ **_: Any,
+ ) -> bool:
+ """Check if sparity has produced the correct result."""
+ for weight, layer in self.check_optimization_generator(model=model):
+ if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)):
+ logger.warning(
+ "\nWARNING: Could not find (%d, %d) sparsity, "
+ "in layer %s for weight %s \n",
+ sparsity_m,
+ sparsity_n,
+ layer.name,
+ weight.name,
+ )
+ return False
return True
class ClusteringRewrite(QuantizeAwareTrainingRewrite):
- """Graph clustering rewrite logic to be used by RewritingOptimizer."""
+ """Rewrite class for clustering rewrite e.g. fully-connected-clustering."""
_strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
@@ -151,32 +263,22 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite):
)
return cqat_model
- def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ def check_optimization(
+ self, model: keras.Model, num_clusters: int = 2, **_: Any
+ ) -> bool:
"""Check if clustering has produced the correct result."""
- number_of_clusters = kwargs.get("number_of_clusters")
- if not number_of_clusters:
- raise ValueError(
- """
- Expected check_preserved_quantize to have argument number_of_clusters.
- """
- )
-
- for layer in model.layers:
- for weight in layer.weights:
- if "kernel" in weight.name:
- if "kernel_min" in weight.name or "kernel_max" in weight.name:
- continue
- number_of_found_clusters = len(np.unique(weight))
- if number_of_found_clusters != number_of_clusters:
- logger.warning(
- "\nWARNING: Expected %d cluster(s), found %d "
- "cluster(s) in layer %s for weight %s \n",
- number_of_clusters,
- number_of_found_clusters,
- layer.name,
- weight.name,
- )
- return False
+ for weight, layer in self.check_optimization_generator(model=model):
+ number_of_found_clusters = len(np.unique(weight))
+ if number_of_found_clusters != num_clusters:
+ logger.warning(
+ "\nWARNING: Expected %d cluster(s), found %d "
+ "cluster(s) in layer %s for weight %s \n",
+ num_clusters,
+ number_of_found_clusters,
+ layer.name,
+ weight.name,
+ )
+ return False
return True
def training_callbacks(self) -> list:
@@ -184,7 +286,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite):
return []
def post_process(self, model: keras.Model) -> keras.Model:
- """Return the clustering stripped model."""
+ """Clustering-specific post-processing rewrite option."""
return self._strip_clustering_wrapper(model)
@@ -215,6 +317,7 @@ class RewriteConfiguration(OptimizerConfiguration):
layers_to_optimize: list[str] | None = None
dataset: Path | None = None
train_params: TrainingParameters = TrainingParameters()
+ rewrite_specific_params: dict | None = None
def __str__(self) -> str:
"""Return string representation of the configuration."""
@@ -227,8 +330,38 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
GenericRewrite("fully-connected", fc_rewrite),
- Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
+ GenericRewrite("conv2d", conv2d_rewrite),
+ GenericRewrite(
+ "depthwise-separable-conv2d",
+ conv2d_rewrite,
+ {"layer_type": keras.layers.SeparableConv2D},
+ ),
+ StructuredSparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite),
ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite),
+ ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite),
+ StructuredSparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite),
+ UnstructuredSparsityRewrite(
+ "conv2d-unstructured-sparsity", conv2d_sparsity_unstructured_rewrite
+ ),
+ UnstructuredSparsityRewrite(
+ "fully-connected-unstructured-sparsity",
+ fc_sparsity_unstructured_rewrite,
+ ),
+ ClusteringRewrite(
+ "depthwise-separable-conv2d-clustering",
+ conv2d_clustering_rewrite,
+ {"layer_type": keras.layers.SeparableConv2D},
+ ),
+ StructuredSparsityRewrite(
+ "depthwise-separable-conv2d-sparsity",
+ conv2d_sparsity_rewrite,
+ {"layer_type": keras.layers.SeparableConv2D},
+ ),
+ UnstructuredSparsityRewrite(
+ "depthwise-separable-conv2d-unstructured-sparsity",
+ conv2d_sparsity_unstructured_rewrite,
+ {"layer_type": keras.layers.SeparableConv2D},
+ ),
]
)
@@ -250,11 +383,13 @@ class RewritingOptimizer(Optimizer):
rewrite = RewritingOptimizer.registry.items[
self.optimizer_configuration.optimization_target
]
-
use_unmodified_model = True
tflite_model = self.model.model_path
- tfrecord = str(self.optimizer_configuration.dataset)
-
+ tfrecord = (
+ str(self.optimizer_configuration.dataset)
+ if self.optimizer_configuration.dataset
+ else None
+ )
tmp_dir = tempfile.mkdtemp()
tmp_output = Path(tmp_dir, "output.tflite")
@@ -266,12 +401,16 @@ class RewritingOptimizer(Optimizer):
source_model=tflite_model,
unmodified_model=tflite_model if use_unmodified_model else None,
output_model=str(tmp_output),
- input_tfrec=str(tfrecord),
+ input_tfrec=tfrecord,
rewrite=rewrite,
is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite),
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
train_params=self.optimizer_configuration.train_params,
+ rewrite_specific_params=self.optimizer_configuration.rewrite_specific_params, # pylint: disable=line-too-long
+ detect_activation_function=(
+ "activation" in rewrite.return_rewrite_func_args()
+ ),
)
if orig_vs_repl_stats:
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 4204978..b20430b 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -33,14 +33,15 @@ from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.utils.logging import log_action
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
@@ -73,18 +74,48 @@ class TrainingParameters:
checkpoint_at: list | None = None
+def generate_random_dataset(source_model: str, dataset_path: str) -> str:
+ """Generate random dataset for model."""
+ model = TFLiteModel(model_path=source_model)
+ input_name = model.input_tensors()[0]
+ model_is_quantized = model.is_tensor_quantized(name=input_name)
+ input_shape = model.shape_from_name[input_name][1:]
+ rand_data_path = dataset_path + "/rand_data.tfrec"
+ with NumpyTFWriter(rand_data_path) as writer:
+ for _ in range(5000):
+ input_data = np.random.rand(1, *input_shape)
+ input_data = (
+ input_data.astype(np.int8)
+ if model_is_quantized
+ else input_data.astype(np.float32)
+ )
+ writer.write({input_name: input_data})
+ return rand_data_path
+
+
def train( # pylint: disable=too-many-arguments
source_model: str,
unmodified_model: Any,
output_model: str,
- input_tfrec: str,
+ input_tfrec: str | None,
rewrite: Callable,
is_qat: bool,
input_tensors: list,
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
+ rewrite_specific_params: dict | None = None,
+ detect_activation_function: bool = False,
) -> Any:
"""Extract and train a model, and return the results."""
+ rand_data_dir_path = None
+ if not input_tfrec:
+ logger.info(
+ "INFO: No dataset given, using random data to perform the rewrite! "
+ )
+ rand_data_dir_path = (
+ tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
+ )
+ input_tfrec = generate_random_dataset(source_model, rand_data_dir_path.name)
if unmodified_model:
unmodified_model_dir = (
tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
@@ -122,6 +153,8 @@ def train( # pylint: disable=too-many-arguments
rewrite=rewrite,
is_qat=is_qat,
train_params=train_params,
+ rewrite_specific_params=rewrite_specific_params,
+ detect_activation_function=detect_activation_function,
)
for i, filename in enumerate(tflite_filenames):
@@ -163,7 +196,8 @@ def train( # pylint: disable=too-many-arguments
if unmodified_model_dir:
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
-
+ if rand_data_dir_path:
+ cast(tempfile.TemporaryDirectory, rand_data_dir_path).cleanup()
return results, [
mae,
nrmse,
@@ -349,6 +383,41 @@ def set_up_data_pipeline(
return dataset, steps_per_epoch
+def detect_activation_from_rewrite_function(model_path: str) -> str:
+ """Given a rewrite model, choose the most common activation function."""
+ interpreter = tf.lite.Interpreter(model_path=model_path)
+ interpreter.allocate_tensors()
+ act_func_match_list = []
+ for tensor_details in interpreter.get_tensor_details():
+ for act_func in ACTIVATION_FUNCTION_LIST:
+ tensor_name = tensor_details["name"].lower()
+ if act_func in tensor_name:
+ act_func_idx = tensor_name.index(act_func)
+ if (
+ len(tensor_name) == act_func_idx + len(act_func)
+ or tensor_name[act_func_idx + len(act_func)] == ";"
+ ):
+ act_func_match_list.append(
+ tensor_name[
+ act_func_idx : act_func_idx + len(act_func) # noqa: E203
+ ]
+ )
+ act_func_match = "relu"
+ if len(act_func_match_list) == 0:
+ logger.info(
+ "No activation function specified, setting activation function to ReLU"
+ )
+ else:
+ act_func_match = max(set(act_func_match_list), key=act_func_match.count)
+ logger.info(
+ "No activation function specified, "
+ "setting activation function to most "
+ "common activation detected in rewrite graph: %s",
+ act_func_match,
+ )
+ return act_func_match
+
+
def train_in_dir(
train_dir: str,
baseline_dir: Any,
@@ -356,6 +425,8 @@ def train_in_dir(
rewrite: Callable,
is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
+ rewrite_specific_params: dict | None = None,
+ detect_activation_function: bool = False,
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
and output.tfrec in train_dir.
@@ -372,6 +443,18 @@ def train_in_dir(
)
replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir))
+ if detect_activation_function and (
+ rewrite_specific_params is None
+ or "activation" not in list(rewrite_specific_params.keys())
+ ):
+ detected_activation_function = detect_activation_from_rewrite_function(
+ ExtractPaths.tflite.replace(train_dir).as_posix()
+ )
+ if rewrite_specific_params:
+ rewrite_specific_params["activation"] = detected_activation_function
+ else:
+ rewrite_specific_params = {"activation": detected_activation_function}
+
input_name, output_name = _get_io_tensors(teacher)
model_is_quantized = replace.is_tensor_quantized(name=input_name)
@@ -396,7 +479,13 @@ def train_in_dir(
loss_fn = keras.losses.MeanSquaredError()
model = create_model(
- rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized
+ rewrite,
+ input_shape,
+ output_shape,
+ optimizer,
+ loss_fn,
+ model_is_quantized,
+ rewrite_specific_params=rewrite_specific_params,
)
logger.info(model.summary())
@@ -462,11 +551,9 @@ def train_in_dir(
steps_per_epoch,
post_process=True,
)
-
- # Placeholder for now, will be parametrized later (MLIA-1114)
- # rewrite.check_optimization( # type: ignore[attr-defined]
- # model, number_of_clusters=32
- # )
+ rewrite.check_optimization( # type: ignore[attr-defined]
+ model, **rewrite_specific_params if rewrite_specific_params else {}
+ )
if model_is_quantized and is_qat:
model = rewrite.preserved_quantize(model) # type: ignore[attr-defined]
checkpoints = (
@@ -501,11 +588,10 @@ def train_in_dir(
loss_fn,
steps_per_epoch,
)
- # Placeholder for now, will be parametrized later (MLIA-1114)
- # rewrite.check_optimization( # type: ignore[attr-defined]
- # model, number_of_clusters=32
- # )
+ rewrite.check_optimization( # type: ignore[attr-defined]
+ model, **rewrite_specific_params if rewrite_specific_params else {}
+ )
teacher.close()
return output_filenames
@@ -528,9 +614,13 @@ def create_model( # pylint: disable=too-many-arguments
loss_fn: Callable,
model_is_quantized: bool,
model_to_load_from: keras.model | None = None,
+ rewrite_specific_params: dict | None = None,
) -> keras.Model:
"""Create a model, optionally from another."""
- model = rewrite(input_shape, output_shape)
+ if rewrite_specific_params:
+ model = rewrite(input_shape, output_shape, **rewrite_specific_params)
+ else:
+ model = rewrite(input_shape, output_shape)
if model_is_quantized:
model = rewrite.quantize(model) # type: ignore[attr-defined]
model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn)
@@ -558,6 +648,7 @@ def model_fit( # pylint: disable=too-many-arguments
loss_fn: Callable,
steps_per_epoch: int,
post_process: bool = False,
+ rewrite_specific_params: dict | None = None,
) -> keras.Model:
"""Train a tflite model."""
steps_so_far = 0
@@ -597,6 +688,7 @@ def model_fit( # pylint: disable=too-many-arguments
loss_fn,
model_is_quantized,
model_to_load_from=model,
+ rewrite_specific_params=rewrite_specific_params,
)
else:
model_to_save = model