diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils')
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/__init__.py | 1 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py | 140 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/parallel.py | 90 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/utils.py | 22 |
4 files changed, 160 insertions, 93 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/__init__.py b/src/mlia/nn/rewrite/core/utils/__init__.py index 8c1f750..f0b5026 100644 --- a/src/mlia/nn/rewrite/core/utils/__init__.py +++ b/src/mlia/nn/rewrite/core/utils/__init__.py @@ -1,2 +1,3 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Rewrite core utils module.""" diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py index 2141003..9229810 100644 --- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py +++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py @@ -1,26 +1,32 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Numpy TFRecord utils.""" +from __future__ import annotations + import json import os import random import tempfile from collections import defaultdict +from pathlib import Path +from typing import Any +from typing import Callable import numpy as np +import tensorflow as tf +from tensorflow.lite.python import interpreter as interpreter_wrapper from mlia.nn.rewrite.core.utils.utils import load from mlia.nn.rewrite.core.utils.utils import save os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -import tensorflow as tf - tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -from tensorflow.lite.python import interpreter as interpreter_wrapper +def make_decode_fn(filename: str) -> Callable: + """Make decode filename.""" -def make_decode_fn(filename): - def decode_fn(record_bytes, type_map): + def decode_fn(record_bytes: Any, type_map: dict) -> dict: parse_dict = { name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys() } @@ -32,38 +38,48 @@ def make_decode_fn(filename): return features meta_filename = filename + ".meta" - with open(meta_filename) as f: - type_map = json.load(f)["type_map"] + with open(meta_filename, encoding="utf-8") as file: + type_map = json.load(file)["type_map"] return lambda record_bytes: decode_fn(record_bytes, type_map) -def NumpyTFReader(filename): - decode_fn = make_decode_fn(filename) - dataset = tf.data.TFRecordDataset(filename) +def numpytf_read(filename: str | Path) -> Any: + """Read TFRecord dataset.""" + decode_fn = make_decode_fn(str(filename)) + dataset = tf.data.TFRecordDataset(str(filename)) return dataset.map(decode_fn) -def numpytf_count(filename): - meta_filename = filename + ".meta" - with open(meta_filename) as f: - return json.load(f)["count"] +def numpytf_count(filename: str | Path) -> Any: + """Return count from TFRecord file.""" + meta_filename = f"{filename}.meta" + with open(meta_filename, encoding="utf-8") as file: + return json.load(file)["count"] class NumpyTFWriter: - def __init__(self, filename): + """Numpy TF serializer.""" + + def __init__(self, filename: str | Path) -> None: + """Initiate a Numpy TF Serializer.""" self.filename = filename - self.meta_filename = filename + ".meta" - self.writer = tf.io.TFRecordWriter(filename) - self.type_map = {} + self.meta_filename = f"{filename}.meta" + self.writer = tf.io.TFRecordWriter(str(filename)) + self.type_map: dict = {} self.count = 0 - def __enter__(self): + def __enter__(self) -> Any: + """Enter instance.""" return self - def __exit__(self, type, value, traceback): + def __exit__( + self, exception_type: Any, exception_value: Any, exception_traceback: Any + ) -> None: + """Close instance.""" self.close() - def write(self, array_dict): + def write(self, array_dict: dict) -> None: + """Write array dict.""" type_map = {n: str(a.dtype.name) for n, a in array_dict.items()} self.type_map.update(type_map) self.count += 1 @@ -77,31 +93,41 @@ class NumpyTFWriter: example = tf.train.Example(features=tf.train.Features(feature=feature)) self.writer.write(example.SerializeToString()) - def close(self): - with open(self.meta_filename, "w") as f: + def close(self) -> None: + """Close NumpyTFWriter.""" + with open(self.meta_filename, "w", encoding="utf-8") as file: meta = {"type_map": self.type_map, "count": self.count} - json.dump(meta, f) + json.dump(meta, file) self.writer.close() class TFLiteModel: - def __init__(self, filename, batch_size=None, num_threads=None): - if num_threads == 0: + """A representation of a TFLite Model.""" + + def __init__( + self, + filename: str, + batch_size: int | None = None, + num_threads: int | None = None, + ) -> None: + """Initiate a TFLite Model.""" + if not num_threads: num_threads = None - if batch_size == None: + if not batch_size: self.interpreter = interpreter_wrapper.Interpreter( model_path=filename, num_threads=num_threads ) else: # if a batch size is specified, modify the TFLite model to use this size with tempfile.TemporaryDirectory() as tmp: - fb = load(filename) - for sg in fb.subgraphs: - for t in list(sg.inputs) + list(sg.outputs): - sg.tensors[t].shape = np.array( - [batch_size] + list(sg.tensors[t].shape[1:]), dtype=np.int32 + flatbuffer = load(filename) + for subgraph in flatbuffer.subgraphs: + for tensor in list(subgraph.inputs) + list(subgraph.outputs): + subgraph.tensors[tensor].shape = np.array( + [batch_size] + list(subgraph.tensors[tensor].shape[1:]), + dtype=np.int32, ) tempname = os.path.join(tmp, "rewrite_tmp.tflite") - save(fb, tempname) + save(flatbuffer, tempname) self.interpreter = interpreter_wrapper.Interpreter( model_path=tempname, num_threads=num_threads ) @@ -122,8 +148,9 @@ class TFLiteModel: self.shape_from_name = {d["name"]: d["shape"] for d in details} self.batch_size = next(iter(self.shape_from_name.values()))[0] - def __call__(self, named_input): - """Execute the model on one or a batch of named inputs (a dict of name: numpy array)""" + def __call__(self, named_input: dict) -> dict: + """Execute the model on one or a batch of named inputs \ + (a dict of name: numpy array).""" input_len = next(iter(named_input.values())).shape[0] full_steps = input_len // self.batch_size remainder = input_len % self.batch_size @@ -131,39 +158,46 @@ class TFLiteModel: named_ys = defaultdict(list) for i in range(full_steps): for name, x_batch in named_input.items(): - x = x_batch[i : i + self.batch_size] - self.interpreter.set_tensor(self.handle_from_name[name], x) + x_tensor = x_batch[i : i + self.batch_size] # noqa: E203 + self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) self.interpreter.invoke() - for d in self.output_details: - named_ys[d["name"]].append(self.interpreter.get_tensor(d["index"])) + for output_detail in self.output_details: + named_ys[output_detail["name"]].append( + self.interpreter.get_tensor(output_detail["index"]) + ) if remainder: for name, x_batch in named_input.items(): - x = np.zeros(self.shape_from_name[name]).astype(x_batch.dtype) - x[:remainder] = x_batch[-remainder:] - self.interpreter.set_tensor(self.handle_from_name[name], x) + x_tensor = np.zeros( # pylint: disable=invalid-name + self.shape_from_name[name] + ).astype(x_batch.dtype) + x_tensor[:remainder] = x_batch[-remainder:] + self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) self.interpreter.invoke() - for d in self.output_details: - named_ys[d["name"]].append( - self.interpreter.get_tensor(d["index"])[:remainder] + for output_detail in self.output_details: + named_ys[output_detail["name"]].append( + self.interpreter.get_tensor(output_detail["index"])[:remainder] ) return {k: np.concatenate(v) for k, v in named_ys.items()} - def input_tensors(self): + def input_tensors(self) -> list: + """Return name from input details.""" return [d["name"] for d in self.input_details] - def output_tensors(self): + def output_tensors(self) -> list: + """Return name from output details.""" return [d["name"] for d in self.output_details] -def sample_tfrec(input_file, k, output_file): +def sample_tfrec(input_file: str, k: int, output_file: str) -> None: + """Count, read and write TFRecord input and output data.""" total = numpytf_count(input_file) - next = sorted(random.sample(range(total), k=k), reverse=True) + next_sample = sorted(random.sample(range(total), k=k), reverse=True) - reader = NumpyTFReader(input_file) + reader = numpytf_read(input_file) with NumpyTFWriter(output_file) as writer: for i, data in enumerate(reader): - if i == next[-1]: - next.pop() + if i == next_sample[-1]: + next_sample.pop() writer.write(data) - if not next: + if not next_sample: break diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py index b1a2914..d930a1e 100644 --- a/src/mlia/nn/rewrite/core/utils/parallel.py +++ b/src/mlia/nn/rewrite/core/utils/parallel.py @@ -1,28 +1,45 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Parallelize a TFLiteModel.""" +from __future__ import annotations + +import logging import math import os from collections import defaultdict from multiprocessing import cpu_count from multiprocessing import Pool +from pathlib import Path +from typing import Any import numpy as np - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) - from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +logger = logging.getLogger(__name__) + class ParallelTFLiteModel(TFLiteModel): - def __init__(self, filename, num_procs=1, num_threads=0, batch_size=None): - """num_procs: 0 => detect real cores on system - num_threads: 0 => TFLite impl. specific setting, usually 3 - batch_size: None => automatic (num_procs or file-determined) - """ + """A parallel version of a TFLiteModel. + + num_procs: 0 => detect real cores on system + num_threads: 0 => TFLite impl. specific setting, usually 3 + batch_size: None => automatic (num_procs or file-determined) + """ + + def __init__( + self, + filename: str | Path, + num_procs: int = 1, + num_threads: int = 0, + batch_size: int | None = None, + ) -> None: + """Initiate a Parallel TFLite Model.""" self.pool = None + filename = str(filename) self.filename = filename if not num_procs: self.num_procs = cpu_count() @@ -37,7 +54,7 @@ class ParallelTFLiteModel(TFLiteModel): local_batch_size = int(math.ceil(batch_size / self.num_procs)) super().__init__(filename, batch_size=local_batch_size) del self.interpreter - self.pool = Pool( + self.pool = Pool( # pylint: disable=consider-using-with processes=self.num_procs, initializer=_pool_create_worker, initargs=[filename, self.batch_size, self.num_threads], @@ -51,15 +68,18 @@ class ParallelTFLiteModel(TFLiteModel): self.partial_batches = 0 self.warned = False - def close(self): + def close(self) -> None: + """Close and terminate pool.""" if self.pool: self.pool.close() self.pool.terminate() - def __del__(self): + def __del__(self) -> None: + """Close instance.""" self.close() - def __call__(self, named_input): + def __call__(self, named_input: dict) -> Any: + """Call instance.""" if self.pool: global_batch_size = next(iter(named_input.values())).shape[0] # Note: self.batch_size comes from superclass and is local batch size @@ -72,19 +92,21 @@ class ParallelTFLiteModel(TFLiteModel): and self.total_batches > 10 and self.partial_batches / self.total_batches >= 0.5 ): - print( - "ParallelTFLiteModel(%s): warning - %.1f%% of batches do not use all %d processes, set batch size to a multiple of this" - % ( - self.filename, - 100 * self.partial_batches / self.total_batches, - self.num_procs, - ) + logger.warning( + "ParallelTFLiteModel(%s): warning - %.1f of batches " + "do not use all %d processes, set batch size to " + "a multiple of this.", + self.filename, + 100 * self.partial_batches / self.total_batches, + self.num_procs, ) self.warned = True local_batches = [ { - key: values[i * self.batch_size : (i + 1) * self.batch_size] + key: values[ + i * self.batch_size : (i + 1) * self.batch_size # noqa: E203 + ] for key, values in named_input.items() } for i in range(chunks) @@ -92,22 +114,26 @@ class ParallelTFLiteModel(TFLiteModel): chunk_results = self.pool.map(_pool_run, local_batches) named_ys = defaultdict(list) for chunk in chunk_results: - for k, v in chunk.items(): - named_ys[k].append(v) - return {k: np.concatenate(v) for k, v in named_ys.items()} - else: - return super().__call__(named_input) + for key, value in chunk.items(): + named_ys[key].append(value) + return {key: np.concatenate(value) for key, value in named_ys.items()} + + return super().__call__(named_input) -_local_model = None +_LOCAL_MODEL = None -def _pool_create_worker(filename, local_batch_size=None, num_threads=None): - global _local_model - _local_model = TFLiteModel( +def _pool_create_worker( + filename: str, local_batch_size: int = 0, num_threads: int = 0 +) -> None: + global _LOCAL_MODEL # pylint: disable=global-statement + _LOCAL_MODEL = TFLiteModel( filename, batch_size=local_batch_size, num_threads=num_threads ) -def _pool_run(named_inputs): - return _local_model(named_inputs) +def _pool_run(named_inputs: dict) -> Any: + if _LOCAL_MODEL: + return _LOCAL_MODEL(named_inputs) + raise ValueError("TFLiteModel is not initiated") diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py index d1ed322..ddf0cc2 100644 --- a/src/mlia/nn/rewrite/core/utils/utils.py +++ b/src/mlia/nn/rewrite/core/utils/utils.py @@ -1,22 +1,28 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -import os +"""Model and file system utilites.""" +from __future__ import annotations + +from pathlib import Path import flatbuffers -from tensorflow.lite.python import schema_py_generated as schema_fb +from tensorflow.lite.python.schema_py_generated import Model +from tensorflow.lite.python.schema_py_generated import ModelT -def load(input_tflite_file): - if not os.path.exists(input_tflite_file): - raise FileNotFoundError("TFLite file not found at %r\n" % input_tflite_file) +def load(input_tflite_file: str | Path) -> ModelT: + """Load a flatbuffer model from file.""" + if not Path(input_tflite_file).exists(): + raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n") with open(input_tflite_file, "rb") as file_handle: file_data = bytearray(file_handle.read()) - model_obj = schema_fb.Model.GetRootAsModel(file_data, 0) - model = schema_fb.ModelT.InitFromObj(model_obj) + model_obj = Model.GetRootAsModel(file_data, 0) + model = ModelT.InitFromObj(model_obj) return model -def save(model, output_tflite_file): +def save(model: ModelT, output_tflite_file: str | Path) -> None: + """Save a flatbuffer model to a given file.""" builder = flatbuffers.Builder(1024) # Initial size of the buffer, which # will grow automatically if needed model_offset = model.Pack(builder) |