aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/record.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/record.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index f85433d..7d9f219 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Save subgraph data."""
# pylint: disable=too-many-locals
@@ -32,7 +32,7 @@ def dequantized_path(filename: str | Path) -> Path:
return path
-def record_model(
+def record_model( # pylint: disable=too-many-arguments
input_filename: str | Path,
model_filename: str | Path,
output_filename: str | Path,
@@ -41,6 +41,7 @@ def record_model(
num_procs: int = 1,
num_threads: int = 0,
dequantize_output: bool = False,
+ quantize_input: bool = False,
) -> None:
"""Model recorder.
@@ -92,7 +93,10 @@ def record_model(
for _, named_x in enumerate(
track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
):
- named_y = model(named_x)
+ if quantize_input:
+ named_y = model(model.quantize_inputs(named_x))
+ else:
+ named_y = model(named_x)
write(writer, named_y)
if dequantize_output: