aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/record.py
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-04-08 15:42:11 +0100
committerNathan Bailey <nathan.bailey@arm.com>2024-05-07 11:48:11 +0000
commit0999ba0f4381ce1e2e8b06a932bfe693692223e2 (patch)
treecdc1061b885f9bfa7548b5c9bc569750a7ba957c /src/mlia/nn/rewrite/core/graph_edit/record.py
parent7565b9a5460ddc7d74eb74dfb6d0c4264f99ebaf (diff)
downloadmlia-0999ba0f4381ce1e2e8b06a932bfe693692223e2.tar.gz
fix: Fixes MAE discrepancies of rewrites
- Adds model.evaluate stage to retrieve correct MAE - Sets augmentations to none by default - Enables MAE calculations using dequantized data (if needed) Resolves: MLIA-972 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: Ia838b2d86caf7b10ad6e4d87bf6aa9e27c80bb72
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/record.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index f85433d..7d9f219 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Save subgraph data."""
# pylint: disable=too-many-locals
@@ -32,7 +32,7 @@ def dequantized_path(filename: str | Path) -> Path:
return path
-def record_model(
+def record_model( # pylint: disable=too-many-arguments
input_filename: str | Path,
model_filename: str | Path,
output_filename: str | Path,
@@ -41,6 +41,7 @@ def record_model(
num_procs: int = 1,
num_threads: int = 0,
dequantize_output: bool = False,
+ quantize_input: bool = False,
) -> None:
"""Model recorder.
@@ -92,7 +93,10 @@ def record_model(
for _, named_x in enumerate(
track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
):
- named_y = model(named_x)
+ if quantize_input:
+ named_y = model(model.quantize_inputs(named_x))
+ else:
+ named_y = model(named_x)
write(writer, named_y)
if dequantize_output: