diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core')
-rw-r--r-- | src/mlia/nn/rewrite/core/extract.py | 61 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/cut.py | 16 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/record.py | 45 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 67 |
4 files changed, 152 insertions, 37 deletions
diff --git a/src/mlia/nn/rewrite/core/extract.py b/src/mlia/nn/rewrite/core/extract.py index f609955..4fcf735 100644 --- a/src/mlia/nn/rewrite/core/extract.py +++ b/src/mlia/nn/rewrite/core/extract.py @@ -2,19 +2,62 @@ # SPDX-License-Identifier: Apache-2.0 """Extract module.""" # pylint: disable=too-many-arguments, too-many-locals +from __future__ import annotations + import os +from functools import partial +from pathlib import Path import tensorflow as tf from tensorflow.lite.python.schema_py_generated import SubGraphT from mlia.nn.rewrite.core.graph_edit.cut import cut_model +from mlia.nn.rewrite.core.graph_edit.record import dequantized_path from mlia.nn.rewrite.core.graph_edit.record import record_model - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +def _get_path( + ext: str, name: str, dir_path: str | Path, model_is_quantized: bool = False +) -> Path: + """Create a file path for extracted files.""" + path = Path(dir_path, f"{name}{ext}") + return dequantized_path(path) if model_is_quantized else path + + +class TFLitePaths: # pylint: disable=too-few-public-methods + """Provide safe access to TensorFlow Lite file paths.""" + + _get_path_tflite = partial(_get_path, ".tflite") + + start = partial(_get_path_tflite, "start") + replace = partial(_get_path_tflite, "replace") + end = partial(_get_path_tflite, "end") + + +class TFRecordPaths: # pylint: disable=too-few-public-methods + """Provide safe access to tfrec file paths.""" + + _get_path_tfrec = partial(_get_path, ".tfrec") + + input = partial(_get_path_tfrec, "input") + output = partial(_get_path_tfrec, "output") + end = partial(_get_path_tfrec, "end") + + +class ExtractPaths: # pylint: disable=too-few-public-methods + """Get paths to extract files. + + This is meant to be the single source of truth regarding all file names + created by the extract() function in an output directory. + """ + + tflite = TFLitePaths + tfrec = TFRecordPaths + + def extract( output_path: str, model_file: str, @@ -26,6 +69,7 @@ def extract( show_progress: bool = False, num_procs: int = 1, num_threads: int = 0, + dequantize_output: bool = False, ) -> None: """Extract a model after cut and record.""" try: @@ -33,7 +77,7 @@ def extract( except FileExistsError: pass - start_file = os.path.join(output_path, "start.tflite") + start_file = ExtractPaths.tflite.start(output_path) cut_model( model_file, input_names=None, @@ -42,7 +86,7 @@ def extract( output_file=start_file, ) - input_tfrec = os.path.join(output_path, "input.tfrec") + input_tfrec = ExtractPaths.tfrec.input(output_path) record_model( input_filename, start_file, @@ -50,9 +94,10 @@ def extract( show_progress=show_progress, num_procs=num_procs, num_threads=num_threads, + dequantize_output=dequantize_output, ) - replace_file = os.path.join(output_path, "replace.tflite") + replace_file = ExtractPaths.tflite.replace(output_path) cut_model( model_file, input_names=input_names, @@ -61,7 +106,7 @@ def extract( output_file=replace_file, ) - end_file = os.path.join(output_path, "end.tflite") + end_file = ExtractPaths.tflite.end(output_path) cut_model( model_file, input_names=output_names, @@ -71,7 +116,7 @@ def extract( ) if not skip_outputs: - output_tfrec = os.path.join(output_path, "output.tfrec") + output_tfrec = ExtractPaths.tfrec.output(output_path) record_model( input_tfrec, replace_file, @@ -79,9 +124,10 @@ def extract( show_progress=show_progress, num_procs=num_procs, num_threads=num_threads, + dequantize_output=dequantize_output, ) - end_tfrec = os.path.join(output_path, "end.tfrec") + end_tfrec = ExtractPaths.tfrec.end(output_path) record_model( output_tfrec, end_file, @@ -89,4 +135,5 @@ def extract( show_progress=show_progress, num_procs=num_procs, num_threads=num_threads, + dequantize_output=dequantize_output, ) diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py index 13a5268..53d5389 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/cut.py +++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py @@ -1,9 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cut module.""" +from __future__ import annotations + import os from collections import defaultdict -from typing import Optional +from pathlib import Path import tensorflow as tf from tensorflow.lite.python.schema_py_generated import ModelT @@ -25,8 +27,8 @@ def tensors_by_name(subgraph: SubGraphT, names: list) -> list: def cut_subgraph( subgraph: SubGraphT, - input_tensor_names: Optional[list], - output_tensor_names: Optional[list], + input_tensor_names: list | None, + output_tensor_names: list | None, ) -> None: """Change the global inputs and outputs of a graph to the provided named tensors.""" if input_tensor_names is not None: @@ -131,11 +133,11 @@ def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple: def cut_model( - model_file: str, - input_names: Optional[list], - output_names: Optional[list], + model_file: str | Path, + input_names: list | None, + output_names: list | None, subgraph_index: int, - output_file: str, + output_file: str | Path, ) -> None: """Cut subgraphs and simplify a given model.""" model = load_fb(model_file) diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index 90f3db8..f85433d 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -6,6 +6,7 @@ from __future__ import annotations import math import os +from contextlib import ExitStack from pathlib import Path import tensorflow as tf @@ -15,11 +16,22 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.tensorflow.config import NameToTensorMap os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +DEQUANT_SUFFIX = "_dequant" + + +def dequantized_path(filename: str | Path) -> Path: + """Append the de-quantization suffix to the given filename.""" + path = Path(filename) + path = Path(path.parent, f"{path.stem}{DEQUANT_SUFFIX}{path.suffix}") + return path + + def record_model( input_filename: str | Path, model_filename: str | Path, @@ -28,11 +40,14 @@ def record_model( show_progress: bool = False, num_procs: int = 1, num_threads: int = 0, + dequantize_output: bool = False, ) -> None: """Model recorder. num_procs: 0 => detect real cores on system num_threads: 0 => TFLite impl. specific setting, usually 3 + + dequantize: True => de-quantize the recorded output before saving """ model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size) if not batch_size: @@ -51,22 +66,38 @@ def record_model( dataset = dataset.batch(batch_size, drop_remainder=False) total = int(math.ceil(total / batch_size)) - with NumpyTFWriter(output_filename) as writer: - for _, named_x in enumerate( - track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) - ): - named_y = model(named_x) + with ExitStack() as stack: + writer = stack.enter_context(NumpyTFWriter(output_filename)) + writer_dequant = None + if dequantize_output: + dequant_path = dequantized_path(output_filename) + writer_dequant = stack.enter_context(NumpyTFWriter(dequant_path)) + + def write(writer: NumpyTFWriter, data: NameToTensorMap) -> None: + """Write the data using the given NumpyTFWriter instance.""" if batch_size > 1: for i in range(batch_size): # Expand the batches and recreate each dict as a # batch-size 1 item for the tfrec output recreated_dict = { k: v[i : i + 1] # noqa: E203 - for k, v in named_y.items() + for k, v in data.items() if i < v.shape[0] } if recreated_dict: writer.write(recreated_dict) else: - writer.write(named_y) + writer.write(data) + + for _, named_x in enumerate( + track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) + ): + named_y = model(named_x) + write(writer, named_y) + + if dequantize_output: + assert writer_dequant + named_y_dequant = model.dequantize_outputs(named_y) + write(writer_dequant, named_y_dequant) + model.close() diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 82af747..6345f07 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -22,9 +22,11 @@ from typing import Literal import numpy as np import tensorflow as tf +import tensorflow_model_optimization as tfmot from numpy.random import Generator from mlia.nn.rewrite.core.extract import extract +from mlia.nn.rewrite.core.extract import ExtractPaths from mlia.nn.rewrite.core.graph_edit.diff import diff_stats from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.graph_edit.record import record_model @@ -34,6 +36,7 @@ from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb +from mlia.nn.tensorflow.utils import get_tflite_converter from mlia.utils.logging import log_action @@ -91,6 +94,7 @@ def train( input_tfrec, input_tensors, output_tensors, + dequantize_output=True, ) else: unmodified_model_dir = None @@ -106,6 +110,7 @@ def train( output_tensors, num_procs=train_params.num_procs, num_threads=train_params.num_threads, + dequantize_output=True, ) tflite_filenames = train_in_dir( @@ -160,7 +165,10 @@ def train( def eval_in_dir( - target_dir: str, new_part: str, num_procs: int = 1, num_threads: int = 0 + target_dir: str, + new_part: str, + num_procs: int = 1, + num_threads: int = 0, ) -> tuple: """Evaluate a model in a given directory.""" model_input_path = Path(target_dir, "input_orig.tfrec") @@ -168,12 +176,12 @@ def eval_in_dir( model_input = ( model_input_path if model_input_path.exists() - else Path(target_dir, "input.tfrec") + else ExtractPaths.tfrec.input(target_dir, False) ) output = ( model_output_path if model_output_path.exists() - else Path(target_dir, "output.tfrec") + else ExtractPaths.tfrec.output(target_dir, False) ) with tempfile.TemporaryDirectory() as tmp_dir: @@ -194,8 +202,8 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None: """Join two models in a given directory.""" with tempfile.TemporaryDirectory() as tmp_dir: new_end = Path(tmp_dir, "new_end.tflite") - join_models(new_part, Path(model_dir, "end.tflite"), new_end) - join_models(Path(model_dir, "start.tflite"), new_end, output_model) + join_models(new_part, ExtractPaths.tflite.end(model_dir), new_end) + join_models(ExtractPaths.tflite.start(model_dir), new_end, output_model) def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]: @@ -244,7 +252,9 @@ def set_up_data_pipeline( input_name, output_name = _get_io_tensors(teacher) - input_filename = Path(train_dir, "input.tfrec") + model_is_quantized = replace.is_tensor_quantized(name=input_name) + + input_filename = ExtractPaths.tfrec.input(train_dir, model_is_quantized) total = numpytf_count(str(input_filename)) dict_inputs = numpytf_read(str(input_filename)) @@ -264,13 +274,13 @@ def set_up_data_pipeline( if any(augmentations): # Map the teacher inputs here because the augmentation stage passes these # through a TFLite model to get the outputs - teacher_outputs = numpytf_read(str(Path(teacher_dir, "input.tfrec"))).map( - lambda d: tf.squeeze(d[input_name], axis=0) - ) + teacher_outputs = numpytf_read( + str(ExtractPaths.tfrec.input(teacher_dir, model_is_quantized)) + ).map(lambda d: tf.squeeze(d[input_name], axis=0)) else: - teacher_outputs = numpytf_read(str(Path(teacher_dir, "output.tfrec"))).map( - lambda d: tf.squeeze(d[output_name], axis=0) - ) + teacher_outputs = numpytf_read( + str(ExtractPaths.tfrec.output(teacher_dir, model_is_quantized)) + ).map(lambda d: tf.squeeze(d[output_name], axis=0)) dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) if epochs > 1: @@ -285,7 +295,23 @@ def set_up_data_pipeline( ) -> tuple: """Return results of train and teach based on augmentations.""" augmented_train = augment_train({input_name: train})[input_name] - augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name] + # If augmentation of the input is enabled, we need to re-generate + # the reference output by running the augmented data through the + # teacher model. + if model_is_quantized: + # If the input model is quantized we have to additionally + # - quantize the augmented data before running it through the + # (quantized) teacher model + # - de-quantize the output for the training. + augmented_teach = teacher.dequantize_outputs( + teacher( + teacher.quantize_inputs(augment_teacher({input_name: teach})) + ) + )[output_name] + else: + augmented_teach = teacher(augment_teacher({input_name: teach}))[ + output_name + ] return (augmented_train, augmented_teach) dataset = dataset.map( @@ -329,15 +355,20 @@ def train_in_dir( """ teacher_dir = baseline_dir if baseline_dir else train_dir teacher = ParallelTFLiteModel( - f"{teacher_dir}/replace.tflite", + ExtractPaths.tflite.replace(teacher_dir), train_params.num_procs, train_params.num_threads, batch_size=train_params.batch_size, ) - replace = TFLiteModel(f"{train_dir}/replace.tflite") + replace = TFLiteModel(ExtractPaths.tflite.replace(train_dir)) input_name, output_name = _get_io_tensors(teacher) + model_is_quantized = replace.is_tensor_quantized(name=input_name) + + if model_is_quantized: + replace.check_datatypes(np.int8) + dataset = set_up_data_pipeline( teacher, replace, @@ -354,6 +385,8 @@ def train_in_dir( optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = tf.keras.losses.MeanSquaredError() + if model_is_quantized: + model = tfmot.quantization.keras.quantize_model(model) model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) logger.info(model.summary()) @@ -432,6 +465,7 @@ def train_in_dir( replace.shape_from_name[input_name], output_name, replace.shape_from_name[output_name], + model_is_quantized, ) output_filenames.append(checkpoint_filename) @@ -446,6 +480,7 @@ def save_as_tflite( input_shape: list, output_name: str, output_shape: list, + model_is_quantized: bool = False, ) -> None: """Save Keras model as TFLite file.""" @@ -464,7 +499,7 @@ def save_as_tflite( keras_model.input.set_shape(orig_shape) with fixed_input(keras_model, input_shape) as fixed_model: - converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model) + converter = get_tflite_converter(fixed_model, quantized=model_is_quantized) tflite_model = converter.convert() with open(filename, "wb") as file: |