diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 119 |
1 files changed, 97 insertions, 22 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 88efa23..570968a 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -34,13 +34,13 @@ from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.utils.logging import log_action - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule) class TrainingParameters: """Define default parameters for the training.""" - augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"] + augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["none"] batch_size: int = 32 steps: int = 48000 learning_rate: float = 1e-3 @@ -83,6 +83,8 @@ def train( # pylint: disable=too-many-arguments input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, + detect_activation_function: bool = False, ) -> Any: """Extract and train a model, and return the results.""" if unmodified_model: @@ -122,6 +124,8 @@ def train( # pylint: disable=too-many-arguments rewrite=rewrite, is_qat=is_qat, train_params=train_params, + rewrite_specific_params=rewrite_specific_params, + detect_activation_function=detect_activation_function, ) for i, filename in enumerate(tflite_filenames): @@ -147,7 +151,8 @@ def train( # pylint: disable=too-many-arguments # Assess the output diff between the parts after the rewrite subgraph # in original and optimized model optimized_end_path = Path(train_dir, "optimized_end.tfrec") - end_path = Path(train_dir, "end.tfrec") + optimized_end_path_dequant = Path(train_dir, "optimized_end_dequant.tfrec") + end_path = Path(train_dir, "end_dequant.tfrec") record_model( str(input_tfrec), @@ -155,8 +160,10 @@ def train( # pylint: disable=too-many-arguments optimized_end_path, num_procs=train_params.num_procs, num_threads=train_params.num_threads, + dequantize_output=True, ) - mae, nrmse = diff_stats(end_path, str(optimized_end_path)) + + mae, nrmse = diff_stats(end_path, optimized_end_path_dequant) if unmodified_model_dir: cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() @@ -179,24 +186,27 @@ def eval_in_dir( model_input = ( model_input_path if model_input_path.exists() - else ExtractPaths.tfrec.input(target_dir, False) + else ExtractPaths.tfrec.input(target_dir, True) ) output = ( model_output_path if model_output_path.exists() - else ExtractPaths.tfrec.output(target_dir, False) + else ExtractPaths.tfrec.output(target_dir, True) ) with tempfile.TemporaryDirectory() as tmp_dir: predict = Path(tmp_dir, "predict.tfrec") + predict_dequant = Path(tmp_dir, "predict_dequant.tfrec") record_model( str(model_input), new_part, str(predict), num_procs=num_procs, num_threads=num_threads, + dequantize_output=True, + quantize_input=True, ) - mae, nrmse = diff_stats(str(output), str(predict)) + mae, nrmse = diff_stats(str(output), predict_dequant) return mae, nrmse @@ -249,7 +259,7 @@ def set_up_data_pipeline( augmentations: tuple[float | None, float | None], steps: int, batch_size: int = 32, -) -> tf.data.Dataset: +) -> tuple[tf.data.Dataset, int]: """Create a data pipeline for training of the replacement model.""" _check_model_compatibility(teacher, replace) @@ -340,7 +350,42 @@ def set_up_data_pipeline( dataset = dataset.map(restore_shapes) dataset = dataset.prefetch(tf.data.AUTOTUNE) - return dataset + return dataset, steps_per_epoch + + +def detect_activation_from_rewrite_function(model_path: str) -> str: + """Given a rewrite model, choose the most common activation function.""" + interpreter = tf.lite.Interpreter(model_path=model_path) + interpreter.allocate_tensors() + act_func_match_list = [] + for tensor_details in interpreter.get_tensor_details(): + for act_func in ACTIVATION_FUNCTION_LIST: + tensor_name = tensor_details["name"].lower() + if act_func in tensor_name: + act_func_idx = tensor_name.index(act_func) + if ( + len(tensor_name) == act_func_idx + len(act_func) + or tensor_name[act_func_idx + len(act_func)] == ";" + ): + act_func_match_list.append( + tensor_name[ + act_func_idx : act_func_idx + len(act_func) # noqa: E203 + ] + ) + act_func_match = "relu" + if len(act_func_match_list) == 0: + logger.info( + "No activation function specified, setting activation function to ReLU" + ) + else: + act_func_match = max(set(act_func_match_list), key=act_func_match.count) + logger.info( + "No activation function specified, " + "setting activation function to most " + "common activation detected in rewrite graph: %s", + act_func_match, + ) + return act_func_match def train_in_dir( @@ -350,6 +395,8 @@ def train_in_dir( rewrite: Callable, is_qat: bool, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, + detect_activation_function: bool = False, ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ and output.tfrec in train_dir. @@ -366,6 +413,18 @@ def train_in_dir( ) replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir)) + if detect_activation_function and ( + rewrite_specific_params is None + or "activation" not in list(rewrite_specific_params.keys()) + ): + detected_activation_function = detect_activation_from_rewrite_function( + ExtractPaths.tflite.replace(train_dir).as_posix() + ) + if rewrite_specific_params: + rewrite_specific_params["activation"] = detected_activation_function + else: + rewrite_specific_params = {"activation": detected_activation_function} + input_name, output_name = _get_io_tensors(teacher) model_is_quantized = replace.is_tensor_quantized(name=input_name) @@ -373,7 +432,7 @@ def train_in_dir( if model_is_quantized: replace.check_datatypes(np.int8) - dataset = set_up_data_pipeline( + dataset, steps_per_epoch = set_up_data_pipeline( teacher, replace, train_dir, @@ -390,7 +449,13 @@ def train_in_dir( loss_fn = keras.losses.MeanSquaredError() model = create_model( - rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + rewrite_specific_params=rewrite_specific_params, ) logger.info(model.summary()) @@ -453,13 +518,12 @@ def train_in_dir( input_shape, output_shape, loss_fn, + steps_per_epoch, post_process=True, ) - - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) if model_is_quantized and is_qat: model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] checkpoints = ( @@ -492,12 +556,12 @@ def train_in_dir( input_shape, output_shape, loss_fn, + steps_per_epoch, ) - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) teacher.close() return output_filenames @@ -520,9 +584,13 @@ def create_model( # pylint: disable=too-many-arguments loss_fn: Callable, model_is_quantized: bool, model_to_load_from: keras.model | None = None, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Create a model, optionally from another.""" - model = rewrite(input_shape, output_shape) + if rewrite_specific_params: + model = rewrite(input_shape, output_shape, **rewrite_specific_params) + else: + model = rewrite(input_shape, output_shape) if model_is_quantized: model = rewrite.quantize(model) # type: ignore[attr-defined] model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) @@ -548,7 +616,9 @@ def model_fit( # pylint: disable=too-many-arguments input_shape: int, output_shape: int, loss_fn: Callable, + steps_per_epoch: int, post_process: bool = False, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Train a tflite model.""" steps_so_far = 0 @@ -588,13 +658,18 @@ def model_fit( # pylint: disable=too-many-arguments loss_fn, model_is_quantized, model_to_load_from=model, + rewrite_specific_params=rewrite_specific_params, ) else: model_to_save = model else: checkpoint_filename = str(output_filename) + logger.info("Evaluate final Keras Model using %d steps", steps_per_epoch) + model.evaluate( + dataset, + steps=steps_per_epoch, + ) model_to_save = model - with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): |