aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py118
1 files changed, 89 insertions, 29 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 4b9821c..88efa23 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -73,15 +73,15 @@ class TrainingParameters:
checkpoint_at: list | None = None
-def train(
+def train( # pylint: disable=too-many-arguments
source_model: str,
unmodified_model: Any,
output_model: str,
input_tfrec: str,
rewrite: Callable,
+ is_qat: bool,
input_tensors: list,
output_tensors: list,
- is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
) -> Any:
"""Extract and train a model, and return the results."""
@@ -383,15 +383,15 @@ def train_in_dir(
)
input_shape = teacher.shape_from_name[input_name][1:]
+
output_shape = teacher.shape_from_name[output_name][1:]
- model = rewrite(input_shape, output_shape)
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = keras.losses.MeanSquaredError()
- if model_is_quantized:
- model = rewrite.quantize(model) # type: ignore[attr-defined]
- model = model_compile(model, optimizer, loss_fn)
+ model = create_model(
+ rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized
+ )
logger.info(model.summary())
@@ -432,16 +432,14 @@ def train_in_dir(
callbacks = []
callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined]
-
- output_filenames = [] # type: list[str]
+ output_filenames: list = []
checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
train_params.steps
]
-
model, output_filenames = model_fit(
model,
train_params,
- checkpoints.copy(),
+ checkpoints,
optimizer,
dataset,
callbacks,
@@ -452,22 +450,35 @@ def train_in_dir(
output_name,
model_is_quantized,
output_filenames,
+ input_shape,
+ output_shape,
+ loss_fn,
+ post_process=True,
)
+ # Placeholder for now, will be parametrized later (MLIA-1114)
+ # rewrite.check_optimization( # type: ignore[attr-defined]
+ # model, number_of_clusters=32
+ # )
if model_is_quantized and is_qat:
- model = rewrite.pruning_preserved_quantization( # type: ignore[attr-defined]
- model,
- )
+ model = rewrite.preserved_quantize(model) # type: ignore[attr-defined]
+ checkpoints = (
+ train_params.checkpoint_at if train_params.checkpoint_at else []
+ ) + [train_params.steps]
+ output_filenames = []
+
+ if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined]
+ rewrite.training_callbacks() # type: ignore[attr-defined]
+ ).issubset(callbacks):
+ callbacks.pop(-1)
+
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
model = model_compile(model, optimizer, loss_fn)
- callbacks.pop(-1)
- output_filenames = []
-
model, output_filenames = model_fit(
model,
train_params,
- checkpoints.copy(),
+ checkpoints,
optimizer,
dataset,
callbacks,
@@ -478,22 +489,50 @@ def train_in_dir(
output_name,
model_is_quantized,
output_filenames,
+ input_shape,
+ output_shape,
+ loss_fn,
)
+ # Placeholder for now, will be parametrized later (MLIA-1114)
+ # rewrite.check_optimization( # type: ignore[attr-defined]
+ # model, number_of_clusters=32
+ # )
teacher.close()
return output_filenames
def model_compile(
- model: tf.keras.Model, optimizer: tf.keras.optimizers, loss_fn: tf.keras.losses
-) -> tf.keras.Model:
+ model: keras.Model,
+ optimizer: keras.optimizers.Nadam,
+ loss_fn: keras.losses.Loss,
+) -> keras.Model:
"""Compiles a tflite model."""
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
return model
-def model_fit(
- model: tf.keras.Model,
+def create_model( # pylint: disable=too-many-arguments
+ rewrite: Callable,
+ input_shape: int,
+ output_shape: int,
+ optimizer: Callable,
+ loss_fn: Callable,
+ model_is_quantized: bool,
+ model_to_load_from: keras.model | None = None,
+) -> keras.Model:
+ """Create a model, optionally from another."""
+ model = rewrite(input_shape, output_shape)
+ if model_is_quantized:
+ model = rewrite.quantize(model) # type: ignore[attr-defined]
+ model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn)
+ if model_to_load_from:
+ model.set_weights(model_to_load_from.get_weights())
+ return model
+
+
+def model_fit( # pylint: disable=too-many-arguments
+ model: keras.Model,
train_params: TrainingParameters,
checkpoints: list,
optimizer: tf.optimizers.Nadam,
@@ -506,8 +545,12 @@ def model_fit(
output_name: str,
model_is_quantized: bool,
output_filenames: list,
-) -> tuple[tf.keras.Model, list]:
- """Train the model."""
+ input_shape: int,
+ output_shape: int,
+ loss_fn: Callable,
+ post_process: bool = False,
+) -> keras.Model:
+ """Train a tflite model."""
steps_so_far = 0
while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
@@ -534,18 +577,34 @@ def model_fit(
checkpoint_filename = (
filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext
)
+ # If post processing we are stripping the clustering/pruning layers below
+ # Thus copy the model before saving, so training can continue
+ if post_process:
+ model_to_save = create_model(
+ rewrite,
+ input_shape,
+ output_shape,
+ optimizer,
+ loss_fn,
+ model_is_quantized,
+ model_to_load_from=model,
+ )
+ else:
+ model_to_save = model
else:
checkpoint_filename = str(output_filename)
-
- if steps_so_far == train_params.steps:
- model = rewrite.post_process(model) # type: ignore[attr-defined]
+ model_to_save = model
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):
+ if post_process:
+ model_to_save = rewrite.post_process( # type: ignore[attr-defined]
+ model_to_save
+ )
save_as_tflite(
- model,
- str(checkpoint_filename),
+ model_to_save,
+ checkpoint_filename,
input_name,
replace.shape_from_name[input_name],
output_name,
@@ -553,7 +612,8 @@ def model_fit(
model_is_quantized,
)
output_filenames.append(checkpoint_filename)
- return model, output_filenames
+
+ return model_to_save, output_filenames
def save_as_tflite(