aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py34
1 files changed, 23 insertions, 11 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 4204978..e99c7e9 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -83,6 +83,7 @@ def train( # pylint: disable=too-many-arguments
input_tensors: list,
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
+ rewrite_specific_params: dict | None = None,
) -> Any:
"""Extract and train a model, and return the results."""
if unmodified_model:
@@ -122,6 +123,7 @@ def train( # pylint: disable=too-many-arguments
rewrite=rewrite,
is_qat=is_qat,
train_params=train_params,
+ rewrite_specific_params=rewrite_specific_params,
)
for i, filename in enumerate(tflite_filenames):
@@ -356,6 +358,7 @@ def train_in_dir(
rewrite: Callable,
is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
+ rewrite_specific_params: dict | None = None,
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
and output.tfrec in train_dir.
@@ -396,7 +399,13 @@ def train_in_dir(
loss_fn = keras.losses.MeanSquaredError()
model = create_model(
- rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized
+ rewrite,
+ input_shape,
+ output_shape,
+ optimizer,
+ loss_fn,
+ model_is_quantized,
+ rewrite_specific_params=rewrite_specific_params,
)
logger.info(model.summary())
@@ -462,11 +471,9 @@ def train_in_dir(
steps_per_epoch,
post_process=True,
)
-
- # Placeholder for now, will be parametrized later (MLIA-1114)
- # rewrite.check_optimization( # type: ignore[attr-defined]
- # model, number_of_clusters=32
- # )
+ rewrite.check_optimization( # type: ignore[attr-defined]
+ model, **rewrite_specific_params if rewrite_specific_params else {}
+ )
if model_is_quantized and is_qat:
model = rewrite.preserved_quantize(model) # type: ignore[attr-defined]
checkpoints = (
@@ -501,11 +508,10 @@ def train_in_dir(
loss_fn,
steps_per_epoch,
)
- # Placeholder for now, will be parametrized later (MLIA-1114)
- # rewrite.check_optimization( # type: ignore[attr-defined]
- # model, number_of_clusters=32
- # )
+ rewrite.check_optimization( # type: ignore[attr-defined]
+ model, **rewrite_specific_params if rewrite_specific_params else {}
+ )
teacher.close()
return output_filenames
@@ -528,9 +534,13 @@ def create_model( # pylint: disable=too-many-arguments
loss_fn: Callable,
model_is_quantized: bool,
model_to_load_from: keras.model | None = None,
+ rewrite_specific_params: dict | None = None,
) -> keras.Model:
"""Create a model, optionally from another."""
- model = rewrite(input_shape, output_shape)
+ if rewrite_specific_params:
+ model = rewrite(input_shape, output_shape, **rewrite_specific_params)
+ else:
+ model = rewrite(input_shape, output_shape)
if model_is_quantized:
model = rewrite.quantize(model) # type: ignore[attr-defined]
model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn)
@@ -558,6 +568,7 @@ def model_fit( # pylint: disable=too-many-arguments
loss_fn: Callable,
steps_per_epoch: int,
post_process: bool = False,
+ rewrite_specific_params: dict | None = None,
) -> keras.Model:
"""Train a tflite model."""
steps_so_far = 0
@@ -597,6 +608,7 @@ def model_fit( # pylint: disable=too-many-arguments
loss_fn,
model_is_quantized,
model_to_load_from=model,
+ rewrite_specific_params=rewrite_specific_params,
)
else:
model_to_save = model