1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
|
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import math
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from tqdm import tqdm
from mlia.nn.rewrite.core.utils.numpy_tfrecord import (
NumpyTFReader,
NumpyTFWriter,
TFLiteModel,
numpytf_count,
)
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
def record_model(
input_filename,
model_filename,
output_filename,
batch_size=None,
show_progress=False,
num_procs=1,
num_threads=0,
):
"""num_procs: 0 => detect real cores on system
num_threads: 0 => TFLite impl. specific setting, usually 3"""
model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size)
if not batch_size:
batch_size = (
model.num_procs * model.batch_size
) # automatically batch to the minimum effective size if not specified
total = numpytf_count(input_filename)
dataset = NumpyTFReader(input_filename)
if batch_size > 1:
# Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now.
dataset = dataset.map(
lambda d: {k: tf.squeeze(v, axis=0) for k, v in d.items()}
)
dataset = dataset.batch(batch_size, drop_remainder=False)
total = int(math.ceil(total / batch_size))
with NumpyTFWriter(output_filename) as writer:
for _, named_x in enumerate(
tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
):
named_y = model(named_x)
if batch_size > 1:
for i in range(batch_size):
# Expand the batches and recreate each dict as a
# batch-size 1 item for the tfrec output
recreated_dict = {
k: v[i : i + 1] # noqa: E203
for k, v in named_y.items()
if i < v.shape[0]
}
if recreated_dict:
writer.write(recreated_dict)
else:
writer.write(named_y)
model.close()
|