1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
|
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Parallelize a TFLiteModel."""
from __future__ import annotations
import logging
import math
import os
from collections import defaultdict
from multiprocessing import cpu_count
from multiprocessing import Pool
from pathlib import Path
from typing import Any
import numpy as np
import tensorflow as tf
from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
class ParallelTFLiteModel(TFLiteModel):
"""A parallel version of a TFLiteModel.
num_procs: 0 => detect real cores on system
num_threads: 0 => TFLite impl. specific setting, usually 3
batch_size: None => automatic (num_procs or file-determined)
"""
def __init__(
self,
filename: str | Path,
num_procs: int = 1,
num_threads: int = 0,
batch_size: int | None = None,
) -> None:
"""Initiate a Parallel TFLite Model."""
self.pool = None
filename = str(filename)
self.filename = filename
if not num_procs:
self.num_procs = cpu_count()
else:
self.num_procs = int(num_procs)
self.num_threads = num_threads
if self.num_procs > 1:
if not batch_size:
batch_size = self.num_procs # default to min effective batch size
local_batch_size = int(math.ceil(batch_size / self.num_procs))
super().__init__(filename, batch_size=local_batch_size)
del self.interpreter
self.pool = Pool( # pylint: disable=consider-using-with
processes=self.num_procs,
initializer=_pool_create_worker,
initargs=[filename, self.batch_size, self.num_threads],
)
else: # fall back to serial implementation for max performance
super().__init__(
filename, batch_size=batch_size, num_threads=self.num_threads
)
self.total_batches = 0
self.partial_batches = 0
self.warned = False
def close(self) -> None:
"""Close and terminate pool."""
if self.pool:
self.pool.close()
self.pool.terminate()
def __del__(self) -> None:
"""Close instance."""
self.close()
def __call__(self, named_input: dict) -> Any:
"""Call instance."""
if self.pool:
global_batch_size = next(iter(named_input.values())).shape[0]
# Note: self.batch_size comes from superclass and is local batch size
chunks = int(math.ceil(global_batch_size / self.batch_size))
self.total_batches += 1
if chunks != self.num_procs:
self.partial_batches += 1
if (
not self.warned
and self.total_batches > 10
and self.partial_batches / self.total_batches >= 0.5
):
logger.warning(
"ParallelTFLiteModel(%s): warning - %.1f of batches "
"do not use all %d processes, set batch size to "
"a multiple of this.",
self.filename,
100 * self.partial_batches / self.total_batches,
self.num_procs,
)
self.warned = True
local_batches = [
{
key: values[
i * self.batch_size : (i + 1) * self.batch_size # noqa: E203
]
for key, values in named_input.items()
}
for i in range(chunks)
]
chunk_results = self.pool.map(_pool_run, local_batches)
named_ys = defaultdict(list)
for chunk in chunk_results:
for key, value in chunk.items():
named_ys[key].append(value)
return {key: np.concatenate(value) for key, value in named_ys.items()}
return super().__call__(named_input)
_LOCAL_MODEL = None
def _pool_create_worker(
filename: str, local_batch_size: int = 0, num_threads: int = 0
) -> None:
global _LOCAL_MODEL # pylint: disable=global-statement
_LOCAL_MODEL = TFLiteModel(
filename, batch_size=local_batch_size, num_threads=num_threads
)
def _pool_run(named_inputs: dict) -> Any:
if _LOCAL_MODEL:
return _LOCAL_MODEL(named_inputs)
raise ValueError("TFLiteModel is not initiated")
|