aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py')
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py140
1 files changed, 87 insertions, 53 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
index 2141003..9229810 100644
--- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -1,26 +1,32 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Numpy TFRecord utils."""
+from __future__ import annotations
+
import json
import os
import random
import tempfile
from collections import defaultdict
+from pathlib import Path
+from typing import Any
+from typing import Callable
import numpy as np
+import tensorflow as tf
+from tensorflow.lite.python import interpreter as interpreter_wrapper
from mlia.nn.rewrite.core.utils.utils import load
from mlia.nn.rewrite.core.utils.utils import save
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
-import tensorflow as tf
-
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-from tensorflow.lite.python import interpreter as interpreter_wrapper
+def make_decode_fn(filename: str) -> Callable:
+ """Make decode filename."""
-def make_decode_fn(filename):
- def decode_fn(record_bytes, type_map):
+ def decode_fn(record_bytes: Any, type_map: dict) -> dict:
parse_dict = {
name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
}
@@ -32,38 +38,48 @@ def make_decode_fn(filename):
return features
meta_filename = filename + ".meta"
- with open(meta_filename) as f:
- type_map = json.load(f)["type_map"]
+ with open(meta_filename, encoding="utf-8") as file:
+ type_map = json.load(file)["type_map"]
return lambda record_bytes: decode_fn(record_bytes, type_map)
-def NumpyTFReader(filename):
- decode_fn = make_decode_fn(filename)
- dataset = tf.data.TFRecordDataset(filename)
+def numpytf_read(filename: str | Path) -> Any:
+ """Read TFRecord dataset."""
+ decode_fn = make_decode_fn(str(filename))
+ dataset = tf.data.TFRecordDataset(str(filename))
return dataset.map(decode_fn)
-def numpytf_count(filename):
- meta_filename = filename + ".meta"
- with open(meta_filename) as f:
- return json.load(f)["count"]
+def numpytf_count(filename: str | Path) -> Any:
+ """Return count from TFRecord file."""
+ meta_filename = f"{filename}.meta"
+ with open(meta_filename, encoding="utf-8") as file:
+ return json.load(file)["count"]
class NumpyTFWriter:
- def __init__(self, filename):
+ """Numpy TF serializer."""
+
+ def __init__(self, filename: str | Path) -> None:
+ """Initiate a Numpy TF Serializer."""
self.filename = filename
- self.meta_filename = filename + ".meta"
- self.writer = tf.io.TFRecordWriter(filename)
- self.type_map = {}
+ self.meta_filename = f"{filename}.meta"
+ self.writer = tf.io.TFRecordWriter(str(filename))
+ self.type_map: dict = {}
self.count = 0
- def __enter__(self):
+ def __enter__(self) -> Any:
+ """Enter instance."""
return self
- def __exit__(self, type, value, traceback):
+ def __exit__(
+ self, exception_type: Any, exception_value: Any, exception_traceback: Any
+ ) -> None:
+ """Close instance."""
self.close()
- def write(self, array_dict):
+ def write(self, array_dict: dict) -> None:
+ """Write array dict."""
type_map = {n: str(a.dtype.name) for n, a in array_dict.items()}
self.type_map.update(type_map)
self.count += 1
@@ -77,31 +93,41 @@ class NumpyTFWriter:
example = tf.train.Example(features=tf.train.Features(feature=feature))
self.writer.write(example.SerializeToString())
- def close(self):
- with open(self.meta_filename, "w") as f:
+ def close(self) -> None:
+ """Close NumpyTFWriter."""
+ with open(self.meta_filename, "w", encoding="utf-8") as file:
meta = {"type_map": self.type_map, "count": self.count}
- json.dump(meta, f)
+ json.dump(meta, file)
self.writer.close()
class TFLiteModel:
- def __init__(self, filename, batch_size=None, num_threads=None):
- if num_threads == 0:
+ """A representation of a TFLite Model."""
+
+ def __init__(
+ self,
+ filename: str,
+ batch_size: int | None = None,
+ num_threads: int | None = None,
+ ) -> None:
+ """Initiate a TFLite Model."""
+ if not num_threads:
num_threads = None
- if batch_size == None:
+ if not batch_size:
self.interpreter = interpreter_wrapper.Interpreter(
model_path=filename, num_threads=num_threads
)
else: # if a batch size is specified, modify the TFLite model to use this size
with tempfile.TemporaryDirectory() as tmp:
- fb = load(filename)
- for sg in fb.subgraphs:
- for t in list(sg.inputs) + list(sg.outputs):
- sg.tensors[t].shape = np.array(
- [batch_size] + list(sg.tensors[t].shape[1:]), dtype=np.int32
+ flatbuffer = load(filename)
+ for subgraph in flatbuffer.subgraphs:
+ for tensor in list(subgraph.inputs) + list(subgraph.outputs):
+ subgraph.tensors[tensor].shape = np.array(
+ [batch_size] + list(subgraph.tensors[tensor].shape[1:]),
+ dtype=np.int32,
)
tempname = os.path.join(tmp, "rewrite_tmp.tflite")
- save(fb, tempname)
+ save(flatbuffer, tempname)
self.interpreter = interpreter_wrapper.Interpreter(
model_path=tempname, num_threads=num_threads
)
@@ -122,8 +148,9 @@ class TFLiteModel:
self.shape_from_name = {d["name"]: d["shape"] for d in details}
self.batch_size = next(iter(self.shape_from_name.values()))[0]
- def __call__(self, named_input):
- """Execute the model on one or a batch of named inputs (a dict of name: numpy array)"""
+ def __call__(self, named_input: dict) -> dict:
+ """Execute the model on one or a batch of named inputs \
+ (a dict of name: numpy array)."""
input_len = next(iter(named_input.values())).shape[0]
full_steps = input_len // self.batch_size
remainder = input_len % self.batch_size
@@ -131,39 +158,46 @@ class TFLiteModel:
named_ys = defaultdict(list)
for i in range(full_steps):
for name, x_batch in named_input.items():
- x = x_batch[i : i + self.batch_size]
- self.interpreter.set_tensor(self.handle_from_name[name], x)
+ x_tensor = x_batch[i : i + self.batch_size] # noqa: E203
+ self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
self.interpreter.invoke()
- for d in self.output_details:
- named_ys[d["name"]].append(self.interpreter.get_tensor(d["index"]))
+ for output_detail in self.output_details:
+ named_ys[output_detail["name"]].append(
+ self.interpreter.get_tensor(output_detail["index"])
+ )
if remainder:
for name, x_batch in named_input.items():
- x = np.zeros(self.shape_from_name[name]).astype(x_batch.dtype)
- x[:remainder] = x_batch[-remainder:]
- self.interpreter.set_tensor(self.handle_from_name[name], x)
+ x_tensor = np.zeros( # pylint: disable=invalid-name
+ self.shape_from_name[name]
+ ).astype(x_batch.dtype)
+ x_tensor[:remainder] = x_batch[-remainder:]
+ self.interpreter.set_tensor(self.handle_from_name[name], x_tensor)
self.interpreter.invoke()
- for d in self.output_details:
- named_ys[d["name"]].append(
- self.interpreter.get_tensor(d["index"])[:remainder]
+ for output_detail in self.output_details:
+ named_ys[output_detail["name"]].append(
+ self.interpreter.get_tensor(output_detail["index"])[:remainder]
)
return {k: np.concatenate(v) for k, v in named_ys.items()}
- def input_tensors(self):
+ def input_tensors(self) -> list:
+ """Return name from input details."""
return [d["name"] for d in self.input_details]
- def output_tensors(self):
+ def output_tensors(self) -> list:
+ """Return name from output details."""
return [d["name"] for d in self.output_details]
-def sample_tfrec(input_file, k, output_file):
+def sample_tfrec(input_file: str, k: int, output_file: str) -> None:
+ """Count, read and write TFRecord input and output data."""
total = numpytf_count(input_file)
- next = sorted(random.sample(range(total), k=k), reverse=True)
+ next_sample = sorted(random.sample(range(total), k=k), reverse=True)
- reader = NumpyTFReader(input_file)
+ reader = numpytf_read(input_file)
with NumpyTFWriter(output_file) as writer:
for i, data in enumerate(reader):
- if i == next[-1]:
- next.pop()
+ if i == next_sample[-1]:
+ next_sample.pop()
writer.write(data)
- if not next:
+ if not next_sample:
break