diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py | 138 |
1 files changed, 26 insertions, 112 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py index 9229810..38ac1ed 100644 --- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py +++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py @@ -6,55 +6,56 @@ from __future__ import annotations import json import os import random -import tempfile -from collections import defaultdict +from functools import lru_cache from pathlib import Path from typing import Any from typing import Callable -import numpy as np import tensorflow as tf -from tensorflow.lite.python import interpreter as interpreter_wrapper -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -def make_decode_fn(filename: str) -> Callable: - """Make decode filename.""" +def decode_fn(record_bytes: Any, type_map: dict) -> dict: + """Decode the given bytes into a name-tensor dict assuming the given type.""" + parse_dict = { + name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys() + } + example = tf.io.parse_single_example(record_bytes, parse_dict) + features = { + n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) for n, t in type_map.items() + } + return features - def decode_fn(record_bytes: Any, type_map: dict) -> dict: - parse_dict = { - name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys() - } - example = tf.io.parse_single_example(record_bytes, parse_dict) - features = { - n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) - for n, t in type_map.items() - } - return features +def make_decode_fn(filename: str, model_filename: str | Path | None = None) -> Callable: + """Make decode filename.""" meta_filename = filename + ".meta" - with open(meta_filename, encoding="utf-8") as file: - type_map = json.load(file)["type_map"] + try: + with open(meta_filename, encoding="utf-8") as file: + type_map = json.load(file)["type_map"] return lambda record_bytes: decode_fn(record_bytes, type_map) def numpytf_read(filename: str | Path) -> Any: """Read TFRecord dataset.""" - decode_fn = make_decode_fn(str(filename)) + decode = make_decode_fn(str(filename)) dataset = tf.data.TFRecordDataset(str(filename)) - return dataset.map(decode_fn) + return dataset.map(decode) -def numpytf_count(filename: str | Path) -> Any: +@lru_cache +def numpytf_count(filename: str | Path) -> int: """Return count from TFRecord file.""" meta_filename = f"{filename}.meta" - with open(meta_filename, encoding="utf-8") as file: - return json.load(file)["count"] + try: + with open(meta_filename, encoding="utf-8") as file: + return int(json.load(file)["count"]) + except FileNotFoundError: + raw_dataset = tf.data.TFRecordDataset(filename) + return sum(1 for _ in raw_dataset) class NumpyTFWriter: @@ -101,93 +102,6 @@ class NumpyTFWriter: self.writer.close() -class TFLiteModel: - """A representation of a TFLite Model.""" - - def __init__( - self, - filename: str, - batch_size: int | None = None, - num_threads: int | None = None, - ) -> None: - """Initiate a TFLite Model.""" - if not num_threads: - num_threads = None - if not batch_size: - self.interpreter = interpreter_wrapper.Interpreter( - model_path=filename, num_threads=num_threads - ) - else: # if a batch size is specified, modify the TFLite model to use this size - with tempfile.TemporaryDirectory() as tmp: - flatbuffer = load(filename) - for subgraph in flatbuffer.subgraphs: - for tensor in list(subgraph.inputs) + list(subgraph.outputs): - subgraph.tensors[tensor].shape = np.array( - [batch_size] + list(subgraph.tensors[tensor].shape[1:]), - dtype=np.int32, - ) - tempname = os.path.join(tmp, "rewrite_tmp.tflite") - save(flatbuffer, tempname) - self.interpreter = interpreter_wrapper.Interpreter( - model_path=tempname, num_threads=num_threads - ) - - try: - self.interpreter.allocate_tensors() - except RuntimeError: - self.interpreter = interpreter_wrapper.Interpreter( - model_path=filename, num_threads=num_threads - ) - self.interpreter.allocate_tensors() - - # Get input and output tensors. - self.input_details = self.interpreter.get_input_details() - self.output_details = self.interpreter.get_output_details() - details = list(self.input_details) + list(self.output_details) - self.handle_from_name = {d["name"]: d["index"] for d in details} - self.shape_from_name = {d["name"]: d["shape"] for d in details} - self.batch_size = next(iter(self.shape_from_name.values()))[0] - - def __call__(self, named_input: dict) -> dict: - """Execute the model on one or a batch of named inputs \ - (a dict of name: numpy array).""" - input_len = next(iter(named_input.values())).shape[0] - full_steps = input_len // self.batch_size - remainder = input_len % self.batch_size - - named_ys = defaultdict(list) - for i in range(full_steps): - for name, x_batch in named_input.items(): - x_tensor = x_batch[i : i + self.batch_size] # noqa: E203 - self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) - self.interpreter.invoke() - for output_detail in self.output_details: - named_ys[output_detail["name"]].append( - self.interpreter.get_tensor(output_detail["index"]) - ) - if remainder: - for name, x_batch in named_input.items(): - x_tensor = np.zeros( # pylint: disable=invalid-name - self.shape_from_name[name] - ).astype(x_batch.dtype) - x_tensor[:remainder] = x_batch[-remainder:] - self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) - self.interpreter.invoke() - for output_detail in self.output_details: - named_ys[output_detail["name"]].append( - self.interpreter.get_tensor(output_detail["index"])[:remainder] - ) - return {k: np.concatenate(v) for k, v in named_ys.items()} - - def input_tensors(self) -> list: - """Return name from input details.""" - return [d["name"] for d in self.input_details] - - def output_tensors(self) -> list: - """Return name from output details.""" - return [d["name"] for d in self.output_details] - - def sample_tfrec(input_file: str, k: int, output_file: str) -> None: """Count, read and write TFRecord input and output data.""" total = numpytf_count(input_file) |