aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py')
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py172
1 files changed, 172 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
new file mode 100644
index 0000000..ac3e875
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -0,0 +1,172 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import json
+import os
+import random
+import tempfile
+from collections import defaultdict
+
+import numpy as np
+
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from tensorflow.lite.python import interpreter as interpreter_wrapper
+
+
+def make_decode_fn(filename):
+ def decode_fn(record_bytes, type_map):
+ parse_dict = {
+ name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys()
+ }
+ example = tf.io.parse_single_example(record_bytes, parse_dict)
+ features = {
+ n: tf.io.parse_tensor(example[n], tf.as_dtype(t))
+ for n, t in type_map.items()
+ }
+ return features
+
+ meta_filename = filename + ".meta"
+ with open(meta_filename) as f:
+ type_map = json.load(f)["type_map"]
+ return lambda record_bytes: decode_fn(record_bytes, type_map)
+
+
+def NumpyTFReader(filename):
+ decode_fn = make_decode_fn(filename)
+ dataset = tf.data.TFRecordDataset(filename)
+ return dataset.map(decode_fn)
+
+
+def numpytf_count(filename):
+ meta_filename = filename + ".meta"
+ with open(meta_filename) as f:
+ return json.load(f)["count"]
+
+
+class NumpyTFWriter:
+ def __init__(self, filename):
+ self.filename = filename
+ self.meta_filename = filename + ".meta"
+ self.writer = tf.io.TFRecordWriter(filename)
+ self.type_map = {}
+ self.count = 0
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
+ def __del__(self):
+ self.close()
+
+ def write(self, array_dict):
+ type_map = {n: str(a.dtype.name) for n, a in array_dict.items()}
+ self.type_map.update(type_map)
+ self.count += 1
+
+ feature = {
+ n: tf.train.Feature(
+ bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(a).numpy()])
+ )
+ for n, a in array_dict.items()
+ }
+ example = tf.train.Example(features=tf.train.Features(feature=feature))
+ self.writer.write(example.SerializeToString())
+
+ def close(self):
+ with open(self.meta_filename, "w") as f:
+ meta = {"type_map": self.type_map, "count": self.count}
+ json.dump(meta, f)
+ self.writer.close()
+
+
+class TFLiteModel:
+ def __init__(self, filename, batch_size=None, num_threads=None):
+ if num_threads == 0:
+ num_threads = None
+ if batch_size == None:
+ self.interpreter = interpreter_wrapper.Interpreter(
+ model_path=filename, num_threads=num_threads
+ )
+ else: # if a batch size is specified, modify the TFLite model to use this size
+ with tempfile.TemporaryDirectory() as tmp:
+ fb = load(filename)
+ for sg in fb.subgraphs:
+ for t in list(sg.inputs) + list(sg.outputs):
+ sg.tensors[t].shape = np.array(
+ [batch_size] + list(sg.tensors[t].shape[1:]), dtype=np.int32
+ )
+ tempname = os.path.join(tmp, "rewrite_tmp.tflite")
+ save(fb, tempname)
+ self.interpreter = interpreter_wrapper.Interpreter(
+ model_path=tempname, num_threads=num_threads
+ )
+
+ try:
+ self.interpreter.allocate_tensors()
+ except RuntimeError:
+ self.interpreter = interpreter_wrapper.Interpreter(
+ model_path=filename, num_threads=num_threads
+ )
+ self.interpreter.allocate_tensors()
+
+ # Get input and output tensors.
+ self.input_details = self.interpreter.get_input_details()
+ self.output_details = self.interpreter.get_output_details()
+ details = list(self.input_details) + list(self.output_details)
+ self.handle_from_name = {d["name"]: d["index"] for d in details}
+ self.shape_from_name = {d["name"]: d["shape"] for d in details}
+ self.batch_size = next(iter(self.shape_from_name.values()))[0]
+
+ def __call__(self, named_input):
+ """Execute the model on one or a batch of named inputs (a dict of name: numpy array)"""
+ input_len = next(iter(named_input.values())).shape[0]
+ full_steps = input_len // self.batch_size
+ remainder = input_len % self.batch_size
+
+ named_ys = defaultdict(list)
+ for i in range(full_steps):
+ for name, x_batch in named_input.items():
+ x = x_batch[i : i + self.batch_size]
+ self.interpreter.set_tensor(self.handle_from_name[name], x)
+ self.interpreter.invoke()
+ for d in self.output_details:
+ named_ys[d["name"]].append(self.interpreter.get_tensor(d["index"]))
+ if remainder:
+ for name, x_batch in named_input.items():
+ x = np.zeros(self.shape_from_name[name]).astype(x_batch.dtype)
+ x[:remainder] = x_batch[-remainder:]
+ self.interpreter.set_tensor(self.handle_from_name[name], x)
+ self.interpreter.invoke()
+ for d in self.output_details:
+ named_ys[d["name"]].append(
+ self.interpreter.get_tensor(d["index"])[:remainder]
+ )
+ return {k: np.concatenate(v) for k, v in named_ys.items()}
+
+ def input_tensors(self):
+ return [d["name"] for d in self.input_details]
+
+ def output_tensors(self):
+ return [d["name"] for d in self.output_details]
+
+
+def sample_tfrec(input_file, k, output_file):
+ total = numpytf_count(input_file)
+ next = sorted(random.sample(range(total), k=k), reverse=True)
+
+ reader = NumpyTFReader(input_file)
+ with NumpyTFWriter(output_file) as writer:
+ for i, data in enumerate(reader):
+ if i == next[-1]:
+ next.pop()
+ writer.write(data)
+ if not next:
+ break