aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py119
1 files changed, 97 insertions, 22 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 88efa23..570968a 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -34,13 +34,13 @@ from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.utils.logging import log_action
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule)
class TrainingParameters:
"""Define default parameters for the training."""
- augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"]
+ augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["none"]
batch_size: int = 32
steps: int = 48000
learning_rate: float = 1e-3
@@ -83,6 +83,8 @@ def train( # pylint: disable=too-many-arguments
input_tensors: list,
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
+ rewrite_specific_params: dict | None = None,
+ detect_activation_function: bool = False,
) -> Any:
"""Extract and train a model, and return the results."""
if unmodified_model:
@@ -122,6 +124,8 @@ def train( # pylint: disable=too-many-arguments
rewrite=rewrite,
is_qat=is_qat,
train_params=train_params,
+ rewrite_specific_params=rewrite_specific_params,
+ detect_activation_function=detect_activation_function,
)
for i, filename in enumerate(tflite_filenames):
@@ -147,7 +151,8 @@ def train( # pylint: disable=too-many-arguments
# Assess the output diff between the parts after the rewrite subgraph
# in original and optimized model
optimized_end_path = Path(train_dir, "optimized_end.tfrec")
- end_path = Path(train_dir, "end.tfrec")
+ optimized_end_path_dequant = Path(train_dir, "optimized_end_dequant.tfrec")
+ end_path = Path(train_dir, "end_dequant.tfrec")
record_model(
str(input_tfrec),
@@ -155,8 +160,10 @@ def train( # pylint: disable=too-many-arguments
optimized_end_path,
num_procs=train_params.num_procs,
num_threads=train_params.num_threads,
+ dequantize_output=True,
)
- mae, nrmse = diff_stats(end_path, str(optimized_end_path))
+
+ mae, nrmse = diff_stats(end_path, optimized_end_path_dequant)
if unmodified_model_dir:
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
@@ -179,24 +186,27 @@ def eval_in_dir(
model_input = (
model_input_path
if model_input_path.exists()
- else ExtractPaths.tfrec.input(target_dir, False)
+ else ExtractPaths.tfrec.input(target_dir, True)
)
output = (
model_output_path
if model_output_path.exists()
- else ExtractPaths.tfrec.output(target_dir, False)
+ else ExtractPaths.tfrec.output(target_dir, True)
)
with tempfile.TemporaryDirectory() as tmp_dir:
predict = Path(tmp_dir, "predict.tfrec")
+ predict_dequant = Path(tmp_dir, "predict_dequant.tfrec")
record_model(
str(model_input),
new_part,
str(predict),
num_procs=num_procs,
num_threads=num_threads,
+ dequantize_output=True,
+ quantize_input=True,
)
- mae, nrmse = diff_stats(str(output), str(predict))
+ mae, nrmse = diff_stats(str(output), predict_dequant)
return mae, nrmse
@@ -249,7 +259,7 @@ def set_up_data_pipeline(
augmentations: tuple[float | None, float | None],
steps: int,
batch_size: int = 32,
-) -> tf.data.Dataset:
+) -> tuple[tf.data.Dataset, int]:
"""Create a data pipeline for training of the replacement model."""
_check_model_compatibility(teacher, replace)
@@ -340,7 +350,42 @@ def set_up_data_pipeline(
dataset = dataset.map(restore_shapes)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
- return dataset
+ return dataset, steps_per_epoch
+
+
+def detect_activation_from_rewrite_function(model_path: str) -> str:
+ """Given a rewrite model, choose the most common activation function."""
+ interpreter = tf.lite.Interpreter(model_path=model_path)
+ interpreter.allocate_tensors()
+ act_func_match_list = []
+ for tensor_details in interpreter.get_tensor_details():
+ for act_func in ACTIVATION_FUNCTION_LIST:
+ tensor_name = tensor_details["name"].lower()
+ if act_func in tensor_name:
+ act_func_idx = tensor_name.index(act_func)
+ if (
+ len(tensor_name) == act_func_idx + len(act_func)
+ or tensor_name[act_func_idx + len(act_func)] == ";"
+ ):
+ act_func_match_list.append(
+ tensor_name[
+ act_func_idx : act_func_idx + len(act_func) # noqa: E203
+ ]
+ )
+ act_func_match = "relu"
+ if len(act_func_match_list) == 0:
+ logger.info(
+ "No activation function specified, setting activation function to ReLU"
+ )
+ else:
+ act_func_match = max(set(act_func_match_list), key=act_func_match.count)
+ logger.info(
+ "No activation function specified, "
+ "setting activation function to most "
+ "common activation detected in rewrite graph: %s",
+ act_func_match,
+ )
+ return act_func_match
def train_in_dir(
@@ -350,6 +395,8 @@ def train_in_dir(
rewrite: Callable,
is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
+ rewrite_specific_params: dict | None = None,
+ detect_activation_function: bool = False,
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
and output.tfrec in train_dir.
@@ -366,6 +413,18 @@ def train_in_dir(
)
replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir))
+ if detect_activation_function and (
+ rewrite_specific_params is None
+ or "activation" not in list(rewrite_specific_params.keys())
+ ):
+ detected_activation_function = detect_activation_from_rewrite_function(
+ ExtractPaths.tflite.replace(train_dir).as_posix()
+ )
+ if rewrite_specific_params:
+ rewrite_specific_params["activation"] = detected_activation_function
+ else:
+ rewrite_specific_params = {"activation": detected_activation_function}
+
input_name, output_name = _get_io_tensors(teacher)
model_is_quantized = replace.is_tensor_quantized(name=input_name)
@@ -373,7 +432,7 @@ def train_in_dir(
if model_is_quantized:
replace.check_datatypes(np.int8)
- dataset = set_up_data_pipeline(
+ dataset, steps_per_epoch = set_up_data_pipeline(
teacher,
replace,
train_dir,
@@ -390,7 +449,13 @@ def train_in_dir(
loss_fn = keras.losses.MeanSquaredError()
model = create_model(
- rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized
+ rewrite,
+ input_shape,
+ output_shape,
+ optimizer,
+ loss_fn,
+ model_is_quantized,
+ rewrite_specific_params=rewrite_specific_params,
)
logger.info(model.summary())
@@ -453,13 +518,12 @@ def train_in_dir(
input_shape,
output_shape,
loss_fn,
+ steps_per_epoch,
post_process=True,
)
-
- # Placeholder for now, will be parametrized later (MLIA-1114)
- # rewrite.check_optimization( # type: ignore[attr-defined]
- # model, number_of_clusters=32
- # )
+ rewrite.check_optimization( # type: ignore[attr-defined]
+ model, **rewrite_specific_params if rewrite_specific_params else {}
+ )
if model_is_quantized and is_qat:
model = rewrite.preserved_quantize(model) # type: ignore[attr-defined]
checkpoints = (
@@ -492,12 +556,12 @@ def train_in_dir(
input_shape,
output_shape,
loss_fn,
+ steps_per_epoch,
)
- # Placeholder for now, will be parametrized later (MLIA-1114)
- # rewrite.check_optimization( # type: ignore[attr-defined]
- # model, number_of_clusters=32
- # )
+ rewrite.check_optimization( # type: ignore[attr-defined]
+ model, **rewrite_specific_params if rewrite_specific_params else {}
+ )
teacher.close()
return output_filenames
@@ -520,9 +584,13 @@ def create_model( # pylint: disable=too-many-arguments
loss_fn: Callable,
model_is_quantized: bool,
model_to_load_from: keras.model | None = None,
+ rewrite_specific_params: dict | None = None,
) -> keras.Model:
"""Create a model, optionally from another."""
- model = rewrite(input_shape, output_shape)
+ if rewrite_specific_params:
+ model = rewrite(input_shape, output_shape, **rewrite_specific_params)
+ else:
+ model = rewrite(input_shape, output_shape)
if model_is_quantized:
model = rewrite.quantize(model) # type: ignore[attr-defined]
model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn)
@@ -548,7 +616,9 @@ def model_fit( # pylint: disable=too-many-arguments
input_shape: int,
output_shape: int,
loss_fn: Callable,
+ steps_per_epoch: int,
post_process: bool = False,
+ rewrite_specific_params: dict | None = None,
) -> keras.Model:
"""Train a tflite model."""
steps_so_far = 0
@@ -588,13 +658,18 @@ def model_fit( # pylint: disable=too-many-arguments
loss_fn,
model_is_quantized,
model_to_load_from=model,
+ rewrite_specific_params=rewrite_specific_params,
)
else:
model_to_save = model
else:
checkpoint_filename = str(output_filename)
+ logger.info("Evaluate final Keras Model using %d steps", steps_per_epoch)
+ model.evaluate(
+ dataset,
+ steps=steps_per_epoch,
+ )
model_to_save = model
-
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):