aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/utils/parallel.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils/parallel.py')
-rw-r--r--src/mlia/nn/rewrite/core/utils/parallel.py90
1 files changed, 58 insertions, 32 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py
index b1a2914..d930a1e 100644
--- a/src/mlia/nn/rewrite/core/utils/parallel.py
+++ b/src/mlia/nn/rewrite/core/utils/parallel.py
@@ -1,28 +1,45 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Parallelize a TFLiteModel."""
+from __future__ import annotations
+
+import logging
import math
import os
from collections import defaultdict
from multiprocessing import cpu_count
from multiprocessing import Pool
+from pathlib import Path
+from typing import Any
import numpy as np
-
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-
from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+logger = logging.getLogger(__name__)
+
class ParallelTFLiteModel(TFLiteModel):
- def __init__(self, filename, num_procs=1, num_threads=0, batch_size=None):
- """num_procs: 0 => detect real cores on system
- num_threads: 0 => TFLite impl. specific setting, usually 3
- batch_size: None => automatic (num_procs or file-determined)
- """
+ """A parallel version of a TFLiteModel.
+
+ num_procs: 0 => detect real cores on system
+ num_threads: 0 => TFLite impl. specific setting, usually 3
+ batch_size: None => automatic (num_procs or file-determined)
+ """
+
+ def __init__(
+ self,
+ filename: str | Path,
+ num_procs: int = 1,
+ num_threads: int = 0,
+ batch_size: int | None = None,
+ ) -> None:
+ """Initiate a Parallel TFLite Model."""
self.pool = None
+ filename = str(filename)
self.filename = filename
if not num_procs:
self.num_procs = cpu_count()
@@ -37,7 +54,7 @@ class ParallelTFLiteModel(TFLiteModel):
local_batch_size = int(math.ceil(batch_size / self.num_procs))
super().__init__(filename, batch_size=local_batch_size)
del self.interpreter
- self.pool = Pool(
+ self.pool = Pool( # pylint: disable=consider-using-with
processes=self.num_procs,
initializer=_pool_create_worker,
initargs=[filename, self.batch_size, self.num_threads],
@@ -51,15 +68,18 @@ class ParallelTFLiteModel(TFLiteModel):
self.partial_batches = 0
self.warned = False
- def close(self):
+ def close(self) -> None:
+ """Close and terminate pool."""
if self.pool:
self.pool.close()
self.pool.terminate()
- def __del__(self):
+ def __del__(self) -> None:
+ """Close instance."""
self.close()
- def __call__(self, named_input):
+ def __call__(self, named_input: dict) -> Any:
+ """Call instance."""
if self.pool:
global_batch_size = next(iter(named_input.values())).shape[0]
# Note: self.batch_size comes from superclass and is local batch size
@@ -72,19 +92,21 @@ class ParallelTFLiteModel(TFLiteModel):
and self.total_batches > 10
and self.partial_batches / self.total_batches >= 0.5
):
- print(
- "ParallelTFLiteModel(%s): warning - %.1f%% of batches do not use all %d processes, set batch size to a multiple of this"
- % (
- self.filename,
- 100 * self.partial_batches / self.total_batches,
- self.num_procs,
- )
+ logger.warning(
+ "ParallelTFLiteModel(%s): warning - %.1f of batches "
+ "do not use all %d processes, set batch size to "
+ "a multiple of this.",
+ self.filename,
+ 100 * self.partial_batches / self.total_batches,
+ self.num_procs,
)
self.warned = True
local_batches = [
{
- key: values[i * self.batch_size : (i + 1) * self.batch_size]
+ key: values[
+ i * self.batch_size : (i + 1) * self.batch_size # noqa: E203
+ ]
for key, values in named_input.items()
}
for i in range(chunks)
@@ -92,22 +114,26 @@ class ParallelTFLiteModel(TFLiteModel):
chunk_results = self.pool.map(_pool_run, local_batches)
named_ys = defaultdict(list)
for chunk in chunk_results:
- for k, v in chunk.items():
- named_ys[k].append(v)
- return {k: np.concatenate(v) for k, v in named_ys.items()}
- else:
- return super().__call__(named_input)
+ for key, value in chunk.items():
+ named_ys[key].append(value)
+ return {key: np.concatenate(value) for key, value in named_ys.items()}
+
+ return super().__call__(named_input)
-_local_model = None
+_LOCAL_MODEL = None
-def _pool_create_worker(filename, local_batch_size=None, num_threads=None):
- global _local_model
- _local_model = TFLiteModel(
+def _pool_create_worker(
+ filename: str, local_batch_size: int = 0, num_threads: int = 0
+) -> None:
+ global _LOCAL_MODEL # pylint: disable=global-statement
+ _LOCAL_MODEL = TFLiteModel(
filename, batch_size=local_batch_size, num_threads=num_threads
)
-def _pool_run(named_inputs):
- return _local_model(named_inputs)
+def _pool_run(named_inputs: dict) -> Any:
+ if _LOCAL_MODEL:
+ return _LOCAL_MODEL(named_inputs)
+ raise ValueError("TFLiteModel is not initiated")