diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils/parallel.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/parallel.py | 90 |
1 files changed, 58 insertions, 32 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py index b1a2914..d930a1e 100644 --- a/src/mlia/nn/rewrite/core/utils/parallel.py +++ b/src/mlia/nn/rewrite/core/utils/parallel.py @@ -1,28 +1,45 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Parallelize a TFLiteModel.""" +from __future__ import annotations + +import logging import math import os from collections import defaultdict from multiprocessing import cpu_count from multiprocessing import Pool +from pathlib import Path +from typing import Any import numpy as np - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) - from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +logger = logging.getLogger(__name__) + class ParallelTFLiteModel(TFLiteModel): - def __init__(self, filename, num_procs=1, num_threads=0, batch_size=None): - """num_procs: 0 => detect real cores on system - num_threads: 0 => TFLite impl. specific setting, usually 3 - batch_size: None => automatic (num_procs or file-determined) - """ + """A parallel version of a TFLiteModel. + + num_procs: 0 => detect real cores on system + num_threads: 0 => TFLite impl. specific setting, usually 3 + batch_size: None => automatic (num_procs or file-determined) + """ + + def __init__( + self, + filename: str | Path, + num_procs: int = 1, + num_threads: int = 0, + batch_size: int | None = None, + ) -> None: + """Initiate a Parallel TFLite Model.""" self.pool = None + filename = str(filename) self.filename = filename if not num_procs: self.num_procs = cpu_count() @@ -37,7 +54,7 @@ class ParallelTFLiteModel(TFLiteModel): local_batch_size = int(math.ceil(batch_size / self.num_procs)) super().__init__(filename, batch_size=local_batch_size) del self.interpreter - self.pool = Pool( + self.pool = Pool( # pylint: disable=consider-using-with processes=self.num_procs, initializer=_pool_create_worker, initargs=[filename, self.batch_size, self.num_threads], @@ -51,15 +68,18 @@ class ParallelTFLiteModel(TFLiteModel): self.partial_batches = 0 self.warned = False - def close(self): + def close(self) -> None: + """Close and terminate pool.""" if self.pool: self.pool.close() self.pool.terminate() - def __del__(self): + def __del__(self) -> None: + """Close instance.""" self.close() - def __call__(self, named_input): + def __call__(self, named_input: dict) -> Any: + """Call instance.""" if self.pool: global_batch_size = next(iter(named_input.values())).shape[0] # Note: self.batch_size comes from superclass and is local batch size @@ -72,19 +92,21 @@ class ParallelTFLiteModel(TFLiteModel): and self.total_batches > 10 and self.partial_batches / self.total_batches >= 0.5 ): - print( - "ParallelTFLiteModel(%s): warning - %.1f%% of batches do not use all %d processes, set batch size to a multiple of this" - % ( - self.filename, - 100 * self.partial_batches / self.total_batches, - self.num_procs, - ) + logger.warning( + "ParallelTFLiteModel(%s): warning - %.1f of batches " + "do not use all %d processes, set batch size to " + "a multiple of this.", + self.filename, + 100 * self.partial_batches / self.total_batches, + self.num_procs, ) self.warned = True local_batches = [ { - key: values[i * self.batch_size : (i + 1) * self.batch_size] + key: values[ + i * self.batch_size : (i + 1) * self.batch_size # noqa: E203 + ] for key, values in named_input.items() } for i in range(chunks) @@ -92,22 +114,26 @@ class ParallelTFLiteModel(TFLiteModel): chunk_results = self.pool.map(_pool_run, local_batches) named_ys = defaultdict(list) for chunk in chunk_results: - for k, v in chunk.items(): - named_ys[k].append(v) - return {k: np.concatenate(v) for k, v in named_ys.items()} - else: - return super().__call__(named_input) + for key, value in chunk.items(): + named_ys[key].append(value) + return {key: np.concatenate(value) for key, value in named_ys.items()} + + return super().__call__(named_input) -_local_model = None +_LOCAL_MODEL = None -def _pool_create_worker(filename, local_batch_size=None, num_threads=None): - global _local_model - _local_model = TFLiteModel( +def _pool_create_worker( + filename: str, local_batch_size: int = 0, num_threads: int = 0 +) -> None: + global _LOCAL_MODEL # pylint: disable=global-statement + _LOCAL_MODEL = TFLiteModel( filename, batch_size=local_batch_size, num_threads=num_threads ) -def _pool_run(named_inputs): - return _local_model(named_inputs) +def _pool_run(named_inputs: dict) -> Any: + if _LOCAL_MODEL: + return _LOCAL_MODEL(named_inputs) + raise ValueError("TFLiteModel is not initiated") |