aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/record.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/record.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
new file mode 100644
index 0000000..03cd3f9
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -0,0 +1,62 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import math
+import os
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from tqdm import tqdm
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import (
+ NumpyTFReader,
+ NumpyTFWriter,
+ TFLiteModel,
+ numpytf_count,
+)
+from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+
+
+def record_model(
+ input_filename,
+ model_filename,
+ output_filename,
+ batch_size=None,
+ show_progress=False,
+ num_procs=1,
+ num_threads=0,
+):
+ """num_procs: 0 => detect real cores on system
+ num_threads: 0 => TFLite impl. specific setting, usually 3"""
+ model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size)
+ if not batch_size:
+ batch_size = (
+ model.num_procs * model.batch_size
+ ) # automatically batch to the minimum effective size if not specified
+
+ total = numpytf_count(input_filename)
+ dataset = NumpyTFReader(input_filename)
+ writer = NumpyTFWriter(output_filename)
+
+ if batch_size > 1:
+ # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now.
+ dataset = dataset.map(
+ lambda d: {k: tf.squeeze(v, axis=0) for k, v in d.items()}
+ )
+ dataset = dataset.batch(batch_size, drop_remainder=False)
+ total = int(math.ceil(total / batch_size))
+
+ for j, named_x in enumerate(
+ tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ ):
+ named_y = model(named_x)
+ if batch_size > 1:
+ for i in range(batch_size):
+ # Expand the batches and recreate each dict as a batch-size 1 item for the tfrec output
+ d = {k: v[i : i + 1] for k, v in named_y.items() if i < v.shape[0]}
+ if d:
+ writer.write(d)
+ else:
+ writer.write(named_y)
+ model.close()