aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/utils/parallel.py
blob: 5affc03f1c7e2a0eb63f82e00761d92c8d1a847a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import math
import os
from collections import defaultdict
from multiprocessing import Pool

import numpy as np
from psutil import cpu_count

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel


class ParallelTFLiteModel(TFLiteModel):
    def __init__(self, filename, num_procs=1, num_threads=0, batch_size=None):
        """num_procs: 0 => detect real cores on system
        num_threads: 0 => TFLite impl. specific setting, usually 3
        batch_size: None => automatic (num_procs or file-determined)
        """
        self.pool = None
        self.filename = filename
        if not num_procs:
            self.num_procs = cpu_count(logical=False)
        else:
            self.num_procs = int(num_procs)

        self.num_threads = num_threads

        if self.num_procs > 1:
            if not batch_size:
                batch_size = self.num_procs  # default to min effective batch size
            local_batch_size = int(math.ceil(batch_size / self.num_procs))
            super().__init__(filename, batch_size=local_batch_size)
            del self.interpreter
            self.pool = Pool(
                processes=self.num_procs,
                initializer=_pool_create_worker,
                initargs=[filename, self.batch_size, self.num_threads],
            )
        else:  # fall back to serial implementation for max performance
            super().__init__(
                filename, batch_size=batch_size, num_threads=self.num_threads
            )

        self.total_batches = 0
        self.partial_batches = 0
        self.warned = False

    def close(self):
        if self.pool:
            self.pool.close()
            self.pool.terminate()

    def __del__(self):
        self.close()

    def __call__(self, named_input):
        if self.pool:
            global_batch_size = next(iter(named_input.values())).shape[0]
            # Note: self.batch_size comes from superclass and is local batch size
            chunks = int(math.ceil(global_batch_size / self.batch_size))
            self.total_batches += 1
            if chunks != self.num_procs:
                self.partial_batches += 1
            if (
                not self.warned
                and self.total_batches > 10
                and self.partial_batches / self.total_batches >= 0.5
            ):
                print(
                    "ParallelTFLiteModel(%s): warning - %.1f%% of batches do not use all %d processes, set batch size to a multiple of this"
                    % (
                        self.filename,
                        100 * self.partial_batches / self.total_batches,
                        self.num_procs,
                    )
                )
                self.warned = True

            local_batches = [
                {
                    key: values[i * self.batch_size : (i + 1) * self.batch_size]
                    for key, values in named_input.items()
                }
                for i in range(chunks)
            ]
            chunk_results = self.pool.map(_pool_run, local_batches)
            named_ys = defaultdict(list)
            for chunk in chunk_results:
                for k, v in chunk.items():
                    named_ys[k].append(v)
            return {k: np.concatenate(v) for k, v in named_ys.items()}
        else:
            return super().__call__(named_input)


_local_model = None


def _pool_create_worker(filename, local_batch_size=None, num_threads=None):
    global _local_model
    _local_model = TFLiteModel(
        filename, batch_size=local_batch_size, num_threads=num_threads
    )


def _pool_run(named_inputs):
    return _local_model(named_inputs)