diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py | 172 |
1 files changed, 172 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py new file mode 100644 index 0000000..ac3e875 --- /dev/null +++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import json +import os +import random +import tempfile +from collections import defaultdict + +import numpy as np + +from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.rewrite.core.utils.utils import save + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +from tensorflow.lite.python import interpreter as interpreter_wrapper + + +def make_decode_fn(filename): + def decode_fn(record_bytes, type_map): + parse_dict = { + name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys() + } + example = tf.io.parse_single_example(record_bytes, parse_dict) + features = { + n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) + for n, t in type_map.items() + } + return features + + meta_filename = filename + ".meta" + with open(meta_filename) as f: + type_map = json.load(f)["type_map"] + return lambda record_bytes: decode_fn(record_bytes, type_map) + + +def NumpyTFReader(filename): + decode_fn = make_decode_fn(filename) + dataset = tf.data.TFRecordDataset(filename) + return dataset.map(decode_fn) + + +def numpytf_count(filename): + meta_filename = filename + ".meta" + with open(meta_filename) as f: + return json.load(f)["count"] + + +class NumpyTFWriter: + def __init__(self, filename): + self.filename = filename + self.meta_filename = filename + ".meta" + self.writer = tf.io.TFRecordWriter(filename) + self.type_map = {} + self.count = 0 + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + def __del__(self): + self.close() + + def write(self, array_dict): + type_map = {n: str(a.dtype.name) for n, a in array_dict.items()} + self.type_map.update(type_map) + self.count += 1 + + feature = { + n: tf.train.Feature( + bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(a).numpy()]) + ) + for n, a in array_dict.items() + } + example = tf.train.Example(features=tf.train.Features(feature=feature)) + self.writer.write(example.SerializeToString()) + + def close(self): + with open(self.meta_filename, "w") as f: + meta = {"type_map": self.type_map, "count": self.count} + json.dump(meta, f) + self.writer.close() + + +class TFLiteModel: + def __init__(self, filename, batch_size=None, num_threads=None): + if num_threads == 0: + num_threads = None + if batch_size == None: + self.interpreter = interpreter_wrapper.Interpreter( + model_path=filename, num_threads=num_threads + ) + else: # if a batch size is specified, modify the TFLite model to use this size + with tempfile.TemporaryDirectory() as tmp: + fb = load(filename) + for sg in fb.subgraphs: + for t in list(sg.inputs) + list(sg.outputs): + sg.tensors[t].shape = np.array( + [batch_size] + list(sg.tensors[t].shape[1:]), dtype=np.int32 + ) + tempname = os.path.join(tmp, "rewrite_tmp.tflite") + save(fb, tempname) + self.interpreter = interpreter_wrapper.Interpreter( + model_path=tempname, num_threads=num_threads + ) + + try: + self.interpreter.allocate_tensors() + except RuntimeError: + self.interpreter = interpreter_wrapper.Interpreter( + model_path=filename, num_threads=num_threads + ) + self.interpreter.allocate_tensors() + + # Get input and output tensors. + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + details = list(self.input_details) + list(self.output_details) + self.handle_from_name = {d["name"]: d["index"] for d in details} + self.shape_from_name = {d["name"]: d["shape"] for d in details} + self.batch_size = next(iter(self.shape_from_name.values()))[0] + + def __call__(self, named_input): + """Execute the model on one or a batch of named inputs (a dict of name: numpy array)""" + input_len = next(iter(named_input.values())).shape[0] + full_steps = input_len // self.batch_size + remainder = input_len % self.batch_size + + named_ys = defaultdict(list) + for i in range(full_steps): + for name, x_batch in named_input.items(): + x = x_batch[i : i + self.batch_size] + self.interpreter.set_tensor(self.handle_from_name[name], x) + self.interpreter.invoke() + for d in self.output_details: + named_ys[d["name"]].append(self.interpreter.get_tensor(d["index"])) + if remainder: + for name, x_batch in named_input.items(): + x = np.zeros(self.shape_from_name[name]).astype(x_batch.dtype) + x[:remainder] = x_batch[-remainder:] + self.interpreter.set_tensor(self.handle_from_name[name], x) + self.interpreter.invoke() + for d in self.output_details: + named_ys[d["name"]].append( + self.interpreter.get_tensor(d["index"])[:remainder] + ) + return {k: np.concatenate(v) for k, v in named_ys.items()} + + def input_tensors(self): + return [d["name"] for d in self.input_details] + + def output_tensors(self): + return [d["name"] for d in self.output_details] + + +def sample_tfrec(input_file, k, output_file): + total = numpytf_count(input_file) + next = sorted(random.sample(range(total), k=k), reverse=True) + + reader = NumpyTFReader(input_file) + with NumpyTFWriter(output_file) as writer: + for i, data in enumerate(reader): + if i == next[-1]: + next.pop() + writer.write(data) + if not next: + break |