diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/record.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/record.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index f85433d..7d9f219 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Save subgraph data.""" # pylint: disable=too-many-locals @@ -32,7 +32,7 @@ def dequantized_path(filename: str | Path) -> Path: return path -def record_model( +def record_model( # pylint: disable=too-many-arguments input_filename: str | Path, model_filename: str | Path, output_filename: str | Path, @@ -41,6 +41,7 @@ def record_model( num_procs: int = 1, num_threads: int = 0, dequantize_output: bool = False, + quantize_input: bool = False, ) -> None: """Model recorder. @@ -92,7 +93,10 @@ def record_model( for _, named_x in enumerate( track(dataset.as_numpy_iterator(), total=total, disable=not show_progress) ): - named_y = model(named_x) + if quantize_input: + named_y = model(model.quantize_inputs(named_x)) + else: + named_y = model(named_x) write(writer, named_y) if dequantize_output: |