diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 34 |
1 files changed, 23 insertions, 11 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 4204978..e99c7e9 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -83,6 +83,7 @@ def train( # pylint: disable=too-many-arguments input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, ) -> Any: """Extract and train a model, and return the results.""" if unmodified_model: @@ -122,6 +123,7 @@ def train( # pylint: disable=too-many-arguments rewrite=rewrite, is_qat=is_qat, train_params=train_params, + rewrite_specific_params=rewrite_specific_params, ) for i, filename in enumerate(tflite_filenames): @@ -356,6 +358,7 @@ def train_in_dir( rewrite: Callable, is_qat: bool, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ and output.tfrec in train_dir. @@ -396,7 +399,13 @@ def train_in_dir( loss_fn = keras.losses.MeanSquaredError() model = create_model( - rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + rewrite_specific_params=rewrite_specific_params, ) logger.info(model.summary()) @@ -462,11 +471,9 @@ def train_in_dir( steps_per_epoch, post_process=True, ) - - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) if model_is_quantized and is_qat: model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] checkpoints = ( @@ -501,11 +508,10 @@ def train_in_dir( loss_fn, steps_per_epoch, ) - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) teacher.close() return output_filenames @@ -528,9 +534,13 @@ def create_model( # pylint: disable=too-many-arguments loss_fn: Callable, model_is_quantized: bool, model_to_load_from: keras.model | None = None, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Create a model, optionally from another.""" - model = rewrite(input_shape, output_shape) + if rewrite_specific_params: + model = rewrite(input_shape, output_shape, **rewrite_specific_params) + else: + model = rewrite(input_shape, output_shape) if model_is_quantized: model = rewrite.quantize(model) # type: ignore[attr-defined] model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) @@ -558,6 +568,7 @@ def model_fit( # pylint: disable=too-many-arguments loss_fn: Callable, steps_per_epoch: int, post_process: bool = False, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Train a tflite model.""" steps_so_far = 0 @@ -597,6 +608,7 @@ def model_fit( # pylint: disable=too-many-arguments loss_fn, model_is_quantized, model_to_load_from=model, + rewrite_specific_params=rewrite_specific_params, ) else: model_to_save = model |