diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/record.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/record.py | 45 |
1 files changed, 38 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index 90f3db8..f85433d 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -6,6 +6,7 @@ from __future__ import annotations import math import os +from contextlib import ExitStack from pathlib import Path import tensorflow as tf @@ -15,11 +16,22 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.tensorflow.config import NameToTensorMap os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +DEQUANT_SUFFIX = "_dequant" + + +def dequantized_path(filename: str | Path) -> Path: + """Append the de-quantization suffix to the given filename.""" + path = Path(filename) + path = Path(path.parent, f"{path.stem}{DEQUANT_SUFFIX}{path.suffix}") + return path + + def record_model( input_filename: str | Path, model_filename: str | Path, @@ -28,11 +40,14 @@ def record_model( show_progress: bool = False, num_procs: int = 1, num_threads: int = 0, + dequantize_output: bool = False, ) -> None: """Model recorder. num_procs: 0 => detect real cores on system num_threads: 0 => TFLite impl. specific setting, usually 3 + + dequantize: True => de-quantize the recorded output before saving """ model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size) if not batch_size: @@ -51,22 +66,38 @@ def record_model( dataset = dataset.batch(batch_size, drop_remainder=False) total = int(math.ceil(total / batch_size)) - with NumpyTFWriter(output_filename) as writer: - for _, named_x in enumerate( - track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) - ): - named_y = model(named_x) + with ExitStack() as stack: + writer = stack.enter_context(NumpyTFWriter(output_filename)) + writer_dequant = None + if dequantize_output: + dequant_path = dequantized_path(output_filename) + writer_dequant = stack.enter_context(NumpyTFWriter(dequant_path)) + + def write(writer: NumpyTFWriter, data: NameToTensorMap) -> None: + """Write the data using the given NumpyTFWriter instance.""" if batch_size > 1: for i in range(batch_size): # Expand the batches and recreate each dict as a # batch-size 1 item for the tfrec output recreated_dict = { k: v[i : i + 1] # noqa: E203 - for k, v in named_y.items() + for k, v in data.items() if i < v.shape[0] } if recreated_dict: writer.write(recreated_dict) else: - writer.write(named_y) + writer.write(data) + + for _, named_x in enumerate( + track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) + ): + named_y = model(named_x) + write(writer, named_y) + + if dequantize_output: + assert writer_dequant + named_y_dequant = model.dequantize_outputs(named_y) + write(writer_dequant, named_y_dequant) + model.close() |