aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py16
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py45
2 files changed, 47 insertions, 14 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
index 13a5268..53d5389 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/cut.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -1,9 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cut module."""
+from __future__ import annotations
+
import os
from collections import defaultdict
-from typing import Optional
+from pathlib import Path
import tensorflow as tf
from tensorflow.lite.python.schema_py_generated import ModelT
@@ -25,8 +27,8 @@ def tensors_by_name(subgraph: SubGraphT, names: list) -> list:
def cut_subgraph(
subgraph: SubGraphT,
- input_tensor_names: Optional[list],
- output_tensor_names: Optional[list],
+ input_tensor_names: list | None,
+ output_tensor_names: list | None,
) -> None:
"""Change the global inputs and outputs of a graph to the provided named tensors."""
if input_tensor_names is not None:
@@ -131,11 +133,11 @@ def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple:
def cut_model(
- model_file: str,
- input_names: Optional[list],
- output_names: Optional[list],
+ model_file: str | Path,
+ input_names: list | None,
+ output_names: list | None,
subgraph_index: int,
- output_file: str,
+ output_file: str | Path,
) -> None:
"""Cut subgraphs and simplify a given model."""
model = load_fb(model_file)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index 90f3db8..f85433d 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -6,6 +6,7 @@ from __future__ import annotations
import math
import os
+from contextlib import ExitStack
from pathlib import Path
import tensorflow as tf
@@ -15,11 +16,22 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+from mlia.nn.tensorflow.config import NameToTensorMap
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+DEQUANT_SUFFIX = "_dequant"
+
+
+def dequantized_path(filename: str | Path) -> Path:
+ """Append the de-quantization suffix to the given filename."""
+ path = Path(filename)
+ path = Path(path.parent, f"{path.stem}{DEQUANT_SUFFIX}{path.suffix}")
+ return path
+
+
def record_model(
input_filename: str | Path,
model_filename: str | Path,
@@ -28,11 +40,14 @@ def record_model(
show_progress: bool = False,
num_procs: int = 1,
num_threads: int = 0,
+ dequantize_output: bool = False,
) -> None:
"""Model recorder.
num_procs: 0 => detect real cores on system
num_threads: 0 => TFLite impl. specific setting, usually 3
+
+ dequantize: True => de-quantize the recorded output before saving
"""
model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size)
if not batch_size:
@@ -51,22 +66,38 @@ def record_model(
dataset = dataset.batch(batch_size, drop_remainder=False)
total = int(math.ceil(total / batch_size))
- with NumpyTFWriter(output_filename) as writer:
- for _, named_x in enumerate(
- track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
- ):
- named_y = model(named_x)
+ with ExitStack() as stack:
+ writer = stack.enter_context(NumpyTFWriter(output_filename))
+ writer_dequant = None
+ if dequantize_output:
+ dequant_path = dequantized_path(output_filename)
+ writer_dequant = stack.enter_context(NumpyTFWriter(dequant_path))
+
+ def write(writer: NumpyTFWriter, data: NameToTensorMap) -> None:
+ """Write the data using the given NumpyTFWriter instance."""
if batch_size > 1:
for i in range(batch_size):
# Expand the batches and recreate each dict as a
# batch-size 1 item for the tfrec output
recreated_dict = {
k: v[i : i + 1] # noqa: E203
- for k, v in named_y.items()
+ for k, v in data.items()
if i < v.shape[0]
}
if recreated_dict:
writer.write(recreated_dict)
else:
- writer.write(named_y)
+ writer.write(data)
+
+ for _, named_x in enumerate(
+ track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ ):
+ named_y = model(named_x)
+ write(writer, named_y)
+
+ if dequantize_output:
+ assert writer_dequant
+ named_y_dequant = model.dequantize_outputs(named_y)
+ write(writer_dequant, named_y_dequant)
+
model.close()