aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/record.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/record.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py51
1 files changed, 28 insertions, 23 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index ae13313..90f3db8 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -1,34 +1,39 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Save subgraph data."""
+# pylint: disable=too-many-locals
+from __future__ import annotations
+
import math
import os
+from pathlib import Path
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
+from rich.progress import track
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-
-from tqdm import tqdm
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import (
- NumpyTFReader,
- NumpyTFWriter,
- TFLiteModel,
- numpytf_count,
-)
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
def record_model(
- input_filename,
- model_filename,
- output_filename,
- batch_size=None,
- show_progress=False,
- num_procs=1,
- num_threads=0,
-):
- """num_procs: 0 => detect real cores on system
- num_threads: 0 => TFLite impl. specific setting, usually 3"""
+ input_filename: str | Path,
+ model_filename: str | Path,
+ output_filename: str | Path,
+ batch_size: int = 0,
+ show_progress: bool = False,
+ num_procs: int = 1,
+ num_threads: int = 0,
+) -> None:
+ """Model recorder.
+
+ num_procs: 0 => detect real cores on system
+ num_threads: 0 => TFLite impl. specific setting, usually 3
+ """
model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size)
if not batch_size:
batch_size = (
@@ -36,10 +41,10 @@ def record_model(
) # automatically batch to the minimum effective size if not specified
total = numpytf_count(input_filename)
- dataset = NumpyTFReader(input_filename)
+ dataset = numpytf_read(input_filename)
if batch_size > 1:
- # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now.
+ # Collapse batch-size 1 items into batch-size n.
dataset = dataset.map(
lambda d: {k: tf.squeeze(v, axis=0) for k, v in d.items()}
)
@@ -48,7 +53,7 @@ def record_model(
with NumpyTFWriter(output_filename) as writer:
for _, named_x in enumerate(
- tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
):
named_y = model(named_x)
if batch_size > 1: