diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils/parallel.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/parallel.py | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py new file mode 100644 index 0000000..5affc03 --- /dev/null +++ b/src/mlia/nn/rewrite/core/utils/parallel.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import math +import os +from collections import defaultdict +from multiprocessing import Pool + +import numpy as np +from psutil import cpu_count + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel + + +class ParallelTFLiteModel(TFLiteModel): + def __init__(self, filename, num_procs=1, num_threads=0, batch_size=None): + """num_procs: 0 => detect real cores on system + num_threads: 0 => TFLite impl. specific setting, usually 3 + batch_size: None => automatic (num_procs or file-determined) + """ + self.pool = None + self.filename = filename + if not num_procs: + self.num_procs = cpu_count(logical=False) + else: + self.num_procs = int(num_procs) + + self.num_threads = num_threads + + if self.num_procs > 1: + if not batch_size: + batch_size = self.num_procs # default to min effective batch size + local_batch_size = int(math.ceil(batch_size / self.num_procs)) + super().__init__(filename, batch_size=local_batch_size) + del self.interpreter + self.pool = Pool( + processes=self.num_procs, + initializer=_pool_create_worker, + initargs=[filename, self.batch_size, self.num_threads], + ) + else: # fall back to serial implementation for max performance + super().__init__( + filename, batch_size=batch_size, num_threads=self.num_threads + ) + + self.total_batches = 0 + self.partial_batches = 0 + self.warned = False + + def close(self): + if self.pool: + self.pool.close() + self.pool.terminate() + + def __del__(self): + self.close() + + def __call__(self, named_input): + if self.pool: + global_batch_size = next(iter(named_input.values())).shape[0] + # Note: self.batch_size comes from superclass and is local batch size + chunks = int(math.ceil(global_batch_size / self.batch_size)) + self.total_batches += 1 + if chunks != self.num_procs: + self.partial_batches += 1 + if ( + not self.warned + and self.total_batches > 10 + and self.partial_batches / self.total_batches >= 0.5 + ): + print( + "ParallelTFLiteModel(%s): warning - %.1f%% of batches do not use all %d processes, set batch size to a multiple of this" + % ( + self.filename, + 100 * self.partial_batches / self.total_batches, + self.num_procs, + ) + ) + self.warned = True + + local_batches = [ + { + key: values[i * self.batch_size : (i + 1) * self.batch_size] + for key, values in named_input.items() + } + for i in range(chunks) + ] + chunk_results = self.pool.map(_pool_run, local_batches) + named_ys = defaultdict(list) + for chunk in chunk_results: + for k, v in chunk.items(): + named_ys[k].append(v) + return {k: np.concatenate(v) for k, v in named_ys.items()} + else: + return super().__call__(named_input) + + +_local_model = None + + +def _pool_create_worker(filename, local_batch_size=None, num_threads=None): + global _local_model + _local_model = TFLiteModel( + filename, batch_size=local_batch_size, num_threads=num_threads + ) + + +def _pool_run(named_inputs): + return _local_model(named_inputs) |